In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms

In [2]:
class KSOM(nn.Module):
    def __init__(self, input_size, output_size):
        super(KSOM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.weights = nn.Parameter(torch.randn(output_size, input_size))

    def forward(self, x):
        x = x.view(-1, self.input_size)
        distance = torch.cdist(x, self.weights)
        winner = torch.argmin(distance, dim=1)
        winner_weights = torch.index_select(self.weights, 0, winner)
        return winner_weights


In [3]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [4]:
import torchvision.datasets as datasets
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

Files already downloaded and verified


In [8]:
ksom = KSOM(input_size=32*32*3, output_size=100)
num_epochs=5

In [9]:
import torch.optim as optim

optimizer = optim.SGD(ksom.parameters(), lr=0.1)

In [10]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        inputs, _ = data
        inputs = inputs.view(-1, 32*32*3)
        optimizer.zero_grad()
        winner = ksom(inputs)
        loss = torch.mean(+
                          torch.sum((ksom.weights - inputs.unsqueeze(1)) ** 2, dim=-1))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: loss={loss.item()}")

print('Finished Training')

Epoch 1: loss=815.142822265625
Epoch 2: loss=941.3677368164062
Epoch 3: loss=680.262451171875
Epoch 4: loss=774.7720336914062
Epoch 5: loss=790.931396484375
Finished Training
