In [None]:
import torch

from network import Network
from data import get_MNIST_data_loaders
from train import train_network
from evaluate import accuracy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

variant = "l"
in_dim = 1 if variant == "p" else 28
hidden_dim = 64
out_dim = 10
batch_size = 256

train_loader, test_loader = get_MNIST_data_loaders(batch_size, variant=variant)
model = Network(
    in_dim, 
    hidden_dim, 
    out_dim, 
    neuron_type="ekfr", 
    freeze_neurons=False, 
    freeze_g=True
)

In [None]:
train_network(
    model, 
    train_loader, 
    epochs=30, 
    lr=0.05, 
    variant=variant,
    C=0
)

In [None]:
train_acc = accuracy(model, train_loader, variant=variant, device=device)
test_acc = accuracy(model, test_loader, variant=variant, device=device)
print(f"Train accuracy: {train_acc} | Test accuracy: {test_acc}")

In [None]:
def load_network(fname):
    state = torch.load(fname)
    model = Network(
        1 if state["variant"] == "p" else 28, 
        state["hidden_dim"], 
        10, 
        neuron_type=state["neuron_type"], 
        freeze_neurons=state["freeze_neurons"], 
        freeze_g=state["freeze_activations"]
    )
    model.load_state_dict(state["model_state_dict"])
    return model

In [None]:
model = load_network("model/network_params/l_ekfr_64_True_True.pt")

In [None]:
import matplotlib.pyplot as plt

xs = torch.linspace(0, 20, 100)
xss = torch.stack([xs for _ in range(hidden_dim)], dim=1)
ys_a = torch.stack([model.hidden_neurons.kernel(x, var="a") for x in xs]).detach()
ys_b = torch.stack([model.hidden_neurons.kernel(x, var="b") for x in xs]).detach()

plt.figure()
plt.plot(xss, ys_a, alpha=0.5)

plt.figure()
plt.plot(xss, ys_b, alpha=0.5);

In [None]:
import torch
torch.load("model/network_params/l_ekfr_128_True_True.pt")

In [5]:
import torch
import torch.nn.functional as F

In [None]:
def train_model(
    model,
    criterion,
    optimizer,
    Is_tr,
    fs_tr,
    epochs: int = 100,
    print_every: int = 10,
    bin_size = 20,
    up_factor = 10,
    scheduler = None
):
    losses = []

    for epoch in range(epochs):
        total_loss = 0
        for Is, fs in zip(Is_tr, fs_tr):
            batch_size = Is.shape[0]
            loss = torch.zeros(batch_size).to(model.device)
            model.reset(batch_size)
            
            for i in range(Is.shape[1]):
                f = model(Is[:, i])
                
                # up-weight loss for non-zero firing rate
                alpha = torch.ones(batch_size).to(model.device)
                alpha[torch.logical_or(fs[:, i] > 0, f > 0)] = up_factor
                loss += alpha * criterion(f * bin_size, fs[:, i] * bin_size)
            
            mean_loss = torch.mean(loss)
            optimizer.zero_grad()
            mean_loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += mean_loss.item()
        losses.append(total_loss)

        if scheduler is not None:
            scheduler.step()
        
        if (epoch+1) % print_every == 0:
            if scheduler is None:
                print(f"Epoch {epoch+1} | Loss: {total_loss}")
            else:
                curr_lr = scheduler.get_last_lr()
                print(f"Epoch {epoch+1} | Loss: {total_loss} | lr: {curr_lr}")

        if len(losses) >= 3 and losses[-1] == losses[-2] == losses[-3]:
            return losses
        
    return losses

