In [1]:
import os
import random
import pickle

import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from model import BatchEKFR, BatchPolynomialActivation, PolynomialActivation

In [10]:
class GeneralizedFiringRateModel(torch.nn.Module):
    def __init__(
        self, 
        g, # activation function
        k: int, # number of previous timesteps for current I
        l: int, # number of timesteps for firing rate
        bin_size = 0,
        freeze_g: bool = True
    ):
        super().__init__()
        self.g = g
        self.k = k
        self.l = l
        self.a = torch.nn.Parameter(torch.zeros(k))
        self.b = torch.nn.Parameter(torch.zeros(l))
        self.bin_size = bin_size
        
        # freeze activation parameters
        if freeze_g: g.freeze_parameters()
        
    def forward(
        self,
        currents, # currents tensor, up to time t
        fs # firing rates, up to time t-1
    ):
        x = self.a @ currents[-self.k:]
        y = 1000 * self.b @ fs[-self.l:]
        return self.g(x + y)
    
    def smoothness_reg(self):
        a = torch.cat([torch.tensor([0.0]), self.a])
        b = torch.cat([torch.tensor([0.0]), self.b])
        i = torch.arange(len(self.a), 0, -1).to(torch.float32)
        j = torch.arange(len(self.b), 0, -1).to(torch.float32)
        smooth_a = (torch.diff(a) ** 2) @ (i ** 2) / len(self.a)
        smooth_b = (torch.diff(b) ** 2) @ (j ** 2) / len(self.b)
        return smooth_a + smooth_b
            
    @classmethod
    def from_params(cls, params, freeze_g=True):
        g = PolynomialActivation.from_params(params["g"])
        model = cls(g, len(model.a), len(model.b), params["bin_size"], freeze_g=freeze_g)
        model.a = torch.nn.Parameter(params["a"])
        model.b = torch.nn.Parameter(params["b"])
        model.k = len(model.a)
        model.l = len(model.b)
        return model
    
    @classmethod
    def from_ekfr(cls, ekfr_model, k, l, freeze_g=True):
        g = PolynomialActivation.from_params(ekfr_model.g.get_params())
        model = cls(g, k, l, ekfr_model.bin_size, freeze_g=freeze_g)
        a = torch.tensor([ekfr_model.kernel(i, var="a") for i in range(k-1, -1, -1)]).to(torch.float32)
        b = torch.tensor([ekfr_model.kernel(i, var="b") for i in range(l-1, -1, -1)]).to(torch.float32)
        model.a = torch.nn.Parameter(a)
        model.b = torch.nn.Parameter(b)
        return model

    def get_params(self):
        return {
            "a": self.a.detach().cpu(),
            "b": self.b.detach().cpu(),
            "g": self.g.get_params(),
            "bin_size": self.bin_size
        }
    
    def freeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = False
            
    def unfreeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = True

    def predict(self, Is):
        k, l = self.k, self.l
        pad = max(k, l)
        Is_pad = F.pad(Is, (pad, 0), "constant")
        
        with torch.no_grad():
            fs1 = torch.zeros(pad)
            pred_fs = []
            for i in range(pad, len(Is_pad)):
                f = self.forward(Is_pad[:i+1], fs1[:i])
                fs1 = torch.cat((fs1, f.reshape(1)))
                pred_fs.append(f)
        return np.array([f.item() for f in pred_fs])
    
class ExponentialKernelFiringRateModel(torch.nn.Module):
    def __init__(
        self, 
        g, # activation function
        ds,
        bin_size,
        freeze_g = True,
        device = None
    ):
        super().__init__()
        self.g = g
        self.bin_size = bin_size
        
        self.ds = torch.nn.Parameter(ds.clone().detach(), requires_grad=False)
        self.n = len(self.ds)
        self.a = torch.nn.Parameter(torch.ones(self.n) + torch.randn(self.n) * 0.001)
        self.b = torch.nn.Parameter(torch.randn(self.n) * 0.001)
        self.w = torch.nn.Parameter((torch.randn(self.n) * 0.001 + 1) / self.n).reshape(-1, 1).to(device)
        
        if freeze_g: self.g.freeze_parameters()
            
    
    # outputs a tensor of shape [B], firing rate predictions at time t
    def forward(
        self,
        currents, # shape [B], currents for time t
        fs # shape [B], firing rates for time t-1
    ):
        x = torch.outer(currents, self.a) # shape [B, n]
        y = 1000 * torch.outer(fs, self.b) # shape [B, n]
        self.v =  (1 - self.ds) * self.v + x + y # shape [B, n]
        return self.g(self.v @ self.w).reshape(-1) # shape [B]
    
    def reset(self, batch_size):
        self.v = torch.zeros(batch_size, self.n).to(self.device)
    
    @classmethod
    def from_params(cls, params, freeze_g=True):
        g = PolynomialActivation.from_params(params["g"])
        model = cls(g, params["ds"], params["bin_size"] if "bin_size" in params else 20, freeze_g=freeze_g)
        model.a = torch.nn.Parameter(params["a"])
        model.b = torch.nn.Parameter(params["b"])
        model.w = torch.nn.Parameter(params["w"])
        return model

    def get_params(self):
        return {
            "a": self.a.detach().cpu(),
            "b": self.b.detach().cpu(),
            "g": self.g.get_params(),
            "w": self.w.detach().cpu(),
            "ds": self.ds.detach().cpu(),
            "bin_size": self.bin_size
        }
    
    def freeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = False
            
    def unfreeze_parameters(self):
        for _, p in self.named_parameters():
            p.requires_grad = True
            
    def kernel(self, x, var="a"):
        a = self.a if var == "a" else self.b
        return torch.sum(self.w * a * torch.pow(self.ds, x))
    
    # Is: shape [seq_length]
    def predict(self, Is):
        pred_fs = []
        vs = []
        f = torch.zeros(1).to(self.device)
        
        with torch.no_grad():
            self.reset(1)
            for i in range(len(Is)):
                f = self.forward(Is[i].reshape(1), f.reshape(1))
                vs.append(self.v.clone())
                pred_fs.append(f.clone())
        return torch.stack(pred_fs).squeeze(), torch.stack(vs).squeeze()

In [12]:
def get_params(save_path="model/params/"):
    params = {}
    for fname in os.listdir(save_path):
        with open(f"{save_path}{fname}", "rb") as f:
            p = pickle.load(f)
            params[int(fname.split(".")[0])] = p
    return params

def get_random_neurons(n_neurons, save_path="model/params/", threshold=0.7):
    params = get_params(save_path)
    cell_ids = []

    for cell_id in params:
        if params[cell_id]["evr"] >= threshold:
            cell_ids.append(cell_id)

    chosen_ids = random.sample(cell_ids, k=n_neurons)
    neurons = []
    for cell_id in chosen_ids:
        neurons.append((cell_id, ExponentialKernelFiringRateModel.from_params(params[cell_id]["params"])))
    return neurons

class Network(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim) -> None:
        super().__init__()
        
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        
        self.fc1 = torch.nn.Linear(in_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, out_dim)

        self.hidden_neurons = BatchEKFR(get_random_neurons(hidden_dim), freeze_g=True)
        self.hidden_neurons.freeze_parameters()
    
    def reset(self, batch_size):
        self.hidden_neurons.reset(batch_size)
        self.xh = torch.zeros(batch_size, self.hidden_dim)
        
    def zero_input(self, batch_size):
        return torch.zeros(batch_size, in_dim)
    
    # x: [batch_size, in_dim]
    def forward(self, x):
        x_in = self.fc1(x)
        x_rec = self.fc2(self.xh)
        self.xh = self.hidden_neurons(x_in + x_rec)
        out = self.fc3(self.xh)
        return out

In [13]:
def get_data_loaders(batch_size):
    train_set = torchvision.datasets.MNIST('data/mnist/train', download=True, train=True, transform=torchvision.transforms.ToTensor())
    test_set = torchvision.datasets.MNIST('data/mnist/test', download=True, train=False, transform=torchvision.transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader
    
# x: shape [batch_size, 28, 28]
# returns shape [batch_size, seq_length, in_dim]
def reshape_image(x, variant="p"):
    if variant == "p":
        return x.reshape(x.shape[0], -1, 1)
    else:
        return x
    
def train_network(model, train_loader, epochs, variant="p"):
    criterion = torch.nn.CrossEntropyLoss()
    #optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1, centered=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.03)
    
    for epoch in range(epochs):
        total_loss = 0

        for x, label in tqdm(train_loader):
            x = x.reshape(x.shape[0], 28, 28)
            x = reshape_image(x, variant=variant)
            
            # sequentially send input into network
            model.reset(x.shape[0])
            for i in range(x.shape[1]):
                model(x[:, i, :])
                
            loss = 0
            for _ in range(5):
                pred_y = model(model.zero_input(x.shape[0]))
                loss += criterion(pred_y, F.one_hot(label, num_classes=10).to(torch.float32))
            loss /= 5
            
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            total_loss += loss
            
        if (epoch+1) % 1 == 0:
            print(f"Epoch {epoch+1} / Loss: {total_loss}")
            
def accuracy(model, data_loader, variant="p"):
    with torch.no_grad():
        correct, total = 0, 0
        for x, label in tqdm(data_loader):
            x = x.reshape(x.shape[0], 28, 28)
            x = reshape_image(x, variant=variant)

            # sequentially send input into network
            model.reset(x.shape[0])
            for i in range(x.shape[1]):
                model(x[:, i, :])

            total_pred = torch.zeros(x.shape[0], 10)
            for _ in range(5):
                pred_y = model(model.zero_input(x.shape[0]))
                total_pred += F.softmax(pred_y, dim=1) # add softmax
            correct += torch.sum(torch.argmax(total_pred, dim=1) == label)
            total += x.shape[0]
    return correct / total

In [5]:
batch_size = 256
variant = "l"
in_dim = 1 if variant == "p" else 28
out_dim = 10
hidden_dim = 32
epochs = 30
train_loader, test_loader = get_data_loaders(batch_size)
model = Network(in_dim, hidden_dim, out_dim)
train_network(model, train_loader, epochs, variant=variant)

100%|█████████████████████████████████████████| 235/235 [00:12<00:00, 18.31it/s]


Epoch 1 / Loss: 541.18408203125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.98it/s]


Epoch 2 / Loss: 541.1002807617188


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.75it/s]


Epoch 3 / Loss: 541.1058959960938


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.64it/s]


Epoch 4 / Loss: 541.1051025390625


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.68it/s]


Epoch 5 / Loss: 541.0718994140625


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.63it/s]


Epoch 6 / Loss: 541.1148681640625


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.45it/s]


Epoch 7 / Loss: 541.1693115234375


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.79it/s]


Epoch 8 / Loss: 541.1134033203125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.60it/s]


Epoch 9 / Loss: 541.1106567382812


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.74it/s]


Epoch 10 / Loss: 541.1167602539062


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.94it/s]


Epoch 11 / Loss: 541.1646728515625


100%|█████████████████████████████████████████| 235/235 [00:14<00:00, 16.49it/s]


Epoch 12 / Loss: 541.0778198242188


100%|█████████████████████████████████████████| 235/235 [00:14<00:00, 16.76it/s]


Epoch 13 / Loss: 541.0667114257812


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.60it/s]


Epoch 14 / Loss: 541.1650390625


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.35it/s]


Epoch 15 / Loss: 541.087890625


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.37it/s]


Epoch 16 / Loss: 541.1319580078125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.28it/s]


Epoch 17 / Loss: 541.1241455078125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.41it/s]


Epoch 18 / Loss: 541.123046875


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.20it/s]


Epoch 19 / Loss: 541.0089111328125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.41it/s]


Epoch 20 / Loss: 541.1427001953125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.18it/s]


Epoch 21 / Loss: 541.044189453125


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.02it/s]


Epoch 22 / Loss: 541.0504760742188


100%|█████████████████████████████████████████| 235/235 [00:14<00:00, 16.67it/s]


Epoch 23 / Loss: 541.11669921875


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 16.96it/s]


Epoch 24 / Loss: 541.073974609375


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.26it/s]


Epoch 25 / Loss: 541.0989990234375


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.15it/s]


Epoch 26 / Loss: 541.05615234375


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.22it/s]


Epoch 27 / Loss: 541.1557006835938


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 16.94it/s]


Epoch 28 / Loss: 541.0872802734375


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 17.07it/s]


Epoch 29 / Loss: 541.1174926757812


100%|█████████████████████████████████████████| 235/235 [00:13<00:00, 16.96it/s]


Epoch 30 / Loss: 541.1862182617188


In [6]:
train_acc = accuracy(model, train_loader, variant=variant)
test_acc = accuracy(model, test_loader, variant=variant)
print(f"Train accuracy: {train_acc} / Test accuracy: {test_acc}")

100%|█████████████████████████████████████████| 235/235 [00:07<00:00, 33.53it/s]
100%|███████████████████████████████████████████| 40/40 [00:01<00:00, 32.25it/s]

Train accuracy: 0.11236666887998581 / Test accuracy: 0.11349999904632568





In [None]:
plt.matshow(model.Wh.detach())
plt.colorbar()

In [None]:
plt.matshow(model.Wx.detach())
plt.colorbar()

In [None]:
plt.matshow(model.Wy.detach())
plt.colorbar()

In [14]:
a = get_random_neurons(1)

In [15]:
b = [GeneralizedFiringRateModel.from_ekfr(model, 10, 10) for _, model in a]

In [17]:
from evaluate import explained_variance_ratio
from data import get_train_test_data, get_data
cell_id = a[0][0]
print(cell_id)
data = get_data(cell_id, aligned=False)
Is_tr, fs_tr, Is_val, fs_val, Is_te, fs_te, stims = get_train_test_data(data, 20)

485932822


FileNotFoundError: [Errno 2] No such file or directory: 'data/processed_data/processed_I_and_firing_rate_485932822.pickle'