In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from sklearn.cluster import KMeans
import numpy as np

In [2]:
class RBFNet(nn.Module):
    def __init__(self, num_neurons, num_classes):
        super(RBFNet, self).__init__()
        self.num_neurons = num_neurons
        self.num_classes = num_classes
        self.fc1 = nn.Linear(input_size, num_neurons, bias=False)
        self.fc2 = nn.Linear(num_neurons, num_classes, bias=False)

    def forward(self, x):
        x = torch.exp(-gamma * torch.cdist(x, self.fc1.weight) ** 2)
        x = self.fc2(x)
        return x

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|████████████████████████████| 9912422/9912422 [00:05<00:00, 1818032.28it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|███████████████████████████████| 28881/28881 [00:00<00:00, 86525495.59it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████| 1648877/1648877 [00:01<00:00, 1047662.96it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 4542/4542 [00:00<00:00, 18406307.99it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:
input_size = 28 * 28
num_classes = 10
learning_rate = 0.001
num_epochs = 5
batch_size = 100

num_neurons_list = [5, 10, 15, 20, 25]

for num_neurons in num_neurons_list:
    kmeans = KMeans(n_clusters=num_neurons, random_state=0)
    kmeans.fit(train_dataset.data.numpy().reshape(-1, 28*28))
    centers = torch.tensor(kmeans.cluster_centers_).float()
    gamma = 1.0 / torch.median(torch.cdist(centers, centers))
    model = RBFNet(num_neurons, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.reshape(-1, 28*28)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    with torch.no_grad():
        correct = 0
        total = 0
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
        for images, labels in test_loader:
            images = images.reshape(-1, 28*28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the network with {} neurons in the hidden layer: {} %'.format(num_neurons, 100 * correct / total))

Epoch [1/5], Step [100/600], Loss: 2.3017
Epoch [1/5], Step [200/600], Loss: 2.2964
Epoch [1/5], Step [300/600], Loss: 2.2734
Epoch [1/5], Step [400/600], Loss: 2.2591
Epoch [1/5], Step [500/600], Loss: 2.2781
Epoch [1/5], Step [600/600], Loss: 2.2609
Epoch [2/5], Step [100/600], Loss: 2.2758
Epoch [2/5], Step [200/600], Loss: 2.2439
Epoch [2/5], Step [300/600], Loss: 2.2358
Epoch [2/5], Step [400/600], Loss: 2.2369
Epoch [2/5], Step [500/600], Loss: 2.2184
Epoch [2/5], Step [600/600], Loss: 2.1862
Epoch [3/5], Step [100/600], Loss: 2.2010
Epoch [3/5], Step [200/600], Loss: 2.1091
Epoch [3/5], Step [300/600], Loss: 2.1630
Epoch [3/5], Step [400/600], Loss: 2.1395
Epoch [3/5], Step [500/600], Loss: 2.1155
Epoch [3/5], Step [600/600], Loss: 2.1175
Epoch [4/5], Step [100/600], Loss: 2.0712
Epoch [4/5], Step [200/600], Loss: 2.0616
Epoch [4/5], Step [300/600], Loss: 2.0720
Epoch [4/5], Step [400/600], Loss: 2.0853
Epoch [4/5], Step [500/600], Loss: 2.0461
Epoch [4/5], Step [600/600], Loss: