In [None]:
'''
Take weight matrix parameters(input weights, recurrent weights and output weights) 
and spike_raster from 10 and instantiate a trained network. 
Run this network on Sine_Wave_Dataset100 (without clock-like input) 
to see if transfer learning(trained on changing amp, changing-period and clock-like input) 
can facilitate learning with non-clock like input. 
'''

In [1]:
import os
import torch
import snntorch as snn
import matplotlib.pyplot as plt
from snntorch import surrogate
from snntorch import spikegen
from snntorch import functional
from snntorch import LIF
from snntorch import spikeplot as splt
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
from warnings import warn
import torch.nn as nn
import csv

In [8]:
from classes.Sine_Wave_Dataset import SineWave100 #changing amp and per with resetting clock
from classes.Custom_Loss import CustomLoss_task

In [9]:
#import the trained recurrent matrix and output patrix
data_dir = 'dataMP'
file_name = 'level1_loss0_epoch495_batch39.npz'
file_path = os.path.join(data_dir, file_name)

if os.path.exists(file_path):
        data = np.load(file_path)
        #shape 200x200 and 1x200,epoch 495, batch 39, training data 24
        l1_mx = torch.from_numpy(data['input_weights']).T  #2x200
        rec_mx = torch.from_numpy(data['rec_weights']).T #200x200
        l2_mx = torch.from_numpy(data['output_weights']).T #200x1

In [13]:
from classes.helper1 import conn_mx, hid_mx
from classes.RLIF1 import RLIF1


#set the rec_mx and output layer matrix to already trained model (trained on sinewave0 dataloader train_data_hpc0)
#set the input layer 3x200: first row to first row, and third row to second row of the trained model
class RSNN2(nn.Module):
    def __init__(self):
        super(RSNN2, self).__init__()
        num_inputs = 2
        num_hidden = 200
        num_output = 1
        beta = 0.85
        pe_e = 0.16

        # Define the dimensions
        num_excitatory = 160
        self.num_excitatory = num_excitatory
        num_inhibitory = 40
        self.false_neg = []
        self.false_pos = []

        #input to hidden layer
        input_hid_mx = conn_mx(num_inputs, num_hidden, pe_e)
        input_hid_mx[0, :] = l1_mx[0, :]  # First row of b is the first row of a
        input_hid_mx[1, :] = l1_mx[1, :] 
        self.input_hid_mx = input_hid_mx
        self.l1 = nn.Linear(num_inputs,num_hidden)
        self.l1.weight.data = input_hid_mx.T

        # Recurrent layer weight matrix        
        self.rlif1 = RLIF1(reset_mechanism="zero",threshold=1, beta=beta, linear_features=num_hidden, all_to_all=True)
        self.rlif1.recurrent.weight.data = rec_mx.T

        #hidden to output layer
        # hid_out_mx = conn_mx(num_hidden,num_output,pe_e)
        self.l2 = nn.Linear(num_hidden, num_output)
        self.l2.weight.data = l2_mx.T


    def forward(self, inputs):
        spk1,mem1 = self.rlif1.init_rleaky()
        self.spk1_rec = []
        self.cur2_rec = []

        # print(inputs.shape)
        for step in range(inputs.shape[0]): #300
            cur_input = inputs[step,:]
            cur1 = self.l1(cur_input)
            spk1,mem1 = self.rlif1(cur1, spk1, mem1)
            cur2 = self.l2(spk1)

            self.spk1_rec.append(spk1)
            self.cur2_rec.append(cur2)

        self.spk1_rec = torch.stack(self.spk1_rec)
        self.cur2_rec = torch.stack(self.cur2_rec)
        cur2_rec = self.cur2_rec

        
        
        return self.cur2_rec, self.spk1_rec

    def positive_negative_weights(self):

        excitatory_weights = self.rlif1.recurrent.weight.data[:, :self.num_excitatory]
        inhibitory_weights = self.rlif1.recurrent.weight.data[:, self.num_excitatory:]

        #save the number of positives in inhibitory and negatives in excitatory region
        num_false_neg = torch.sum(excitatory_weights < 0).item()
        num_false_pos = torch.sum(inhibitory_weights > 0).item()

        self.false_neg.append(num_false_neg)
        self.false_pos.append(num_false_pos)

        # Clamp switched sign values at 0
        excitatory_weights.clamp_(min=0)
        inhibitory_weights.clamp_(max=0)

        mu = -0.64
        sigma = 0.51


        #change the code so that for any vanishing excitatory neuron, populate another excitatory.

        #following code picks random indices from excitatory and inhibitory originating weights
        #for the number of num_false_neg and num_false_neg for inhibitory and excitatory originating weights respectively
        #assigns them with the lognormal dist
        excitatory_zero_indices = (self.rlif1.recurrent.weight.data[:, :self.num_excitatory] == 0).nonzero(as_tuple=True)
        inhibitory_zero_indices = (self.rlif1.recurrent.weight.data[:, self.num_excitatory:] == 0).nonzero(as_tuple=True)

        if (len(excitatory_zero_indices) > num_false_pos):
            excitatory_sampled_indices = torch.stack([
                    excitatory_zero_indices[0][torch.randint(len(excitatory_zero_indices[0]), (num_false_pos,))],
                    excitatory_zero_indices[1][torch.randint(len(excitatory_zero_indices[1]), (num_false_pos,))]
                ], dim=1)

            # generating self.excitatory_changes number of lognormal values
            new_excitatory_values = torch.from_numpy(np.random.lognormal(mean=mu, sigma=sigma, size=num_false_pos)).float()
            self.rlif1.recurrent.weight.data[excitatory_sampled_indices[:, 0], excitatory_sampled_indices[:, 1]] = new_excitatory_values

        if (len(inhibitory_zero_indices) > num_false_neg):
            inhibitory_sampled_indices = torch.stack([
                    inhibitory_zero_indices[0][torch.randint(len(inhibitory_zero_indices[0]), (num_false_neg,))],
                    inhibitory_zero_indices[1][torch.randint(len(inhibitory_zero_indices[1]), (num_false_neg,))]
                ], dim=1)

            new_inhibitory_values = -torch.from_numpy(np.random.lognormal(mean=mu, sigma=sigma, size=num_false_neg)).float()
            self.rlif1.recurrent.weight.data[inhibitory_sampled_indices[:, 0], self.num_excitatory + inhibitory_sampled_indices[:, 1]] = new_inhibitory_values

In [14]:
from classes.helper1 import count_spikes

def train_modelA(args):
    model, optimizer, dataloader,step_dataloader, criterion, criterium_idx, num_epochs, num_timesteps= args
    model.train()
    weights = model.rlif1.recurrent.weight.data

    for epoch in range(num_epochs):
        epoch_loss = 0

        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs, targets

            optimizer.zero_grad()

            outputs = torch.empty(0, dtype=torch.float32, requires_grad=True)
            firing_rate_per_batch = torch.empty(0, dtype=torch.float32, requires_grad=True)

            for input, target in zip(inputs, targets):
                output, spikes = model(input)
                spikes = spikes.T
                outputs = torch.cat((outputs, output.view(1, -1)))
                firing_rate = count_spikes(spikes)
                firing_rate_per_batch = torch.cat((firing_rate_per_batch, firing_rate))
                
            loss = criterion(outputs, targets, firing_rate_per_batch)
            print("loss: ", loss)
            
            loss.backward()
            # print(model.l2.weight.grad)
            zero_idxs = torch.where(weights == 0, 1, 0)  # Create a matrix to identify where zeroes are initialized in weight matrix
            optimizer.step()

            #maintain initial sparsity - #keep previously 0 weights 0
            weights[zero_idxs==True] = 0
            model.positive_negative_weights()

            epoch_loss += loss.item()

            if epoch % 5 == 0:
                np.savez(f'dataMP/level{100}_loss{criterium_idx}_epoch{epoch}_batch{i}.npz',
                         task_loss=criterion.task_loss.item(),
                         firing_rate_loss = criterion.rate_loss.item(),
                         spikes=spikes.detach().numpy(),
                         input_weights=model.l1.weight.data.detach().numpy(),
                         rec_weights=model.rlif1.recurrent.weight.data.detach().numpy(),
                         output_weights=model.l2.weight.data.detach().numpy(),
                         inputs=inputs.detach().numpy(),
                         outputs=outputs.detach().numpy(),
                         targets=targets.detach().numpy())

In [None]:
dataset100 = SineWave100.SineWaveDataset100('train_data/train_data_sine_hpc.csv')
dataloader100 = DataLoader(dataset100, batch_size=25, shuffle=True)
loss_task = CustomLoss_task.CustomLoss_task()
net_100 = RSNN2()
optimizer_100 = torch.optim.Adam(net_100.parameters(),lr=0.05)
num_epochs = 1000
num_timesteps = 300

train_modelA([net_100, optimizer_100,dataloader100,100, loss_task, 0, 500, num_timesteps])



loss:  tensor(1324.7240, grad_fn=<MseLossBackward0>)
loss:  tensor(1053.4404, grad_fn=<MseLossBackward0>)
loss:  tensor(1653.3936, grad_fn=<MseLossBackward0>)
loss:  tensor(1650.3685, grad_fn=<MseLossBackward0>)
loss:  tensor(1415.8304, grad_fn=<MseLossBackward0>)
loss:  tensor(1421.0380, grad_fn=<MseLossBackward0>)
loss:  tensor(1563.6764, grad_fn=<MseLossBackward0>)
loss:  tensor(1204.7181, grad_fn=<MseLossBackward0>)
loss:  tensor(1472.1830, grad_fn=<MseLossBackward0>)
loss:  tensor(1064.1030, grad_fn=<MseLossBackward0>)
loss:  tensor(1670.4272, grad_fn=<MseLossBackward0>)
