In [1]:
import torch
#import torch_directml
import torch.nn.functional as F
import matplotlib.pyplot as plt

from network import Network
from data import get_MNIST_data_loaders
from train import train_network
from evaluate import accuracy

import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

variant = "l"
batch_size = 256

train_loader, test_loader = get_MNIST_data_loaders(batch_size, variant=variant)

device=device(type='cpu')


In [15]:
def load_network(fname):
    state = torch.load(fname)
    model = Network(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10,
        freeze_neurons=state["freeze_neurons"], 
        freeze_g=state["freeze_activations"]
    )
    model.load_state_dict(state["model_state_dict"])
    return model

def load_rnn(fname, lstm=False):
    state = torch.load(fname)
    model = RNN(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10,
        lstm=lstm
    )
    model.load_state_dict(state["model_state_dict"])
    return model

def permuted_network(fname, variant, n_neurons):
    model = load_network(fname)

    in_dim = model.in_dim
    hidden_dim = model.hidden_dim
    out_dim = model.out_dim
    
    new_model = Network(
        in_dim, 
        hidden_dim, 
        out_dim,
        freeze_neurons=True, 
        freeze_g=True
    )

    idxs = (np.random.rand(n_neurons) * n_neurons).astype(int)
    new_model.hidden_neurons.g.max_current = torch.nn.Parameter(model.hidden_neurons.g.max_current.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.max_firing_rate = torch.nn.Parameter(model.hidden_neurons.g.max_firing_rate.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.max_firing_rate = torch.nn.Parameter(model.hidden_neurons.g.max_firing_rate.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.b = torch.nn.Parameter(model.hidden_neurons.g.b.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.poly_coeff = torch.nn.Parameter(model.hidden_neurons.g.poly_coeff.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.a = torch.nn.Parameter(model.hidden_neurons.a.detach().clone()[idxs, :], requires_grad=False)
    new_model.hidden_neurons.b = torch.nn.Parameter(model.hidden_neurons.b.detach().clone()[idxs, :], requires_grad=False)

    return new_model

In [4]:
class RNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, lstm=False, device=None):
        super().__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        if lstm:
            self.rnn = torch.nn.LSTM(in_dim, hidden_dim)
        else:
            self.rnn = torch.nn.RNN(in_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, out_dim)
        self.device = device
        self.lstm = lstm

    def reset(self, batch_size):
        self.xh = torch.zeros(1, batch_size, self.hidden_dim).to(self.device)
        if self.lstm:
            self.xc = torch.zeros(1, batch_size, self.hidden_dim).to(self.device)
        
    def zero_input(self, batch_size):
        return torch.zeros(batch_size, self.in_dim).to(self.device)
        
    def forward(self, x):
        if self.lstm:
            z, (self.xh, self.xc) = self.rnn(x.unsqueeze(dim=0), (self.xh, self.xc))
        else:
            z, self.xh = self.rnn(x.unsqueeze(dim=0), self.xh)
        return self.fc(z.reshape(x.shape[0], self.hidden_dim))

In [8]:
model = load_network("model/network_params/l_64_True_True.pt")

In [16]:
model = load_rnn("model/network_params/l_64_lstm.pt", lstm=True)

In [22]:
in_dim = 1 if variant == "p" else 28
hidden_dim = 30#64
out_dim = 10
lr = 1e-3
epochs = 100
freeze_neurons = True
freeze_activations = True

In [23]:
model = RNN(in_dim, hidden_dim, out_dim, lstm=True)

In [None]:
if freeze_neurons:
    model = permuted_network(f"model/network_params/l_{hidden_dim}_False_True.pt", "l", hidden_dim)
else:
    model = Network(
        in_dim, 
        hidden_dim, 
        out_dim,
        freeze_neurons=freeze_neurons, 
        freeze_g=freeze_activations
    )

In [25]:
train_network(
    model, 
    train_loader, 
    epochs=epochs, 
    lr=lr,
    variant=variant,
)

Epoch 1 | Loss: 382.4295349121094
Epoch 2 | Loss: 154.76206970214844
Epoch 3 | Loss: 104.01383972167969
Epoch 4 | Loss: 80.13578033447266
Epoch 5 | Loss: 65.38260650634766
Epoch 6 | Loss: 55.955020904541016
Epoch 7 | Loss: 49.36930847167969
Epoch 8 | Loss: 43.924015045166016
Epoch 9 | Loss: 39.960758209228516
Epoch 10 | Loss: 36.19942855834961
Epoch 11 | Loss: 33.557701110839844
Epoch 12 | Loss: 31.307674407958984
Epoch 13 | Loss: 29.214557647705078
Epoch 14 | Loss: 27.726770401000977
Epoch 15 | Loss: 26.066349029541016
Epoch 16 | Loss: 24.633543014526367
Epoch 17 | Loss: 23.23685646057129
Epoch 18 | Loss: 22.37148094177246
Epoch 19 | Loss: 21.550722122192383
Epoch 20 | Loss: 20.70616912841797
Epoch 21 | Loss: 19.613544464111328
Epoch 22 | Loss: 19.134441375732422
Epoch 23 | Loss: 18.733346939086914
Epoch 24 | Loss: 17.594409942626953
Epoch 25 | Loss: 17.008438110351562
Epoch 26 | Loss: 17.099416732788086
Epoch 27 | Loss: 16.186809539794922
Epoch 28 | Loss: 16.370229721069336
Epoch 29 

In [26]:
train_acc = accuracy(model, train_loader, variant=variant, device=device)
test_acc = accuracy(model, test_loader, variant=variant, device=device)
print(f"Train accuracy: {train_acc} | Test accuracy: {test_acc}")

Train accuracy: 0.9944833517074585 | Test accuracy: 0.978600025177002


In [27]:
sum(p.numel() for p in filter(lambda x: x.requires_grad, model.parameters()))

7510

In [None]:
torch.save(
    {
        "model_state_dict": model.to(torch.device("cpu")).state_dict(),
        "train_accuracy": train_acc,
        "test_accuracy": test_acc,
        "lr": lr,
        "epochs": epochs,
        "hidden_dim": hidden_dim,
        "variant": variant,
        "freeze_neurons": freeze_neurons,
        "freeze_activations": freeze_activations
    },
    f"model/network_params/{variant}_{hidden_dim}_{freeze_neurons}_{freeze_activations}.pt"
)

In [None]:
torch.save(
    {
        "model_state_dict": model.to(torch.device("cpu")).state_dict(),
        "train_accuracy": train_acc,
        "test_accuracy": test_acc,
        "lr": lr,
        "epochs": epochs,
        "hidden_dim": hidden_dim,
        "variant": variant,
        "freeze_neurons": freeze_neurons,
        "freeze_activations": freeze_activations
    },
    f"model/network_params/{variant}_{hidden_dim}_rnn.pt"
)

In [None]:
torch.save(
    {
        "model_state_dict": model.to(torch.device("cpu")).state_dict(),
        "train_accuracy": train_acc,
        "test_accuracy": test_acc,
        "lr": lr,
        "epochs": epochs,
        "hidden_dim": hidden_dim,
        "variant": variant,
        "freeze_neurons": freeze_neurons,
        "freeze_activations": freeze_activations
    },
    f"model/network_params/{variant}_{hidden_dim}_lstm.pt"
)