In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms

In [2]:
class KSOM(nn.Module):
    def __init__(self, input_dim, output_dim, learning_rate, weights):
        super(KSOM, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.learning_rate = learning_rate
        self.sigma = output_dim / 2
        self.weights = nn.Parameter(weights)

    def forward(self, x):
    
        x = x.view(-1, self.input_dim)
        distances = torch.sum((self.weights - x.unsqueeze(1)) ** 2, dim=-1)
        winner = torch.argmin(distances, dim=1)
        neighborhood = torch.exp(-distances / (2 * self.sigma ** 2))
        delta = self.learning_rate * neighborhood.unsqueeze(-1) * (x.unsqueeze(1) - self.weights)
        self.weights.data[winner] += torch.sum(delta, dim=0)[winner]

        return winner

In [3]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
import torchvision.datasets as datasets
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

Files already downloaded and verified


In [8]:
ksom = KSOM(input_dim=32*32*3, output_dim=100, learning_rate=0.1, weights=torch.rand(100, 32*32*3))
num_epochs=5

In [9]:
import torch.optim as optim

optimizer = optim.SGD(ksom.parameters(), lr=0.1)

In [10]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        inputs, _ = data
        inputs = inputs.view(-1, 32*32*3)
        optimizer.zero_grad()
        winner = ksom(inputs)
        loss = torch.mean(+
                          torch.sum((ksom.weights - inputs.unsqueeze(1)) ** 2, dim=-1))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: loss={loss.item()}")

print('Finished Training')

Epoch 1: loss=2262.60302734375
Epoch 2: loss=2331.546875
Epoch 3: loss=2255.106689453125
Epoch 4: loss=2399.530517578125
Epoch 5: loss=2138.78369140625
Finished Training
