In [92]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms

In [93]:
class KSOM(nn.Module):
    def __init__(self, input_dim, output_dim, learning_rate, weights):
        super(KSOM, self).__init__()
        self.fc1 = nn.Linear(32*32*3,100)
        self.fc2 = nn.Linear(100,32*32*3)
        # self.fc2 = nn.Linear(100,10)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.learning_rate = learning_rate
        self.sigma = output_dim / 2
        self.weights = nn.Parameter(weights)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.softmax(x)
        x = x.view(-1, 32*32*3)
        # print(self.weights.shape,x.unsqueeze(1).shape)
        distances = torch.sum((self.weights - x.unsqueeze(1)) ** 2, dim=-1)
        winner = torch.argmin(distances, dim=1)
        neighborhood = torch.exp(-distances / (2 * self.sigma ** 2))
        delta = self.learning_rate * neighborhood.unsqueeze(-1) * (x.unsqueeze(1) - self.weights)
        self.weights.data[winner] += torch.sum(delta, dim=0)[winner]
        # print(winner.size())
        return winner

In [94]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [95]:
import torchvision.datasets as datasets

In [96]:
dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)

In [97]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [98]:
ksom = KSOM(input_dim=32*32*3, output_dim=10, learning_rate=0.1, weights=torch.rand(10,32*32*3))
# print(ksom)
# ksom.load_state_dict(torch.load('./ksomCIFAR.pth'))

In [99]:
num_epochs=5

In [100]:
import torch.optim as optim

optimizer = optim.SGD(ksom.parameters(), lr=0.1)

In [101]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        inputs, _ = data
        optimizer.zero_grad()
        inputs = inputs.view(-1,32*32*3)
        winner = ksom(inputs)
        # print(ksom.weights.size(),inputs.size())
        loss = torch.mean(torch.sum((ksom.weights - inputs.unsqueeze(1)) ** 6, dim=-1))
        loss=loss/5000
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: loss={loss.item()}")

print('Finished Training')
print(ksom.weights)

Epoch 1: loss=1.1541579961776733
Epoch 2: loss=0.7707021236419678
Epoch 3: loss=0.7693989276885986
Epoch 4: loss=0.7894042134284973
Epoch 5: loss=0.8451470732688904
Finished Training
Parameter containing:
tensor([[0.6331, 0.6483, 0.3936,  ..., 0.4617, 0.6844, 0.4543],
        [0.6907, 0.3962, 0.2816,  ..., 0.6596, 0.4473, 0.0265],
        [0.2398, 0.7034, 0.5666,  ..., 0.6983, 0.2708, 0.3242],
        ...,
        [0.0370, 0.4914, 0.5477,  ..., 0.0996, 0.6905, 0.5493],
        [0.0843, 0.1876, 0.5059,  ..., 0.1204, 0.6342, 0.3960],
        [0.5489, 0.1323, 0.5134,  ..., 0.0376, 0.6469, 0.6547]],
       requires_grad=True)


In [102]:
torch.save(ksom.state_dict(), './ksomCIFAR.pth')