In [7]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from network import Network
from data import get_MNIST_data_loaders
from train import train_network
from evaluate import accuracy

import numpy as np

In [2]:
def load_network(fname):
    state = torch.load(fname)
    model = Network(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10,
        freeze_neurons=state["freeze_neurons"], 
        freeze_g=state["freeze_activations"]
    )
    model.load_state_dict(state["model_state_dict"])
    return model

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

variant = "p"
batch_size = 256

train_loader, test_loader = get_MNIST_data_loaders(batch_size, variant=variant)

device=device(type='cpu')


In [4]:
model = load_network("model/network_params/l_256_False_True.pt")

In [23]:
def permuted_network(fname, variant):
    model = load_network(fname)

    in_dim = model.in_dim
    hidden_dim = model.hidden_dim
    out_dim = model.out_dim
    
    new_model = Network(
        in_dim, 
        hidden_dim, 
        out_dim,
        freeze_neurons=True, 
        freeze_g=True
    )

    idxs = (np.random.rand(256) * 256).astype(int)
    new_model.hidden_neurons.g.max_current = torch.nn.Parameter(model.hidden_neurons.g.max_current.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.max_firing_rate = torch.nn.Parameter(model.hidden_neurons.g.max_firing_rate.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.max_firing_rate = torch.nn.Parameter(model.hidden_neurons.g.max_firing_rate.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.b = torch.nn.Parameter(model.hidden_neurons.g.b.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.g.poly_coeff = torch.nn.Parameter(model.hidden_neurons.g.poly_coeff.detach().clone()[idxs], requires_grad=False)
    new_model.hidden_neurons.a = torch.nn.Parameter(model.hidden_neurons.a.detach().clone()[idxs, :], requires_grad=False)
    new_model.hidden_neurons.b = torch.nn.Parameter(model.hidden_neurons.b.detach().clone()[idxs, :], requires_grad=False)

    return new_model

In [25]:
model = permuted_network("model/network_params/l_256_False_True.pt", "l")

In [None]:
in_dim = 1 if variant == "p" else 28
hidden_dim = 128
out_dim = 10

model = Network(
    in_dim, 
    hidden_dim, 
    out_dim,
    freeze_neurons=True, 
    freeze_g=True
)

In [None]:
train_network(
    model, 
    train_loader, 
    epochs=30, 
    lr=1e-3, 
    variant=variant,
    C=0
)

In [None]:
train_acc = accuracy(model, train_loader, variant=variant, device=device)
test_acc = accuracy(model, test_loader, variant=variant, device=device)
print(f"Train accuracy: {train_acc} | Test accuracy: {test_acc}")

In [None]:
sum(p.numel() for p in filter(lambda x: x.requires_grad, model.parameters()))

In [None]:
variant = "p" if model.in_dim == 1 else "l"
lr = 0.00001
epochs = 10
hidden_dim = model.hidden_dim
neuron_type = "ekfr"
freeze_neurons = True
freeze_activations = True

torch.save(
    {
        "model_state_dict": model.to(torch.device("cpu")).state_dict(),
        "train_accuracy": train_acc,
        "test_accuracy": test_acc,
        "lr": lr,
        "epochs": epochs,
        "hidden_dim": hidden_dim,
        "neuron_type": neuron_type,
        "variant": variant,
        "freeze_neurons": freeze_neurons,
        "freeze_activations": freeze_activations
    },
    f"model/network_params/{variant}_{neuron_type}_{hidden_dim}_{freeze_neurons}_{freeze_activations}.pt"
)