In [None]:
import torch

from network import Network
from data import get_MNIST_data_loaders
from train import train_network
from evaluate import accuracy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

variant = "l"
in_dim = 1 if variant == "p" else 28
hidden_dim = 64
out_dim = 10
batch_size = 256

train_loader, test_loader = get_MNIST_data_loaders(batch_size, variant=variant)
model = Network(
    in_dim, 
    hidden_dim, 
    out_dim, 
    neuron_type="ekfr", 
    freeze_neurons=False, 
    freeze_g=True
)

In [None]:
train_network(
    model, 
    train_loader, 
    epochs=30, 
    lr=0.05, 
    variant=variant,
    C=0
)

In [None]:
train_acc = accuracy(model, train_loader, variant=variant, device=device)
test_acc = accuracy(model, test_loader, variant=variant, device=device)
print(f"Train accuracy: {train_acc} | Test accuracy: {test_acc}")

In [None]:
def load_network(fname):
    state = torch.load(fname)
    model = Network(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10, 
        neuron_type=state["neuron_type"], 
        freeze_neurons=state["freeze_neurons"], 
        freeze_g=state["freeze_activations"]
    )
    model.load_state_dict(state["model_state_dict"])
    return model

In [None]:
model = load_network("model/network_params/l_ekfr_64_True_True.pt")

In [None]:
import matplotlib.pyplot as plt

xs = torch.linspace(0, 20, 100)
xss = torch.stack([xs for _ in range(hidden_dim)], dim=1)
ys_a = torch.stack([model.hidden_neurons.kernel(x, var="a") for x in xs]).detach()
ys_b = torch.stack([model.hidden_neurons.kernel(x, var="b") for x in xs]).detach()

plt.figure()
plt.plot(xss, ys_a, alpha=0.5)

plt.figure()
plt.plot(xss, ys_b, alpha=0.5);

In [2]:
import torch
torch.load("model/network_params/l_ekfr_128_True_True.pt")

{'model_state_dict': OrderedDict([('fc1.weight',
               tensor([[ 2.3318,  0.2502, -1.3602,  ...,  0.8194,  3.7907,  4.7172],
                       [ 2.2966, -1.7131, -3.2557,  ..., 19.2983, 14.1500, 10.4019],
                       [ 3.4089,  8.7084,  3.6688,  ...,  6.9564,  5.8821,  6.9156],
                       ...,
                       [-3.5263, -4.2181,  0.7863,  ...,  6.4039, 13.5257,  7.1356],
                       [ 4.5333, -1.5102,  0.6951,  ...,  0.9540, -0.6415,  2.9155],
                       [ 2.9168,  5.7390,  0.9828,  ...,  3.2207, -1.6872,  2.3443]])),
              ('fc1.bias',
               tensor([ 9.9357e-01,  6.9180e-01,  1.1555e+00,  4.8776e-01,  2.9918e+00,
                        1.0772e+00,  6.4470e-01,  1.1532e+00,  9.2926e-01,  1.5291e+00,
                        4.6782e-01,  6.7773e-01,  2.3076e-01,  1.8055e+00,  1.6767e+00,
                        2.6175e-01,  6.1275e-01,  5.9095e-01,  8.0888e-01,  2.5238e-01,
                        9.0050e