In [6]:
class ExponentialKernelFiringRateModel(torch.nn.Module):
    def __init__(
        self, 
        g, # activation function
        ds,
        bin_size,
        freeze_g = True,
        device = None
    ):
        super().__init__()
        self.g = g
        self.bin_size = bin_size
        self.device = device
        
        self.ds = torch.nn.Parameter(ds.clone().detach(), requires_grad=False)
        self.n = len(self.ds)
        self.a = torch.nn.Parameter(torch.ones(self.n) + torch.randn(self.n) * 0.001)
        self.b = torch.nn.Parameter(torch.randn(self.n) * 0.001)
        
        if freeze_g: self.g.freeze_parameters()
            
    
    # outputs a tensor of shape [B], firing rate predictions at time t
    def forward(
        self,
        currents # shape [B], currents for time t
    ):
        x = torch.outer(currents, self.a) # shape [B, n]
        y = 1000 * torch.outer(self.fs, self.b) # shape [B, n]
        self.v =  (1 - self.ds) * self.v + x + y # shape [B, n]
        print(self.v.shape)
        self.fs = self.g(torch.mean(self.v, dim=1).unsqueeze(1)) # shape [B]
        return self.fs
    
    def reset(self, batch_size):
        self.v = torch.zeros(batch_size, self.n).to(self.device)
        self.fs = torch.zeros(batch_size).to(self.device)
    
    @classmethod
    def from_params(cls, params, freeze_g=True, device=None):
        g = PolynomialActivation.from_params(params["g"])
        model = cls(g, params["ds"], params["bin_size"], freeze_g=freeze_g, device=device)
        model.a = torch.nn.Parameter(params["a"])
        model.b = torch.nn.Parameter(params["b"])
        return model

    def get_params(self):
        return {
            "a": self.a.detach().cpu(),
            "b": self.b.detach().cpu(),
            "g": self.g.get_params(),
            "ds": self.ds.detach().cpu(),
            "bin_size": self.bin_size
        }
    
    def freeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = False
            
    def unfreeze_parameters(self): # problematic
        for _, p in self.named_parameters():
            p.requires_grad = True
            
    def kernel(self, x, var="a"):
        a = self.a if var == "a" else self.b
        return torch.sum(a * torch.pow(1-self.ds, x))
    
    # Is: shape [seq_length]
    def predict(self, Is):
        pred_fs = []
        vs = []
        
        with torch.no_grad():
            self.reset(1)
            for i in range(len(Is)):
                f = self.forward(Is[i].reshape(1))
                vs.append(self.v.clone())
                pred_fs.append(f.clone())
        return torch.stack(pred_fs).squeeze(), torch.stack(vs).squeeze()
    
class PolynomialActivation(torch.nn.Module):
    def __init__(self, degree, max_current, max_firing_rate, bin_size):
        super().__init__()
        self.degree = degree
        self.max_current = max_current
        self.max_firing_rate = max_firing_rate
        self.bin_size = bin_size
        
        self.p = torch.nn.Parameter(torch.tensor([d for d in range(degree+1)]), requires_grad=False)
        self.poly_coeff = torch.nn.Parameter(torch.randn(self.degree + 1))
        self.b = torch.nn.Parameter(torch.tensor(0.0))
    
    # z: shape [B, 1]
    def forward(self, z):
        x = (z - self.b) / self.max_current # shape [B, n]
        poly = torch.einsum("ijk,jk->ij", x.unsqueeze(dim=2).pow(self.p.reshape(1, 1, -1)), self.poly_coeff ** 2) # shape [B, n]
        tan = self.max_firing_rate * F.tanh(poly) # ceil is the max firing rate
        return F.relu(tan).to(torch.float32) # shape [B, n]
    
    # initialize based on linear approximation of data
    @classmethod
    def from_data(cls, degree, max_current, max_firing_rate, bin_size, Is, fs):
        g = cls(degree, max_current, max_firing_rate, bin_size)
        
        x1, x2, y1, y2 = tuple([torch.tensor(0.0)] * 4)
        xs, ys = map(list, zip(*sorted(zip(Is.cpu(), fs.cpu()), key=lambda x: x[0])))
        i = np.argmax(ys)
        x2, y2 = xs[i], ys[i]
        for i in range(0, len(ys)):
            if ys[i] > 0.01:
                x1, y1 = (xs[i-1], ys[i-1]) if i - 1 > 0 else (xs[i], ys[i])
                break
                
        g.b = torch.nn.Parameter(x1.clone())
        poly_coeff = torch.randn(degree + 1) * 1e-1
        poly_coeff[1] = np.abs((y2 - y1) / (x2 - x1) * max_current)
        g.poly_coeff = torch.nn.Parameter(poly_coeff)
        
        return g
    
    @classmethod
    def from_params(cls, params):
        poly_coeff = torch.nn.Parameter(params["poly_coeff"])
        degree = len(poly_coeff) - 1
        max_current = params["max_current"]
        max_firing_rate = params["max_firing_rate"]
        bin_size = params["bin_size"]
        g = cls(degree, max_current, max_firing_rate, bin_size)
        g.poly_coeff = poly_coeff
        g.b = torch.nn.Parameter(params["b"])
        return g

    def get_params(self):
        return {
            "max_current": self.max_current,
            "max_firing_rate": self.max_firing_rate,
            "poly_coeff": self.poly_coeff.detach().cpu(),
            "b": self.b.detach().cpu(),
            "bin_size": self.bin_size
        }
    
    def freeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = False
            
    def unfreeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = True

In [7]:
g = PolynomialActivation(1, 100, 100, 20)
model = ExponentialKernelFiringRateModel(g, torch.tensor([1, 2, 3, 4, 5]).to(torch.float32), 20)

In [8]:
model.reset(10)
model(torch.ones(10))

torch.Size([10, 5])


tensor([[0.0374],
        [0.0374],
        [0.0374],
        [0.0374],
        [0.0374],
        [0.0374],
        [0.0374],
        [0.0374],
        [0.0374],
        [0.0374]], grad_fn=<ReluBackward0>)

In [None]:
torch.mean(torch.ones(10,5), dim=1).shape