In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [24]:
class KSOM(nn.Module):
    def __init__(self, input_dim, output_dim, learning_rate, weights):
        super(KSOM, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.learning_rate = learning_rate
        
        self.weights = nn.Parameter(weights)

    def forward(self, x):
    
        x = x.view(-1, self.input_dim)
        distances = torch.sum((self.weights - x.unsqueeze(1)) ** 2, dim=-1)
        winner = torch.argmin(distances, dim=1)#used to return the index of the minimum value of all elements in the input tensor.
        
        delta = x.unsqueeze(1)+self.learning_rate * (x.unsqueeze(1) - self.weights)
        self.weights.data[winner] += torch.sum(delta, dim=0)[winner]
        print(self.weights.data[winner])
        return winner

In [25]:
x = torch.tensor([0.0, 0.2, 0.1, 0.2, 0.0])

weights = torch.tensor([
    [1.0, 0.9, 0.7, 0.3, 0.2],
    [0.6, 0.7, 0.5, 0.4, 1.0],
])

In [26]:
ksom = KSOM(input_dim=5, output_dim=2, learning_rate=0.2, weights=weights)


winner = ksom(x)
print("Winning cluster unit:", winner)


print("Updated weights:", ksom.weights)

tensor([[0.4800, 0.8000, 0.5200, 0.5600, 0.8000]])
Winning cluster unit: tensor([1])
Updated weights: Parameter containing:
tensor([[1.0000, 0.9000, 0.7000, 0.3000, 0.2000],
        [0.4800, 0.8000, 0.5200, 0.5600, 0.8000]], requires_grad=True)


In [27]:
print("Updated learning rate:",0.5*ksom.learning_rate)

Updated learning rate: 0.1


In [28]:
y = torch.tensor([0.1, 0.3, 0.1, 0.4, 0.2])

weights = torch.tensor([
    [5.0, 0.7, 0.4, 0.2, 0.5],
    [0.8, 0.9, 0.3, 0.6, 2.0],
])

In [29]:
ksom = KSOM(input_dim=5, output_dim=2, learning_rate=0.2, weights=weights)


winner = ksom(y)
print("Winning cluster unit:", winner)


print("Updated weights:", ksom.weights)

tensor([[0.7600, 1.0800, 0.3600, 0.9600, 1.8400]])
Winning cluster unit: tensor([1])
Updated weights: Parameter containing:
tensor([[5.0000, 0.7000, 0.4000, 0.2000, 0.5000],
        [0.7600, 1.0800, 0.3600, 0.9600, 1.8400]], requires_grad=True)


In [30]:
print("Updated learning rate:",0.5*ksom.learning_rate)

Updated learning rate: 0.1
