In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [2]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [41]:
def compute_accuracy(network, x, y):
    with torch.no_grad():
        correct, total = 0, 0
        logits = network(x)
        y_pred = torch.argmax(logits, dim=1)
        correct += torch.sum(y_pred == y)
        total += x.size(0)
        return (correct / total).item()

def compute_loss(network, x, y):
    with torch.no_grad():
        loss_fn = nn.CrossEntropyLoss()
        one_hots = torch.eye(10, 10).to('cuda')
        total = 0
        points = 0
        logits = network(x)
        one_hots = one_hots[y]
        total += loss_fn(logits, one_hots)
        points += x.size(0)
        return total / points

In [42]:
def data(seed=0):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    seed = seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

    # load dataset
    train = torchvision.datasets.MNIST(root='.', train=True, transform=torchvision.transforms.ToTensor(), download=True)
    test = torchvision.datasets.MNIST(root='.', train=False, transform=torchvision.transforms.ToTensor(), download=True)

    train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train, range(1000)), batch_size=1000, shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(test, range(1000)), batch_size=1000, shuffle=True)

    train_x, train_y = next(iter(train_loader))
    test_x, test_y = next(iter(test_loader))

    train_x = train_x.to(device)
    train_y = train_y.to(device)
    test_x = test_x.to(device)
    test_y = test_y.to(device)

    return train_x, train_y, test_x, test_y

In [43]:
def spike_init(model):
    # multiply weights by 8
    for name, param in model.named_parameters():
        if 'weight' in name:
            param.data = param.data * 8

In [44]:
model = MLP().to(device)
spike_init(model)
train_x, train_y, test_x, test_y = data()

In [45]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

In [46]:
train_x.shape, train_y.shape, test_x.shape, test_y.shape

(torch.Size([1000, 1, 28, 28]),
 torch.Size([1000]),
 torch.Size([1000, 1, 28, 28]),
 torch.Size([1000]))

In [47]:
epochs = 100
lr = 0.001
optimizer = optim.SGD(model.parameters(), lr=lr)
train_losses = []
train_accuracies = []
test_accuarcies = []

In [48]:
def plot(train_losses, train_accuracies, test_accuracies):
    clear_output(True)
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    plt.plot(train_losses)
    plt.title('Train loss')
    plt.subplot(122)
    plt.plot(train_accuracies, label='train')
    plt.plot(test_accuracies, label='test')
    plt.title('Accuracy')
    plt.legend()
    plt.show()

In [49]:
for i in range(epochs):
    optimizer.zero_grad()
    train_acc = compute_accuracy(model, train_x, train_y)
    test_acc = compute_accuracy(model, test_x, test_y)
    train_losses.append(compute_loss(model, train_x, train_y))
    train_accuracies.append(train_acc)
    test_accuarcies.append(test_acc)
    train_losses[0].backward()
    optimizer.step()
    plot(train_losses, train_accuracies, test_accuarcies)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn