In [4]:
import os
import random
import pickle

import torch
import torch.nn.functional as F
import torchvision
import numpy as np

from model import load_model

In [5]:
class GFRLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()

class FiringRateModel(torch.nn.Module):
    def __init__(
        self, 
        g, # activation function
        ds,
        bin_size = 20,
        device = None,
        freeze_g = True
    ):
        super().__init__()
        self.g = g
        self.bin_size = bin_size
        self.dt = bin_size / 1000
        self.device = device
        
        self.ds = torch.nn.Parameter(torch.tensor(ds).to(torch.float32), requires_grad=False)
        self.n = len(self.ds)
        self.a = torch.nn.Parameter(torch.ones(self.n) + torch.randn(self.n) * 0.001)
        self.b = torch.nn.Parameter(torch.randn(self.n) * 0.001)
        self.w = torch.nn.Parameter((torch.ones(self.n) + torch.randn(self.n) * 0.001) / self.n).reshape(-1, 1).to(device)
        
        # freeze activation parameters
        if freeze_g:
            for _, p in self.g.named_parameters():
                p.requires_grad = False
    
    # outputs a tensor of shape [B], firing rate predictions at time t
    def forward(
        self,
        currents, # shape [B], currents for time t
        fs # shape [B], firing rates for time t-1
    ):
        print("hi")
        x = torch.outer(currents, self.a) # shape [B, n]
        y = 1000 * torch.outer(fs, self.b) # shape [B, n]
        self.v =  (1 - self.ds) * self.v + x + y # shape [B, n]
        return self.g(self.v @ self.w).reshape(-1) # shape [B]
    
    def reset(self, batch_size):
        self.v = torch.zeros(batch_size, self.n).to(self.device)

    def init_from_params(self, params, freeze_g=True):
        self.a = torch.nn.Parameter(params["a"])
        self.b = torch.nn.Parameter(params["b"])
        self.g = PolynomialActivation()
        self.g.init_from_params(params["g"])
        self.ds = torch.nn.Parameter(params["ds"], requires_grad=False)
        self.w = torch.nn.Parameter(params["w"])
        self.n = len(self.ds)
        
        if freeze_g:
            for _, p in self.g.named_parameters():
                p.requires_grad = False

    def get_params(self):
        return {
            "a": self.a.detach().cpu(),
            "b": self.b.detach().cpu(),
            "g": self.g.get_params(),
            "w": self.w.detach().cpu(),
            "ds": self.ds.detach().cpu()
        }

class PolynomialActivation(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    # z: shape [B, 1]
    def forward(self, z):
        x = (z - self.b) / self.max_current # shape [B, 1]
        poly = x.pow(self.p) @ (self.poly_coeff ** 2).reshape(-1, 1) # shape [B, degree]
        tan = self.max_firing_rate * F.tanh(poly) # ceil is the max firing rate
        return F.relu(tan) # shape [B, 1]
    
    # slightly ad hoc parameter initialization
    def init_params(self, bin_size, degree, max_current, max_firing_rate, Is, fs):
        self.bin_size = bin_size
        self.degree = degree # polynomial degree
        self.max_current = max_current # used for normalization
        self.max_firing_rate = max_firing_rate
        self.p = torch.nn.Parameter(torch.tensor([d for d in range(degree+1)]), requires_grad=False)
        
        x1, x2, y1, y2 = tuple([torch.tensor(0.0)] * 4)
        xs, ys = map(list, zip(*sorted(zip(Is.cpu(), fs.cpu()), key=lambda x: x[0])))
        i = np.argmax(ys)
        x2, y2 = xs[i], ys[i]
        for i in range(0, len(ys)):
            if ys[i] > 0.01:
                x1, y1 = (xs[i-1], ys[i-1]) if i - 1 > 0 else (xs[i], ys[i])
                break
        self.b = torch.nn.Parameter(x1.clone())
        self.poly_coeff = torch.randn(self.degree + 1) * 1e-1 # to make sure there is some gradient
        self.poly_coeff[1] = np.abs((y2 - y1) / (x2 - x1) * self.max_current) #* torch.abs(torch.randn(1)[0] * 7 + 15)
        self.poly_coeff = torch.nn.Parameter(self.poly_coeff)
        
    def init_from_file(self, filename):
        try:
            with open(filename, "rb") as file:
                params = pickle.load(file)
        except:
            print("Error")
        finally:
            self.init_from_params(params)

    def init_from_params(self, params):
        self.bin_size = params["bin_size"]
        self.max_current = params["max_current"]
        self.max_firing_rate = params["max_firing_rate"]
        self.poly_coeff = torch.nn.Parameter(params["poly_coeff"])
        self.b = torch.nn.Parameter(params["b"])
        self.degree = len(self.poly_coeff) - 1
        self.p = torch.tensor([d for d in range(self.degree+1)])
        self.p = torch.nn.Parameter(self.p, requires_grad=False)
    
    def get_params(self):
        return {
            "max_current": self.max_current,
            "max_firing_rate": self.max_firing_rate,
            "poly_coeff": self.poly_coeff.detach().cpu(),
            "b": self.b.detach().cpu(),
            "bin_size": self.bin_size
        }

    def save_params(self, filename):
        d = self.get_params()
        with open(filename, 'wb') as handle:
            pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [35]:
def get_params(save_path="model/params/"):
    params = {}
    for fname in os.listdir(save_path):
        with open(f"{save_path}{fname}", "rb") as f:
            p = pickle.load(f)
            params[int(fname.split(".")[0])] = p
    return params

def get_random_neurons(n_neurons, save_path="model/params/"):
    params = get_params(save_path)
    cell_ids = []

    for cell_id in params:
        if params[cell_id]["evr"] >= 0.7:
            cell_ids.append(cell_id)

    chosen_ids = random.sample(cell_ids, k=n_neurons)
    neurons = {}
    for cell_id in chosen_ids:
        model = load_model(params[cell_id]["params"])
        neurons[cell_id] = model
    return neurons

class Network(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_neurons) -> None:
        super().__init__()
        
        # first in_dim neurons are input neurons
        # last out_dim neurons are output neurons
        assert n_neurons >= hidden_dim
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_neurons = n_neurons
        
        ## connectivity matrix
        self.A = torch.nn.Parameter(torch.randn(n_neurons, n_neurons)) / np.sqrt(n_neurons)
        # readout weights
        self.W = torch.nn.Parameter(torch.randn(hidden_dim, out_dim))
        
        neurons = get_random_neurons(n_neurons)
        self.neurons = torch.nn.ModuleList([neurons[cell_id] for cell_id in neurons])
        self.cell_ids = list(neurons.keys())
    
    def reset(self, batch_size):
        for neuron in self.neurons:
            neuron.reset(batch_size)
        self.fs = torch.zeros(batch_size, self.n_neurons)
        
    def zero_input(self, batch_size):
        return torch.zeros(batch_size, in_dim)
    
    # x: [batch_size, in_dim]
    def forward(self, x):
        for i in range(self.in_dim):
            I = x[:, i] + torch.tensordot(self.fs, self.A[i, :], dims=1)
            f = self.fs[:, i]
            self.fs[:, i] = self.neurons[i](I, f)
        #print(self.fs[:, -self.out_dim:].shape, self.w.shape)
        return torch.tensordot(self.fs[:, -self.out_dim:], self.w, dims=1)

In [44]:
def get_data_loaders(batch_size):
    train_set = torchvision.datasets.MNIST('data/mnist/train', download=True, train=True, transform=torchvision.transforms.ToTensor())
    test_set = torchvision.datasets.MNIST('data/mnist/test', download=True, train=False, transform=torchvision.transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
    
    return train_loader, test_loader
    
# x: shape [batch_size, 28, 28]
# returns shape [batch_size, seq_length, in_dim]
def reshape_image(x, variant="p"):
    if variant == "p":
        return x.reshape(x.shape[0], -1, 1)
    else:
        return x
    
def train_network(model, train_loader, epochs, variant="p"):
    criterion = torch.nn.CrossEntropyLoss()
    #optimizer = torch.optim.RMSProp(model.parameters(), lr=0.05)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, centered=True)
    
    for epoch in range(epochs):
        total_loss = 0
        
        for x, label in train_loader:
            x = reshape_image(x, variant=variant)
            
            # sequentially send input into network
            model.reset(x.shape[0])
            for i in range(x.shape[1]):
                model(x[:, i, :])
                
            pred_y = model(model.zero_input(x.shape[0]))
            print(pred_y.shape)
            loss = criterion(pred_y, F.one_hot(label, num_classes=10))
            
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            total_loss += loss
            
        if (epoch+1) % 1 == 0:
            print(f"Epoch {epoch+1} / Loss: {total_loss}")

In [45]:
batch_size = 32
variant = "p"
in_dim = 1 if variant == "p" else 28
train_loader, test_loader = get_data_loaders(batch_size)
model = Network(in_dim, 16, 16)
train_network(model, train_loader, 10)

torch.Size([32])


RuntimeError: 0D or 1D target tensor expected, multi-target not supported

In [None]:
a = torch.zeros(32, 16)
b = torch.zeros(16)
a.shape, b.shape

In [34]:
torch.tensordot(a, b, dims=1).shape

torch.Size([32])