In [1]:
import os
import random
import pickle

import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from model import (
    BatchEKFR, BatchGFR, BatchPolynomialActivation, PolynomialActivation, 
    GeneralizedFiringRateModel, ExponentialKernelFiringRateModel
)

from network import (
    get_params, get_neuron_layer, Network
)

In [3]:
def get_data_loaders(batch_size):
    train_set = torchvision.datasets.MNIST('data/mnist/train', download=True, train=True, transform=torchvision.transforms.ToTensor())
    test_set = torchvision.datasets.MNIST('data/mnist/test', download=True, train=False, transform=torchvision.transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader
    
# x: shape [batch_size, 28, 28]
# returns shape [batch_size, seq_length, in_dim]
def reshape_image(x, variant="p"):
    if variant == "p":
        return x.reshape(x.shape[0], -1, 1)
    else:
        return x
    
def accuracy(model, data_loader, variant="p"):
    with torch.no_grad():
        correct, total = 0, 0
        for x, label in tqdm(data_loader):
            x = x.reshape(x.shape[0], 28, 28)
            x = reshape_image(x, variant=variant)

            # sequentially send input into network
            model.reset(x.shape[0])
            for i in range(x.shape[1]):
                model(x[:, i, :])

            total_pred = torch.zeros(x.shape[0], 10)
            for _ in range(5):
                pred_y = model(model.zero_input(x.shape[0]))
                total_pred += F.softmax(pred_y, dim=1) # add softmax
            correct += torch.sum(torch.argmax(total_pred, dim=1) == label)
            total += x.shape[0]
    return correct / total
    
def train_network(model, train_loader, epochs, lr=0.005, variant="p"):
    criterion = torch.nn.CrossEntropyLoss()
    #optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1, centered=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        total_loss = 0

        for x, label in tqdm(train_loader):
            x = x.reshape(x.shape[0], 28, 28)
            x = reshape_image(x, variant=variant)
            
            # sequentially send input into network
            model.reset(x.shape[0])
            for i in range(x.shape[1]):
                model(x[:, i, :])
                
            loss = 0
            for _ in range(5):
                pred_y = model(model.zero_input(x.shape[0]))
                loss += criterion(pred_y, F.one_hot(label, num_classes=10).to(torch.float32))
            loss /= 5
            
            loss += model.reg()
            
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            total_loss += loss
            
        if (epoch+1) % 1 == 0:
            print(f"Epoch {epoch+1} / Loss: {total_loss}")

In [4]:
batch_size = 256
variant = "l"
in_dim = 1 if variant == "p" else 28
out_dim = 10
hidden_dim = 128
epochs = 30
train_loader, test_loader = get_data_loaders(batch_size)
model = Network(in_dim, hidden_dim, out_dim, neuron_type="ekfr", freeze_neurons=False)

In [None]:
train_network(model, train_loader, epochs, lr=0.005, variant=variant)

100%|█████████████████████████████████████████| 235/235 [00:39<00:00,  5.92it/s]


Epoch 1 / Loss: 447.07000732421875


100%|█████████████████████████████████████████| 235/235 [00:41<00:00,  5.70it/s]


Epoch 2 / Loss: 270.161376953125


100%|█████████████████████████████████████████| 235/235 [00:40<00:00,  5.77it/s]


Epoch 3 / Loss: 185.6604766845703


100%|█████████████████████████████████████████| 235/235 [00:40<00:00,  5.74it/s]


Epoch 4 / Loss: 146.80886840820312


100%|█████████████████████████████████████████| 235/235 [00:40<00:00,  5.75it/s]


Epoch 5 / Loss: 125.95415496826172


100%|█████████████████████████████████████████| 235/235 [00:40<00:00,  5.80it/s]


Epoch 6 / Loss: 113.37059020996094


100%|█████████████████████████████████████████| 235/235 [00:40<00:00,  5.75it/s]


Epoch 7 / Loss: 103.77672576904297


100%|█████████████████████████████████████████| 235/235 [00:40<00:00,  5.78it/s]


Epoch 8 / Loss: 98.12919616699219


 87%|███████████████████████████████████▌     | 204/235 [00:34<00:05,  5.66it/s]

In [None]:
train_acc = accuracy(model, train_loader, variant=variant)
test_acc = accuracy(model, test_loader, variant=variant)
print(f"Train accuracy: {train_acc} / Test accuracy: {test_acc}")

In [None]:
plt.matshow(model.Wh.detach())
plt.colorbar()

In [None]:
plt.matshow(model.Wx.detach())
plt.colorbar()

In [None]:
plt.matshow(model.Wy.detach())
plt.colorbar()