In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from network import Network
from data import get_MNIST_data_loaders
from train import train_network
from evaluate import accuracy

In [2]:
def load_network(fname):
    state = torch.load(fname)
    model = Network(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10,
        freeze_neurons=state["freeze_neurons"], 
        freeze_g=state["freeze_activations"]
    )
    model.load_state_dict(state["model_state_dict"])
    return model

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

variant = "p"
in_dim = 1 if variant == "p" else 28
hidden_dim = 128
out_dim = 10
batch_size = 256

train_loader, test_loader = get_MNIST_data_loaders(batch_size, variant=variant)

device=device(type='cpu')


In [3]:
model = Network(
    in_dim, 
    hidden_dim, 
    out_dim,
    freeze_neurons=True, 
    freeze_g=True
)

In [30]:
train_network(
    model, 
    train_loader, 
    epochs=30, 
    lr=1e-3, 
    variant=variant,
    C=0
)

Epoch 1 | Loss: 46.079750061035156 | lr: [0.0005]
Epoch 2 | Loss: 46.058109283447266 | lr: [0.0005]
Epoch 3 | Loss: 46.01806640625 | lr: [0.0005]
Epoch 4 | Loss: 45.99532699584961 | lr: [0.0005]
Epoch 5 | Loss: 46.052974700927734 | lr: [0.0005]
Epoch 6 | Loss: 45.98429870605469 | lr: [0.0005]
Epoch 7 | Loss: 46.000118255615234 | lr: [0.0005]
Epoch 8 | Loss: 45.968379974365234 | lr: [0.0005]
Epoch 9 | Loss: 45.97325897216797 | lr: [0.0005]
Epoch 10 | Loss: 46.04072570800781 | lr: [0.00025]
Epoch 11 | Loss: 45.86256408691406 | lr: [0.00025]
Epoch 12 | Loss: 45.87454605102539 | lr: [0.00025]
Epoch 13 | Loss: 45.865272521972656 | lr: [0.00025]
Epoch 14 | Loss: 45.91133117675781 | lr: [0.00025]
Epoch 15 | Loss: 45.876670837402344 | lr: [0.00025]
Epoch 16 | Loss: 45.90795135498047 | lr: [0.00025]
Epoch 17 | Loss: 45.853858947753906 | lr: [0.00025]
Epoch 18 | Loss: 45.8655891418457 | lr: [0.00025]
Epoch 19 | Loss: 45.83961486816406 | lr: [0.00025]
Epoch 20 | Loss: 45.82017517089844 | lr: [0.0

In [31]:
train_acc = accuracy(model, train_loader, variant=variant, device=device)
test_acc = accuracy(model, test_loader, variant=variant, device=device)
print(f"Train accuracy: {train_acc} | Test accuracy: {test_acc}")

Train accuracy: 0.9493833184242249 | Test accuracy: 0.9323999881744385


In [7]:
sum(p.numel() for p in filter(lambda x: x.requires_grad, model.parameters()))

1450

In [32]:
variant = "p" if model.in_dim == 1 else "l"
lr = 0.00001
epochs = 10
hidden_dim = model.hidden_dim
neuron_type = "ekfr"
freeze_neurons = True
freeze_activations = True

torch.save(
    {
        "model_state_dict": model.to(torch.device("cpu")).state_dict(),
        "train_accuracy": train_acc,
        "test_accuracy": test_acc,
        "lr": lr,
        "epochs": epochs,
        "hidden_dim": hidden_dim,
        "neuron_type": neuron_type,
        "variant": variant,
        "freeze_neurons": freeze_neurons,
        "freeze_activations": freeze_activations
    },
    f"model/network_params/{variant}_{neuron_type}_{hidden_dim}_{freeze_neurons}_{freeze_activations}.pt"
)