In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms

In [7]:
class KSOM(nn.Module):
    def __init__(self, input_dim, output_dim, learning_rate, weights):
        super(KSOM, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.learning_rate = learning_rate
        self.sigma = output_dim / 2
        self.weights = nn.Parameter(weights)

    def forward(self, x):
    
        x = x.view(-1, self.input_dim)
        distances = torch.sum((self.weights - x.unsqueeze(1)) ** 2, dim=-1)
        winner = torch.argmin(distances, dim=1)
        neighborhood = torch.exp(-distances / (2 * self.sigma ** 2))
        delta = self.learning_rate * neighborhood.unsqueeze(-1) * (x.unsqueeze(1) - self.weights)
        self.weights.data[winner] += torch.sum(delta, dim=0)[winner]

        return winner

In [8]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [9]:
import torchvision.datasets as datasets

In [10]:
dataset = datasets.CIFAR10(root='../Datasets/', train=True, download=False, transform=transform)

In [11]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [12]:
ksom = KSOM(input_dim=32*32*3, output_dim=100, learning_rate=0.1, weights=torch.rand(100, 32*32*3))
ksom.load_state_dict(torch.load('./ksomCIFAR.pth'))

<All keys matched successfully>

In [13]:
num_epochs=5

In [14]:
import torch.optim as optim

optimizer = optim.SGD(ksom.parameters(), lr=0.1)

In [15]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        inputs, _ = data
        inputs = inputs.view(-1, 32*32*3)
        optimizer.zero_grad()
        winner = ksom(inputs)
        loss = torch.mean(torch.sum((ksom.weights - inputs.unsqueeze(1)) ** 6, dim=-1))
        loss=loss/5000
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: loss={loss.item()}")

print('Finished Training')

KeyboardInterrupt: 

In [None]:
torch.save(ksom.state_dict(), './ksomCIFAR.pth')
