In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from sklearn.cluster import KMeans
import numpy as np

In [3]:
class RBFNet(nn.Module):
    def __init__(self, num_neurons, num_classes):
        super(RBFNet, self).__init__()
        self.num_neurons = num_neurons
        self.num_classes = num_classes
        self.fc1 = nn.Linear(input_size, num_neurons, bias=False)
        self.fc2 = nn.Linear(num_neurons, num_classes, bias=False)

    def forward(self, x):
        x = torch.exp(-gamma * torch.cdist(x, self.fc1.weight) ** 2)
        x = self.fc2(x)
        return x

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 159589857.22it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 29984082.63it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 45766053.42it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 12898123.74it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [5]:
input_size = 28 * 28
num_classes = 10
learning_rate = 0.001
num_epochs = 5
batch_size = 100

num_neurons_list = [5, 10, 15, 20, 25]

for num_neurons in num_neurons_list:
    kmeans = KMeans(n_clusters=num_neurons, random_state=0)
    kmeans.fit(train_dataset.data.numpy().reshape(-1, 28*28))
    centers = torch.tensor(kmeans.cluster_centers_).float()
    gamma = 1.0 / torch.median(torch.cdist(centers, centers))
    model = RBFNet(num_neurons, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.reshape(-1, 28*28)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    with torch.no_grad():
        correct = 0
        total = 0
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
        for images, labels in test_loader:
            images = images.reshape(-1, 28*28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the network with {} neurons in the hidden layer: {} %'.format(num_neurons, 100 * correct / total))



Epoch [1/5], Step [100/600], Loss: 2.2790
Epoch [1/5], Step [200/600], Loss: 2.2742
Epoch [1/5], Step [300/600], Loss: 2.2815
Epoch [1/5], Step [400/600], Loss: 2.2838
Epoch [1/5], Step [500/600], Loss: 2.2871
Epoch [1/5], Step [600/600], Loss: 2.2332
Epoch [2/5], Step [100/600], Loss: 2.2624
Epoch [2/5], Step [200/600], Loss: 2.2567
Epoch [2/5], Step [300/600], Loss: 2.2547
Epoch [2/5], Step [400/600], Loss: 2.2071
Epoch [2/5], Step [500/600], Loss: 2.1969
Epoch [2/5], Step [600/600], Loss: 2.2026
Epoch [3/5], Step [100/600], Loss: 2.1924
Epoch [3/5], Step [200/600], Loss: 2.2031
Epoch [3/5], Step [300/600], Loss: 2.1507
Epoch [3/5], Step [400/600], Loss: 2.1030
Epoch [3/5], Step [500/600], Loss: 2.1042
Epoch [3/5], Step [600/600], Loss: 2.0878
Epoch [4/5], Step [100/600], Loss: 2.0849
Epoch [4/5], Step [200/600], Loss: 2.1022
Epoch [4/5], Step [300/600], Loss: 2.0804
Epoch [4/5], Step [400/600], Loss: 2.0355
Epoch [4/5], Step [500/600], Loss: 2.0459
Epoch [4/5], Step [600/600], Loss:



Epoch [1/5], Step [100/600], Loss: 2.2926
Epoch [1/5], Step [200/600], Loss: 2.2741
Epoch [1/5], Step [300/600], Loss: 2.2728
Epoch [1/5], Step [400/600], Loss: 2.2620
Epoch [1/5], Step [500/600], Loss: 2.2580
Epoch [1/5], Step [600/600], Loss: 2.2620
Epoch [2/5], Step [100/600], Loss: 2.2269
Epoch [2/5], Step [200/600], Loss: 2.2264
Epoch [2/5], Step [300/600], Loss: 2.2199
Epoch [2/5], Step [400/600], Loss: 2.2160
Epoch [2/5], Step [500/600], Loss: 2.1718
Epoch [2/5], Step [600/600], Loss: 2.1576
Epoch [3/5], Step [100/600], Loss: 2.1005
Epoch [3/5], Step [200/600], Loss: 2.1129
Epoch [3/5], Step [300/600], Loss: 2.0537
Epoch [3/5], Step [400/600], Loss: 2.0607
Epoch [3/5], Step [500/600], Loss: 1.9754
Epoch [3/5], Step [600/600], Loss: 1.9999
Epoch [4/5], Step [100/600], Loss: 2.0149
Epoch [4/5], Step [200/600], Loss: 1.9983
Epoch [4/5], Step [300/600], Loss: 1.9270
Epoch [4/5], Step [400/600], Loss: 1.8912
Epoch [4/5], Step [500/600], Loss: 1.8973
Epoch [4/5], Step [600/600], Loss:



Epoch [1/5], Step [100/600], Loss: 2.2892
Epoch [1/5], Step [200/600], Loss: 2.2787
Epoch [1/5], Step [300/600], Loss: 2.2743
Epoch [1/5], Step [400/600], Loss: 2.2728
Epoch [1/5], Step [500/600], Loss: 2.2795
Epoch [1/5], Step [600/600], Loss: 2.2340
Epoch [2/5], Step [100/600], Loss: 2.1888
Epoch [2/5], Step [200/600], Loss: 2.2114
Epoch [2/5], Step [300/600], Loss: 2.1701
Epoch [2/5], Step [400/600], Loss: 2.1091
Epoch [2/5], Step [500/600], Loss: 2.0739
Epoch [2/5], Step [600/600], Loss: 2.1024
Epoch [3/5], Step [100/600], Loss: 2.0283
Epoch [3/5], Step [200/600], Loss: 2.0264
Epoch [3/5], Step [300/600], Loss: 1.9961
Epoch [3/5], Step [400/600], Loss: 1.9572
Epoch [3/5], Step [500/600], Loss: 1.8978
Epoch [3/5], Step [600/600], Loss: 1.9280
Epoch [4/5], Step [100/600], Loss: 1.8682
Epoch [4/5], Step [200/600], Loss: 1.8240
Epoch [4/5], Step [300/600], Loss: 1.8133
Epoch [4/5], Step [400/600], Loss: 1.8641
Epoch [4/5], Step [500/600], Loss: 1.7251
Epoch [4/5], Step [600/600], Loss:



Epoch [1/5], Step [100/600], Loss: 2.2918
Epoch [1/5], Step [200/600], Loss: 2.2697
Epoch [1/5], Step [300/600], Loss: 2.2594
Epoch [1/5], Step [400/600], Loss: 2.2456
Epoch [1/5], Step [500/600], Loss: 2.2539
Epoch [1/5], Step [600/600], Loss: 2.1935
Epoch [2/5], Step [100/600], Loss: 2.2002
Epoch [2/5], Step [200/600], Loss: 2.1892
Epoch [2/5], Step [300/600], Loss: 2.1475
Epoch [2/5], Step [400/600], Loss: 2.1373
Epoch [2/5], Step [500/600], Loss: 2.0662
Epoch [2/5], Step [600/600], Loss: 2.0752
Epoch [3/5], Step [100/600], Loss: 2.0631
Epoch [3/5], Step [200/600], Loss: 1.9564
Epoch [3/5], Step [300/600], Loss: 1.9252
Epoch [3/5], Step [400/600], Loss: 1.9125
Epoch [3/5], Step [500/600], Loss: 1.8144
Epoch [3/5], Step [600/600], Loss: 1.8294
Epoch [4/5], Step [100/600], Loss: 1.7822
Epoch [4/5], Step [200/600], Loss: 1.7470
Epoch [4/5], Step [300/600], Loss: 1.7096
Epoch [4/5], Step [400/600], Loss: 1.6345
Epoch [4/5], Step [500/600], Loss: 1.6861
Epoch [4/5], Step [600/600], Loss:



Epoch [1/5], Step [100/600], Loss: 2.3005
Epoch [1/5], Step [200/600], Loss: 2.2678
Epoch [1/5], Step [300/600], Loss: 2.2686
Epoch [1/5], Step [400/600], Loss: 2.2473
Epoch [1/5], Step [500/600], Loss: 2.2372
Epoch [1/5], Step [600/600], Loss: 2.2132
Epoch [2/5], Step [100/600], Loss: 2.2102
Epoch [2/5], Step [200/600], Loss: 2.1662
Epoch [2/5], Step [300/600], Loss: 2.1158
Epoch [2/5], Step [400/600], Loss: 2.0760
Epoch [2/5], Step [500/600], Loss: 2.0901
Epoch [2/5], Step [600/600], Loss: 2.0263
Epoch [3/5], Step [100/600], Loss: 1.9873
Epoch [3/5], Step [200/600], Loss: 1.9108
Epoch [3/5], Step [300/600], Loss: 1.8539
Epoch [3/5], Step [400/600], Loss: 1.8479
Epoch [3/5], Step [500/600], Loss: 1.7809
Epoch [3/5], Step [600/600], Loss: 1.6709
Epoch [4/5], Step [100/600], Loss: 1.7474
Epoch [4/5], Step [200/600], Loss: 1.7467
Epoch [4/5], Step [300/600], Loss: 1.7036
Epoch [4/5], Step [400/600], Loss: 1.6931
Epoch [4/5], Step [500/600], Loss: 1.5342
Epoch [4/5], Step [600/600], Loss: