In [None]:
import torch

from network import Network
from data import get_MNIST_data_loaders
from train import train_network
from evaluate import accuracy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

variant = "l"
in_dim = 1 if variant == "p" else 28
hidden_dim = 64
out_dim = 10
batch_size = 256

train_loader, test_loader = get_MNIST_data_loaders(batch_size, variant=variant)
model = Network(
    in_dim, 
    hidden_dim, 
    out_dim, 
    neuron_type="ekfr", 
    freeze_neurons=False, 
    freeze_g=True
)

In [None]:
train_network(
    model, 
    train_loader, 
    epochs=30, 
    lr=0.05, 
    variant=variant,
    C=0
)

In [None]:
train_acc = accuracy(model, train_loader, variant=variant, device=device)
test_acc = accuracy(model, test_loader, variant=variant, device=device)
print(f"Train accuracy: {train_acc} | Test accuracy: {test_acc}")

In [None]:
def load_network(fname):
    state = torch.load(fname)
    model = Network(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10, 
        neuron_type=state["neuron_type"], 
        freeze_neurons=state["freeze_neurons"], 
        freeze_g=state["freeze_activations"]
    )
    model.load_state_dict(state["model_state_dict"])
    return model

In [None]:
model = load_network("model/network_params/l_ekfr_64_True_True.pt")

In [None]:
import matplotlib.pyplot as plt

xs = torch.linspace(0, 20, 100)
xss = torch.stack([xs for _ in range(hidden_dim)], dim=1)
ys_a = torch.stack([model.hidden_neurons.kernel(x, var="a") for x in xs]).detach()
ys_b = torch.stack([model.hidden_neurons.kernel(x, var="b") for x in xs]).detach()

plt.figure()
plt.plot(xss, ys_a, alpha=0.5)

plt.figure()
plt.plot(xss, ys_b, alpha=0.5);

In [None]:
import torch
torch.load("model/network_params/l_ekfr_128_True_True.pt")

In [5]:
import torch
import torch.nn.functional as F

In [None]:
def train_model(
    model,
    criterion,
    optimizer,
    Is_tr,
    fs_tr,
    epochs: int = 100,
    print_every: int = 10,
    bin_size = 20,
    up_factor = 10,
    scheduler = None
):
    losses = []

    for epoch in range(epochs):
        total_loss = 0
        for Is, fs in zip(Is_tr, fs_tr):
            batch_size = Is.shape[0]
            loss = torch.zeros(batch_size).to(model.device)
            model.reset(batch_size)
            
            for i in range(Is.shape[1]):
                f = model(Is[:, i])
                
                # up-weight loss for non-zero firing rate
                alpha = torch.ones(batch_size).to(model.device)
                alpha[torch.logical_or(fs[:, i] > 0, f > 0)] = up_factor
                loss += alpha * criterion(f * bin_size, fs[:, i] * bin_size)
            
            mean_loss = torch.mean(loss)
            optimizer.zero_grad()
            mean_loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += mean_loss.item()
        losses.append(total_loss)

        if scheduler is not None:
            scheduler.step()
        
        if (epoch+1) % print_every == 0:
            if scheduler is None:
                print(f"Epoch {epoch+1} | Loss: {total_loss}")
            else:
                curr_lr = scheduler.get_last_lr()
                print(f"Epoch {epoch+1} | Loss: {total_loss} | lr: {curr_lr}")

        if len(losses) >= 3 and losses[-1] == losses[-2] == losses[-3]:
            return losses
        
    return losses

In [7]:
g = PolynomialActivation(1, 100, 100, 20)
model = ExponentialKernelFiringRateModel(g, torch.tensor([1, 2, 3, 4, 5]).to(torch.float32), 20)