<a href="https://colab.research.google.com/github/Singular-Brain/DeepBioLCNet/blob/main/UnsupervisedDeep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Notebook setups

In [1]:
!pip install -q git+https://github.com/Singular-Brain/DeepBioLCNet

[?25l[K     |██▊                             | 10 kB 16.5 MB/s eta 0:00:01[K     |█████▍                          | 20 kB 22.8 MB/s eta 0:00:01[K     |████████▏                       | 30 kB 12.5 MB/s eta 0:00:01[K     |██████████▉                     | 40 kB 9.4 MB/s eta 0:00:01[K     |█████████████▋                  | 51 kB 5.4 MB/s eta 0:00:01[K     |████████████████▎               | 61 kB 5.9 MB/s eta 0:00:01[K     |███████████████████             | 71 kB 5.7 MB/s eta 0:00:01[K     |█████████████████████▊          | 81 kB 6.4 MB/s eta 0:00:01[K     |████████████████████████▌       | 92 kB 6.4 MB/s eta 0:00:01[K     |███████████████████████████▏    | 102 kB 5.3 MB/s eta 0:00:01[K     |██████████████████████████████  | 112 kB 5.3 MB/s eta 0:00:01[K     |████████████████████████████████| 120 kB 5.3 MB/s 
[K     |████████████████████████████████| 73 kB 1.7 MB/s 
[K     |████████████████████████████████| 280 kB 43.8 MB/s 
[K     |█████████████████████████████

In [2]:
!wget https://data.deepai.org/mnist.zip
!mkdir -p ../data/MNIST/TorchvisionDatasetWrapper/raw
!unzip mnist.zip -d ../data/MNIST/TorchvisionDatasetWrapper/raw/

--2021-10-22 18:03:41--  https://data.deepai.org/mnist.zip
Resolving data.deepai.org (data.deepai.org)... 138.201.36.183
Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11597176 (11M) [application/x-zip-compressed]
Saving to: ‘mnist.zip’


2021-10-22 18:03:42 (11.0 MB/s) - ‘mnist.zip’ saved [11597176/11597176]

Archive:  mnist.zip
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-labels-idx1-ubyte.gz  


In [3]:
!git clone https://github.com/Singular-Brain/DeepBioLCNet/

Cloning into 'DeepBioLCNet'...
remote: Enumerating objects: 347, done.[K
remote: Counting objects: 100% (347/347), done.[K
remote: Compressing objects: 100% (273/273), done.[K
remote: Total 347 (delta 124), reused 249 (delta 57), pack-reused 0[K
Receiving objects: 100% (347/347), 31.43 MiB | 26.15 MiB/s, done.
Resolving deltas: 100% (124/124), done.


In [4]:
from bindsnet.network.nodes import Nodes
import os
import torch
import random
import numpy as np
import copy
import math
import matplotlib.pyplot as plt
import collections
from torchvision import transforms
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sn
import torch.nn.functional as fn

from abc import ABC, abstractmethod
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Union, Tuple, Optional, Sequence
from torch.nn.modules.utils import _pair

from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import metrics
from sklearn.metrics import accuracy_score

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes, AdaptiveLIFNodes, IFNodes
from bindsnet.network.topology import LocalConnection, Connection, LocalConnectionOrig, MaxPool2dLocalConnection
from bindsnet.network.monitors import Monitor, AbstractMonitor, TensorBoardMonitor
from bindsnet.learning import PostPre, MSTDP, MSTDPET, WeightDependentPostPre, Hebbian
from bindsnet.learning.reward import DynamicDopamineInjection, DopaminergicRPE
from bindsnet.analysis.plotting import plot_locally_connected_weights,plot_locally_connected_weights_meh,plot_locally_connected_weights_one_chan,plot_spikes,plot_locally_connected_weights_meh2,plot_convergence_and_histogram,plot_locally_connected_weights_meh3
from bindsnet.analysis.visualization import plot_weights_movie, plot_spike_trains_for_example,summary, plot_voltage
from bindsnet.utils import reshape_locally_connected_weights, reshape_locally_connected_weights_meh, reshape_conv2d_weights

## Sets up Gpu use and manual seed


In [5]:
if torch.cuda.is_available():
    device =  torch.device("cuda")
    gpu = True
else:
    device =  torch.device("cpu")
    gpu = False

def manual_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


SEED = 2045 # The Singularity is Near!
manual_seed(SEED)

torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

Running on Device =  cuda


# Custom Monitors

## Reward Monitor

In [6]:
class RewardMonitor(AbstractMonitor):
    # language=rst
    """
    Records state variables of interest.
    """

    def __init__(
        self,
        time: None,
        batch_size: int = 1,
        device: str = "cpu",
    ):
        # language=rst
        """
        Constructs a ``Monitor`` object.

        :param obj: An object to record state variables from during network simulation.
        :param state_vars: Iterable of strings indicating names of state variables to record.
        :param time: If not ``None``, pre-allocate memory for state variable recording.
        :param device: Allow the monitor to be on different device separate from Network device
        """
        super().__init__()

        self.time = time
        self.batch_size = batch_size
        self.device = device

        # if time is not specified the monitor variable accumulate the logs
        if self.time is None:
            self.device = "cpu"

        self.recording = []
        self.reset_state_variables()

    def get(self,) -> torch.Tensor:
        # language=rst
        """
        Return recording to user.

        :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded state
        variable.
        Note, if time == `None`, get return the logs and empty the monitor variable

        """
        # return_logs = torch.as_tensor(self.recording)
        # if self.time is None:
        #     self.recording = []
        return self.recording

    def record(self, **kwargs) -> None:
        # language=rst
        """
        Appends the current value of the recorded state variables to the recording.
        """
        if "reward" in kwargs:
            self.recording.append(kwargs["reward"])
        # remove the oldest element (first in the list)
        # if self.time is not None:
        #     self.recording.pop(0)

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Resets recordings to empty ``List``s.
        """
        self.recording = []


## Plot Eligibility trace

In [7]:
class PlotET(AbstractMonitor):
    # language=rst
    """
    Records state variables of interest.
    """

    def __init__(
        self,
        i,
        j,
        source,
        target,
        connection,
    ):
        # language=rst
        """
        Constructs a ``Monitor`` object.

        :param obj: An object to record state variables from during network simulation.
        :param state_vars: Iterable of strings indicating names of state variables to record.
        :param time: If not ``None``, pre-allocate memory for state variable recording.
        :param device: Allow the monitor to be on different device separate from Network device
        """
        super().__init__()
        self.i = i
        self.j = j
        self.source = source
        self.target = target
        self.connection = connection

        self.reset_state_variables()

    def get(self,) -> torch.Tensor:
        # language=rst
        """
        Return recording to user.

        :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded state
        variable.
        Note, if time == `None`, get return the logs and empty the monitor variable

        """
        # return_logs = torch.as_tensor(self.recording)
        # if self.time is None:
        #     self.recording = []
        return self.recording

    def record(self, **kwargs) -> None:
        # language=rst
        """
        Appends the current value of the recorded state variables to the recording.
        """
        if hasattr(self.connection.update_rule, 'p_plus'):
            self.recording['spikes_i'].append(self.source.s.ravel()[self.i].item())
            self.recording['spikes_j'].append(self.target.s.ravel()[self.j].item())
            self.recording['p_plus'].append(self.connection.update_rule.p_plus[self.i].item())
            self.recording['p_minus'].append(self.connection.update_rule.p_minus[self.j].item())
            self.recording['eligibility'].append(self.connection.update_rule.eligibility[self.i,self.j].item())
            self.recording['eligibility_trace'].append(self.connection.update_rule.eligibility_trace[self.i,self.j].item())
            self.recording['w'].append(self.connection.w[self.i,self.j].item())

    def plot(self):

        fig, axs  = plt.subplots(7)
        fig.set_size_inches(10, 20)
        for i, (name, p) in enumerate(self.recording.items()):
            axs[i].plot(p[-250:])
            axs[i].set_title(name)
    
        fig.show()

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Resets recordings to empty ``List``s.
        """
        self.recording = {
        'spikes_i': [],
        'spikes_j': [],
        'p_plus':[],
        'p_minus':[],
        'eligibility':[],
        'eligibility_trace':[],
        'w': [],
        }


## Kernel 

In [8]:
class AbstractKernel(ABC):
    def __init__(self, kernel_size):
        """
        Base class for generating image filter kernels such as Gabor, DoG, etc. Each subclass should override :attr:`__call__` function.
        Instantiates a ``Filter Kernel`` object.
        :param window_size : The size of the kernel (int)
        """
        self.window_size = kernel_size

    def __call__(self):
        pass


In [9]:
class DoGKernel(AbstractKernel):
    def __init__(self, kernel_size: Union[int, Tuple[int, int]], sigma1 : float, sigma2 : float):
        """
        Generates DoG filter kernels.
        :param kernel_size: Horizontal and vertical size of DOG kernels.(If pass int, we consider it as a square filter) 
        :param sigma1 : The sigma parameter for the first Gaussian function.
        :param sigma2 : The sigma parameter for the second Gaussian function.
        """
        super(DoGKernel, self).__init__(kernel_size)
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        
    def __call__(self):
        k = self.window_size//2
        x, y = np.mgrid[-k:k+1:1, -k:k+1:1]
        a = 1.0 / (2 * math.pi)
        prod = x*x + y*y
        f1 = (1/(self.sigma1*self.sigma1)) * np.exp(-0.5 * (1/(self.sigma1*self.sigma1)) * (prod))
        f2 = (1/(self.sigma2*self.sigma2)) * np.exp(-0.5 * (1/(self.sigma2*self.sigma2)) * (prod))
        dog = a * (f1-f2)
        dog_mean = np.mean(dog)
        dog = dog - dog_mean
        dog_max = np.max(dog)
        dog = dog / dog_max
        dog_tensor = torch.from_numpy(dog)
        # returns a 2d tensor corresponding to the requested DoG filter
        return dog_tensor.float()

In [10]:
class Filter():
    """
    Applies a filter transform. Each filter contains a sequence of :attr:`FilterKernel` objects.
    The result of each filter kernel will be passed through a given threshold (if not :attr:`None`).
    Args:
        filter_kernels (sequence of FilterKernels): The sequence of filter kernels.
        padding (int, optional): The size of the padding for the convolution of filter kernels. Default: 0
        thresholds (sequence of floats, optional): The threshold for each filter kernel. Default: None
        use_abs (boolean, optional): To compute the absolute value of the outputs or not. Default: False
    .. note::
        The size of the compund filter kernel tensor (stack of individual filter kernels) will be equal to the 
        greatest window size among kernels. All other smaller kernels will be zero-padded with an appropriate 
        amount.
    """
    # filter_kernels must be a list of filter kernels
    # thresholds must be a list of thresholds for each kernel
    def __init__(self, filter_kernels, padding=0, thresholds=None, use_abs=False):
        tensor_list = []
        self.max_window_size = 0
        for kernel in filter_kernels:
            if isinstance(kernel, torch.Tensor):
                tensor_list.append(kernel)
                self.max_window_size = max(self.max_window_size, kernel.size(-1))
            else:
                tensor_list.append(kernel().unsqueeze(0))
                self.max_window_size = max(self.max_window_size, kernel.window_size)
        for i in range(len(tensor_list)):
            p = (self.max_window_size - filter_kernels[i].window_size)//2
            tensor_list[i] = fn.pad(tensor_list[i], (p,p,p,p))

        self.kernels = torch.stack(tensor_list)
        self.number_of_kernels = len(filter_kernels)
        self.padding = padding
        if isinstance(thresholds, list):
            self.thresholds = thresholds.clone().detach()
            self.thresholds.unsqueeze_(0).unsqueeze_(2).unsqueeze_(3)
        else:
            self.thresholds = thresholds
        self.use_abs = use_abs

    # returns a 4d tensor containing the flitered versions of the input image
    # input is a 4d tensor. dim: (minibatch=1, filter_kernels, height, width)
    def __call__(self, input):

        # if input.dim() == 3:
        #     input2 = torch.unsqueeze(input, 0)
        input.unsqueeze_(0)
        output = fn.conv2d(input, self.kernels, padding = self.padding).float()
        if not(self.thresholds is None):
            output = torch.where(output < self.thresholds, torch.tensor(0.0, device=output.device), output)
        if self.use_abs:
            torch.abs_(output)
        return output.squeeze(0)

# Design network

In [11]:
def convergence(c):
    if c.norm is None:
        return 1-torch.mean((c.w-c.wmin)*(c.wmax-c.w))/((c.wmax-c.wmin)/2)**2
    else:
        mean_norm_factor = c.norm / c.w.shape[-1]
        return  1-(torch.mean((c.w-c.wmin)*(c.wmax-c.w))/((mean_norm_factor-c.wmin)*(c.wmax-mean_norm_factor)))

In [12]:
compute_size = lambda inp_size, k, s: int((inp_size-k)/s) + 1

class LCNet(Network):
    def __init__(
        self,
        dt: float,
        time: int,
        channels: list,
        filters: list,
        strides: list,
        input_channels: int,
        nu_LC: Union[float, Tuple[float, float]],
        NodesType_LC,
        crop_size:int,
        inh_LC: bool,
        nu_inh_LC: float,
        inh_factor_LC: float,
        norm_factor_LC: float,
        update_rule_LC,
        update_rule_inh_LC,
        wmin: float,
        wmax: float ,
        soft_bound: bool,
        theta_plus: float,
        tc_theta_decay: float,
        tc_trace:int,
        trace_additive,
        load_path,
        save_path,
        confusion_matrix,
        lc_weights_vis,
        out_weights_vis,
        lc_convergence_vis,
        out_convergence_vis,
        **kwargs,
    ) -> None:
        # language=rst
        """
        Constructor for class ``BioLCNet``.

        :param n_inpt: Number of input neurons. Matches the 1D size of the input data.
        :param n_neurons: Number of excitatory, inhibitory neurons.
        :param exc: Strength of synapse weights from excitatory to inhibitory layer.
        :param inh: Strength of synapse weights from inhibitory to excitatory layer.
        :param dt: Simulation time step.
        :param nu: Single or pair of learning rates for pre- and post-synaptic events,
            respectively.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param wmin: Minimum allowed weight on input to excitatory synapses.
        :param wmax: Maximum allowed weight on input to excitatory synapses.
        :param norm: Input to excitatory layer connection weights normalization
            constant.
        :param theta_plus: On-spike increment of ``(adaptive)LIFNodes`` membrane
            threshold potential.
        :param tc_theta_decay: Time constant of ``(adaptive)LIFNodes`` threshold
            potential decay.
        :param inpt_shape: The dimensionality of the input layer.
        """
        manual_seed(SEED)
        super().__init__(dt=dt, reward_fn = None, online=True)
        
        self.dt = dt
        self.save_path = save_path
        self.load_path = load_path
        self.time = time
        self.input_channels = input_channels
        self.crop_size = crop_size
        self.clamp_intensity = kwargs.get('clamp_intensity',None)
        self.soft_bound = soft_bound
        self.confusion_matrix = confusion_matrix
        self.lc_weights_vis = lc_weights_vis
        self.out_weights_vis = out_weights_vis
        self.lc_convergence_vis = lc_convergence_vis
        self.out_convergence_vis = out_convergence_vis
        self.convergences = {}
        self.wmin = wmin 
        self.wmax = wmax

        ### nodes
        inp = Input(shape= [input_channels,crop_size,crop_size], traces=True, tc_trace=tc_trace,traces_additive = trace_additive)
        self.add_layer(inp, name="layer0")

        last_layer_size = crop_size
        last_layer = inp
        last_layer_channels = 1
        inh_LC_type = kwargs.get('inh_LC_type', 'recurrent')
        LC_weights_path = kwargs.get('LC_weights_path', None)

        for i, (n_channels,filter_size,stride) in enumerate(zip(channels,filters,strides)):
            new_layer_size = compute_size(last_layer_size, filter_size, stride)
            new_layer = NodesType_LC(
                shape= [n_channels, new_layer_size, new_layer_size],
                traces=True, tc_trace=tc_trace,traces_additive = trace_additive,
                tc_theta_decay = tc_theta_decay, theta_plus = theta_plus
                )
            self.add_layer(new_layer, name=f"layer{i+1}")
            ### connections 
            LC_connection = LocalConnection(
                last_layer, new_layer,
                filter_size, stride,
                last_layer_channels, n_channels,
                input_shape = [last_layer_size]*2,
                nu = nu_LC, update_rule= update_rule_LC,
                wmin = wmin, wmax= wmax, soft_bound = soft_bound,
                norm = norm_factor_LC)
            
            if LC_weights_path:
                LC_connection.w.data = torch.load(LC_weights_path)['state_dict']['input_to_main1.w']
                print("Weights loaded ...")

            self.add_connection(LC_connection, f"layer{i}", f"layer{i+1}")
            ###
            last_layer_size = new_layer_size
            last_layer = new_layer
            last_layer_channels = n_channels
            ### Inhibitory connection
            if inh_LC:
                #if inh_LC_type != 'lateral':
                w_inh_LC = torch.zeros(n_channels,new_layer_size,new_layer_size,n_channels,new_layer_size,new_layer_size)
                for c in range(n_channels):
                    for w1 in range(new_layer_size):
                        for w2 in range(new_layer_size):
                            w_inh_LC[c,w1,w2,:,w1,w2] = - inh_factor_LC
                            w_inh_LC[c,w1,w2,c,w1,w2] = 0
            
                w_inh_LC = w_inh_LC.reshape(new_layer.n,new_layer.n)
                                                            
                LC_recurrent_inhibition = Connection(
                    source=new_layer,
                    target=new_layer,
                    w=w_inh_LC,
                    nu= nu_inh_LC,
                )
                self.add_connection(LC_recurrent_inhibition, f"layer{i+1}", f"layer{i+1}")

                #else:
                inh_neurons_layer = LIFNodes(
                    shape= [n_channels],
                    traces=True, tc_trace=tc_trace,traces_additive = trace_additive,
                    tc_theta_decay = tc_theta_decay, theta_plus = theta_plus
                    ).to(device)
                self.add_layer(inh_neurons_layer, name=f"layer{i+1}_inh_neurons")

                w_inh_LC = torch.zeros(n_channels, n_channels,new_layer_size,new_layer_size)
                w_inh_LC_rev = torch.zeros(n_channels, new_layer_size, new_layer_size, n_channels)

                for c in range(n_channels):
                    for w1 in range(new_layer_size):
                        for w2 in range(new_layer_size): 
                            w_inh_LC[c,c,w1,w2] = -inh_factor_LC
                            w_inh_LC_rev[c,w1,w2,c] = ((13.0/4)*new_layer_size**2)/(new_layer_size**2)

                w_inh_LC = w_inh_LC.reshape(n_channels, new_layer.n)
                w_inh_LC_rev = w_inh_LC_rev.reshape(new_layer.n, n_channels)

                LC_lateral_inhibition = Connection(
                    source=inh_neurons_layer,
                    target=new_layer,
                    w=w_inh_LC,
                    nu= nu_inh_LC,
                ).to(device)

                LC_lateral_inhibition_rev = Connection(
                    source=new_layer,
                    target=inh_neurons_layer,
                    w=w_inh_LC_rev,
                    nu= nu_inh_LC,
                ).to(device)

                self.add_connection(LC_lateral_inhibition, f"layer{i+1}_inh_neurons", f"layer{i+1}")
                self.add_connection(LC_lateral_inhibition_rev, f"layer{i+1}",f"layer{i+1}_inh_neurons")

        self.to(device)


    def fit(
        self,
        dataloader,
        n_train = 2000,
        verbose = True,
    ):
        manual_seed(SEED)
        self.verbose = verbose
        # add Monitors
        #Plot_et = PlotET(i = 0, j = 0, source = self.layers["main"], target = self.layers["output"], connection = self.connections[("main","output")])
        #tensorboard = TensorBoardMonitor(self, time = self.time)
        #self.add_monitor(main_monitor, name="main")
        #self.add_monitor(Plot_et, name="Plot_et")
        #self.add_monitor(tensorboard, name="tensorboard")
        self.spikes = {}
        for layer in set(self.layers):
            self.spikes[layer] = Monitor(self.layers[layer], state_vars=["s"], time=None)
            self.add_monitor(self.spikes[layer], name="%s_spikes" % layer)

        if self.load_path:
            self.model_params = torch.load(self.load_path, torch.device('cuda') if gpu else torch.device('cpu'))
            self.load_state_dict(torch.load(self.load_path, torch.device('cuda') if gpu else torch.device('cpu'))['state_dict'])
            iteration =  self.model_params['iteration']
            hparams = self.model_params['hparams']
            train_accs = self.model_params['train_accs']
            val_accs = self.model_params['val_accs']
            acc_rewards = self.model_params['acc_rewards']
            print(f'Previous model loaded! Resuming training from iteration {iteration}..., last running training accuracy: {train_accs[-1]}, last validation accuracy: {val_accs[-1]}\n') if self.verbose else None
        else:
            print(f'Previous model not found! Training from the beginning...\n') if self.verbose else None


        pbar = tqdm(total=n_train)
        self.reset_state_variables()

        for (i, datum) in enumerate(dataloader):
            if self.load_path:
                if i <= iteration:
                    n_train += 1
                    continue
            if i > n_train:
                break

            image = datum["encoded_image"]
            # Run the network on the input.
            if gpu:
                inputs = {"layer0": image.cuda().view(self.time, 1, self.input_channels, self.crop_size, self.crop_size)}
            else:
                inputs = {"layer0": image.view(self.time, 1, self.input_channels, self.crop_size, self.crop_size)}
            ### Spike clamping (baseline activity)
            clamp = {}
            if self.clamp_intensity is not None:
                encoder = PoissonEncoder(time = self.time, dt = self.dt)
                clamp['output'] = encoder.enc(datum = torch.rand(self.layers['output'].n)*self.clamp_intensity,time = self.time, dt = self.dt)

            self.run(inputs=inputs, 
                    time=self.time, 
                    one_step=True,
                    clamp = clamp
                     )

            # Get voltage recording.
            #tensorboard.update(step= i)

            ### convergence
            # w_lc1 = self.connections[('input', 'main1')].w
            # w_last_main_out = self.connections[(self.final_connection_source_name,'output')].w
            # convg_lc1=  1 -  torch.mean((w_lc1 - self.wmin)*(self.wmax-w_lc1))
            # convg_out= 1 -  torch.mean((w_last_main_out - self.wmin)*(self.wmax-w_last_main_out))
            # if self.norm_factor_LC is not None:
            #     mean_norm_factor_lc = self.norm_factor_LC / w_lc1.shape[-1]
            #     convg_lc1=  1 - ( torch.mean((w_lc1 - self.wmin)*(self.wmax-w_lc1)) / ((mean_norm_factor_lc - self.wmin)*(self.wmax - mean_norm_factor_lc)) )
            # if self.norm_factor_out is not None:    
            #     mean_norm_factor_out = self.norm_factor_out / w_last_main_out.shape[-2]
            #     convg_out= 1 - ( torch.mean((w_last_main_out - self.wmin)*(self.wmax-w_last_main_out)) / ((mean_norm_factor_out - self.wmin)*(self.wmax - mean_norm_factor_out)) )
            # self.convergences['lc1'].append((convg_lc1 * 10**4).round() / (10**4))
            # self.convergences['last_main_out'].append((convg_out * 10**4).round() / (10**4)) 

            #print(f"\r{i}/{n_train}",
                # sum_spikes, 'pred_label:',
                # predicted_label.item(), 'GT:', label.item(),
                # ', Acc Rew:', round(sum(reward_monitor.get()).item(),4),
                # f"Pos dps: {self.reward_fn.dps:.5f}, Neg dps: {self.reward_fn.neg_dps:.5f}, Rew base: {self.reward_fn.rew_base:.5f}, Pun base: {self.reward_fn.punish_base:.5f}, RPe: {self.reward_fn.reward_predict_episode:.3f}",
                # f"input_mean_fire_freq: {torch.mean(image.float())*1000:.1f},main_mean_fire_freq:{torch.mean(lc_spikes1.float())*1000:.1f}",#" main2_mean_fire_freq:{torch.mean(lc_spikes2.float())*1000:.1f}",
                # f"output_mean_fire_freq:{torch.mean(out_spikes.float())*1000:.1f}",
                # f"mean_lc1_w: {torch.mean(w_lc1):.5f}, mean_fc_w:{torch.mean(w_last_main_out):.5f}",
                # f"std_lc1_w: {torch.std(w_lc1):.5f}, std_fc_w:{torch.std(w_last_main_out):.5f}",
                # f"convergence_lc1: {convg_lc1:.5f}, convergence_fc: {convg_out:.5f}",
                # end = '')       
            if verbose:
                print(f"\r ",
                    f"input_mean_fire_freq: {torch.mean(image.float())*1000:.1f}",
                    '- Spikes:', {name: monitor.get('s').sum().item() for name, monitor in self.spikes.items()},
                    '- Layers Mean:', list(map(lambda x: {x[0]: round(x[1].w.mean().item(),3)}, self.connections.items())),
                    '- convergences:', {name: round(convergence(c).item() , 5) for name, c in self.connections.items() if c.wmin!=-np.inf and c.wmax!=np.inf },
                    '- std:', {name: round(c.w.std().item() , 5) for name, c in self.connections.items() if c.wmin!=-np.inf and c.wmax!=np.inf },
                    end = ''
                    )    

            #Plot_et.plot()    
            self.reset_state_variables()  # Reset state variables.
            
            # pbar.set_description_str("Running accuracy: " + "{:.2f}".format(acc) + "%, " + "Current val accuracy: " + "{:.2f}".format(val_acc) + "%, ")
            pbar.update()

        # tensorboard.writer.add_hparams(
        #     {k:(v if type(v) in (int, float, bool, str, torch.Tensor) else str(v)) for k,v in {**train_hparams, **data_hparams, **network_hparams, **reward_hparams}.items() },
        #     result_metrics
        # )


    def predict(
        self,
        val_loader,
        n_pred,
    ):
        manual_seed(SEED)

        monitors = []
        for layer in self.layers:
            if layer == 'layer0':
                continue
            monitor = Monitor(self.layers[layer], ["s", "v"], time=time)
            monitors.append(monitor)
            self.add_monitor(monitor, name=f"{layer} monitor")
        self.train(False)
        S= []
        V= []
        y= []
        pbar = tqdm(total=n_pred)
        for (i, datum) in enumerate(val_loader):
            if i > n_pred:
                break

            image = datum["encoded_image"]
            if self.label is None : 
              label = datum["label"]
            else :
              label = self.label

            # Run the network on the input.
            if gpu:
                inputs = {"layer0": image.cuda().view(self.time, 1, self.input_channels, self.crop_size, self.crop_size)}
            else:
                inputs = {"layer0": image.view(self.time, 1, self.input_channels, self.crop_size, self.crop_size)}

            self.run(inputs=inputs, 
                    time=self.time, 
                    **reward_hparams,
                    one_step = True,
                    true_label = label.int().item(),
                    dopaminergic_layers= self.dopaminergic_layers,
                     )

            S.append(list(map(lambda x: x.get('s'), monitors)))
            V.append(list(map(lambda x: x.get('v'), monitors)))
            y.append(label)

            self.reset_state_variables()  # Reset state variables.
            pbar.update()

        self.train(True)
        return S, V ,y

# Load Dataset

In [13]:
class ClassSelector(torch.utils.data.sampler.Sampler):
    """Select target classes from the dataset"""
    def __init__(self, target_classes, data_source, mask = None):
        if mask is not None:
            self.mask = mask
        else:
            self.mask = torch.tensor([1 if data_source[i]['label'] in target_classes else 0 for i in range(len(data_source))])
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(self.mask)])

    def __len__(self):
        return len(self.data_source)

In [14]:
kernels = [DoGKernel(7,1,2),
			DoGKernel(7,2,1),]
filter = Filter(kernels, padding = 3, thresholds = 50/255)
# Load MNIST data.

def load_datasets(network_hparams, data_hparams, mask=None, test_mask=None):
    manual_seed(SEED)
    dataset = MNIST(
        PoissonEncoder(time=network_hparams['time'], dt=network_hparams['dt']),
        None,
        root=os.path.join("..", "..", "data", "MNIST"),
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
            # filter,
            transforms.Lambda(lambda x: (
                x.round() if data_hparams['round_input'] else x
            ) * data_hparams['intensity']),
            transforms.CenterCrop(data_hparams['crop_size'])]
        ),
    )

    # Create a dataloader to iterate and batch data
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                            sampler = ClassSelector(
                                                    target_classes = target_classes,
                                                    data_source = dataset,
                                                    mask = mask,
                                                    ) if target_classes else None
                                            )

    # Load test dataset
    test_dataset = MNIST(   
        PoissonEncoder(time=network_hparams['time'], dt=network_hparams['dt']),
        None,
        root=os.path.join("..", "..", "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
            # filter,
            transforms.Lambda(lambda x: (
                x.round() if data_hparams['round_input'] else x
            ) * data_hparams['intensity']),
            transforms.CenterCrop(data_hparams['crop_size'])]
        ),
    )

    val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
                                            sampler = ClassSelector(
                                                    target_classes = target_classes,
                                                    data_source = test_dataset,
                                                    mask = mask_test,
                                                    ) if target_classes else None
                                            )
    

    return dataloader, val_loader

# Set up hyper-parameters

In [15]:
# Dataset Hyperparameters
target_classes = None #(0,1)
if target_classes:
    npz_file = np.load(f'bindsnet/mask_{"_".join([str(i) for i in target_classes])}.npz')
    # npz_file = np.load('bindsnet/mask_0_1.npz') ##### KESAFAT KARI !!!
    mask, mask_test = torch.from_numpy(npz_file['arr_0']), torch.from_numpy(npz_file['arr_1'])
    n_classes = len(target_classes)
    
else:
    mask = None
    mask_test = None
    n_classes = 10

data_hparams = { 
    'intensity': 128,
    'crop_size': 22,
    'round_input': False,
}

In [16]:
train_hparams = {
    'n_train' : 4000,
}

network_hparams = {
    # net structure
    'channels': [100],
    'filters': [15],
    'strides': [4],
    'input_channels': 1,
    # time & Phase
    'time': 250,
    'dt' : 1,
    # Nodes
    'NodesType_LC': AdaptiveLIFNodes,
    'theta_plus': 0.05,
    'tc_theta_decay': 1e6,
    'tc_trace':20,
    'trace_additive' : False,
    # Learning
    'update_rule_LC': PostPre,
    'update_rule_inh_LC' : None,
    'nu_LC': (1e-4, 1e-2),
    'soft_bound': False,
    'wmin': 0,
    'wmax': 1,
    # Inhibition
    'nu_inh_LC': 0.0,
    'inh_LC': True,
    'inh_LC_type': 'lateral',
    'inh_factor_LC': 100,
    # Normalization
    'norm_factor_LC': 0.25*15*15,
    'norm_factor_inh': None,
    'norm_factor_inh_LC': None,
    # clamp
    'clamp_intensity': None,
    # Save
    'LC_weights_path': None,#'/content/drive/My Drive/LCNet/BioLCNet_layer1_Shallow_f15_s4_inh100_norm25_ch100.pth',
    'save_path': '/content/drive/My Drive/LCNet/LCNet_phase2_baseline2_gpu_tmp.pth',
    'load_path': None,#'/content/drive/My Drive/LCNet/LCNet_phase2_baseline2_gpu.pth',
    # Plot:
    'confusion_matrix' : False,
    'lc_weights_vis': False,
    'out_weights_vis': False,
    'lc_convergence_vis': False,
    'out_convergence_vis': False,
}


In [17]:
dataloader, val_loader = load_datasets(network_hparams, data_hparams, mask, mask_test)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz
Extracting ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/TorchvisionDatasetWrapper/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz
Extracting ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/TorchvisionDatasetWrapper/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Using downloaded and verified file: ../../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz
Extracting ../../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/TorchvisionDatasetWrapper/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k

# Training

In [None]:
# from google.colab import drive
# if network_hparams['save_path'] or network_hparams['LC_weights_path']:    
#     drive.mount('/content/drive')

In [None]:
manual_seed(SEED)
net = LCNet(**network_hparams, **train_hparams, **data_hparams)
net.fit(dataloader = dataloader, **train_hparams)

## Visualization

In [20]:
def plot_semantic_pooling(lc : object,
    input_channel: int = 0,
    output_channel: int = None,
    lines: bool = True,
    figsize: Tuple[int, int] = (5, 5),
    cmap: str = "hot_r",
    color: str='r',
    ) -> AxesImage:
    # language=rst
    """
    Plot a connection weight matrix of a :code:`Connection` with `locally connected
    structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
    :param lc: LC connection object of LCNet
    :param input_channel: indicates weights which connected to this channel of input 
    :param output_channel: indicates weights of specific channel in the output layer
    :param lines: Whether or not to draw horizontal and vertical lines separating input regions.
    :param figsize: Horizontal, vertical figure size in inches.
    :param cmap: Matplotlib colormap.
    :return: Used for re-drawing the weights plot.
    """

    n_sqrt = int(np.ceil(np.sqrt(lc.out_channels)))
    sel_slice = lc.w.view(lc.in_channels, lc.out_channels, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu()
    input_size = _pair(int(np.sqrt(lc.source.n)))
    
    if output_channel is None:
        sel_slice = sel_slice[input_channel, ...]
        reshaped = reshape_LC_weights(sel_slice, lc.out_channels, lc.kernel_size, lc.conv_size, input_size)
    else:
        sel_slice = sel_slice[input_channel, output_channel, ...]
        sel_slice = sel_slice.unsqueeze(0)
        reshaped = reshape_LC_weights(sel_slice, 1, lc.kernel_size, lc.conv_size, input_size)
        print(reshaped.shape)

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
    div = make_axes_locatable(ax)
    cax = div.append_axes("right", size="5%", pad=0.05)

    if lines and  output_channel is None:
        for i in range(
            n_sqrt * lc.kernel_size[0],
            n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
            n_sqrt * lc.kernel_size[0],
        ):
            ax.axhline(i - 0.5, color=color, linestyle="--")

        for i in range(
            n_sqrt * lc.kernel_size[1],
            n_sqrt * lc.conv_size[1] * lc.kernel_size[1],
            n_sqrt * lc.kernel_size[1],
        ):
            ax.axvline(i - 0.5, color=color, linestyle="--")

    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_aspect("auto")

    plt.colorbar(im, cax=cax)
    fig.tight_layout()

    return im

In [19]:
def reshape_LC_weights(
    w: torch.Tensor,
    n_filters: int,
    kernel_size: Union[int, Tuple[int, int]],
    conv_size: Union[int, Tuple[int, int]],
    input_sqrt: Union[int, Tuple[int, int]],
) -> torch.Tensor:
    # language=rst
    """
    Get the weights from a locally connected layer and reshape them to be two-dimensional and square.
    :param w: Weights from a locally connected layer.
    :param n_filters: No. of neuron filters.
    :param kernel_size: Side length(s) of convolutional kernel.
    :param conv_size: Side length(s) of convolution population.
    :param input_sqrt: Sides length(s) of input neurons.
    :return: Locally connected weights reshaped as a collection of spatially ordered square grids.
    """
    k1, k2 = kernel_size
    c1, c2 = conv_size
    i1, i2 = input_sqrt
    c1sqrt, c2sqrt = int(math.ceil(math.sqrt(c1))), int(math.ceil(math.sqrt(c2)))
    fs = int(math.ceil(math.sqrt(n_filters)))

    w_ = torch.zeros((n_filters * k1, k2 * c1 * c2))

    for n1 in range(c1):
        for n2 in range(c2):
            for feature in range(n_filters):
                n = n1 * c2 + n2
                filter_ = w[feature, n1, n2, :, :
                ].view(k1, k2)
                w_[feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2] = filter_

    if c1 == 1 and c2 == 1:
        square = torch.zeros((i1 * fs, i2 * fs))

        for n in range(n_filters):
            square[
                (n // fs) * i1 : ((n // fs) + 1) * i2,
                (n % fs) * i2 : ((n % fs) + 1) * i2,
            ] = w_[n * i1 : (n + 1) * i2]

        return square
    else:
        square = torch.zeros((k1 * fs * c1, k2 * fs * c2))

        for n1 in range(c1):
            for n2 in range(c2):
                for f1 in range(fs):
                    for f2 in range(fs):
                        if f1 * fs + f2 < n_filters:
                            square[
                                k1 * (n1 * fs + f1) : k1 * (n1 * fs + f1 + 1),
                                k2 * (n2 * fs + f2) : k2 * (n2 * fs + f2 + 1),
                            ] = w_[
                                (f1 * fs + f2) * k1 : (f1 * fs + f2 + 1) * k1,
                                (n1 * c2 + n2) * k2 : (n1 * c2 + n2 + 1) * k2,
                            ]

        return square

In [21]:
def plot_LC_weights(lc : object,
    input_channel: int = 0,
    output_channel: int = None,
    lines: bool = True,
    figsize: Tuple[int, int] = (5, 5),
    cmap: str = "hot_r",
    color: str='r',
    ) -> AxesImage:
    # language=rst
    """
    Plot a connection weight matrix of a :code:`Connection` with `locally connected
    structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
    :param lc: LC connection object of LCNet
    :param input_channel: indicates weights which connected to this channel of input 
    :param output_channel: indicates weights of specific channel in the output layer
    :param lines: Whether or not to draw horizontal and vertical lines separating input regions.
    :param figsize: Horizontal, vertical figure size in inches.
    :param cmap: Matplotlib colormap.
    :return: Used for re-drawing the weights plot.
    """

    n_sqrt = int(np.ceil(np.sqrt(lc.out_channels)))
    sel_slice = lc.w.view(lc.in_channels, lc.out_channels, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu()
    input_size = _pair(int(np.sqrt(lc.source.n)))
    
    if output_channel is None:
        sel_slice = sel_slice[input_channel, ...]
        reshaped = reshape_LC_weights(sel_slice, lc.out_channels, lc.kernel_size, lc.conv_size, input_size)
    else:
        sel_slice = sel_slice[input_channel, output_channel, ...]
        sel_slice = sel_slice.unsqueeze(0)
        reshaped = reshape_LC_weights(sel_slice, 1, lc.kernel_size, lc.conv_size, input_size)
        print(reshaped.shape)

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
    div = make_axes_locatable(ax)
    cax = div.append_axes("right", size="5%", pad=0.05)

    if lines and  output_channel is None:
        for i in range(
            n_sqrt * lc.kernel_size[0],
            n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
            n_sqrt * lc.kernel_size[0],
        ):
            ax.axhline(i - 0.5, color=color, linestyle="--")

        for i in range(
            n_sqrt * lc.kernel_size[1],
            n_sqrt * lc.conv_size[1] * lc.kernel_size[1],
            n_sqrt * lc.kernel_size[1],
        ):
            ax.axvline(i - 0.5, color=color, linestyle="--")

    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_aspect("auto")

    plt.colorbar(im, cax=cax)
    fig.tight_layout()

    return im

In [22]:
def plot_LC_activation_map(lc : object,
    spikes: torch.tensor,
    input_channel: int = 0,
    scale_factor: float = 1.0,
    lines: bool = True,
    figsize: Tuple[int, int] = (5, 5),
    cmap: str = "hot_r",
    color: str='r'
    ) -> AxesImage:
    # language=rst
    """
    Plot an activation map of a :code:`Connection` with `locally connected
    structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
    :param lc: LC connection object of LCNet
    :param input_channel: indicates weights which connected to this channel of input 
    :param scale_factor: determines intensity of activation map 
    :param lines: Whether or not to draw horizontal and vertical lines separating input regions.
    :param figsize: Horizontal, vertical figure size in inches.
    :param cmap: Matplotlib colormap.
    :return: Used for re-drawing the weights plot.
    """
    spikes = spikes.sum(0).squeeze().view(lc.conv_size[0]*int(np.sqrt(lc.out_channels)),lc.conv_size[1]*int(np.sqrt(lc.out_channels)))
    x = scale_factor * spikes / torch.max(spikes)
    x = x.clip(lc.wmin,lc.wmax).repeat_interleave(lc.kernel_size[0], dim=0).repeat_interleave(lc.kernel_size[1], dim=1).cpu()
    n_sqrt = int(np.ceil(np.sqrt(lc.out_channels)))

    sel_slice = lc.w.view(lc.in_channels, lc.out_channels, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu()
    sel_slice = sel_slice[input_channel, ...]
    input_size = _pair(int(np.sqrt(lc.source.n)))
    reshaped = reshape_LC_weights(sel_slice, lc.out_channels, lc.kernel_size, lc.conv_size, input_size)

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(reshaped.cpu()*x, cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
    div = make_axes_locatable(ax)
    cax = div.append_axes("right", size="5%", pad=0.05)

    if lines:
        for i in range(
            n_sqrt * lc.kernel_size[0],
            n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
            n_sqrt * lc.kernel_size[0],
        ):
            ax.axhline(i - 0.5, color=color, linestyle="--")

        for i in range(
            n_sqrt * lc.kernel_size[1],
            n_sqrt * lc.conv_size[1] * lc.kernel_size[1],
            n_sqrt * lc.kernel_size[1],
        ):
            ax.axvline(i - 0.5, color=color, linestyle="--")

    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_aspect("auto")

    plt.colorbar(im, cax=cax)
    fig.tight_layout()

    return im

In [23]:
def plot_FC_response_map(lc: object,
    fc: object,
    ind_neuron_in_group: int,
    label: int,
    n_per_class: int,
    input_channel: int = 0,
    scale_factor: float = 1.0,
    lines: bool = True,
    figsize: Tuple[int, int] = (5, 5),
    cmap: str = "hot_r",
    color: str='r'
    ) -> AxesImage:
    # language=rst
    """
    Plot a connection weight matrix of a :code:`Connection` with `locally connected
    structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
    :param lc: LC connection object of LCNet
    :param fc: FC connection object of LCNet
    :param input_channel: indicates weights which connected to this channel of input
    :param scale_factor: determines intensity of activation map  
    :param lines: Whether or not to draw horizontal and vertical lines separating input regions.
    :param figsize: Horizontal, vertical figure size in inches.
    :param cmap: Matplotlib colormap.
    :return: Used for re-drawing the weights plot.
    """

    n_sqrt = int(np.ceil(np.sqrt(lc.out_channels)))

    sel_slice = lc.w.view(lc.in_channels, lc.out_channels, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu()
    sel_slice = sel_slice[input_channel, ...]
    input_size = _pair(int(np.sqrt(lc.source.n)))
    reshaped = reshape_LC_weights(sel_slice, lc.out_channels, lc.kernel_size, lc.conv_size, input_size)
	
    ind_neuron = label * n_per_class + ind_neuron_in_group
    w = fc.w[:,ind_neuron].view(reshaped.shape[0]//lc.kernel_size[0],reshaped.shape[1]//lc.kernel_size[1])
    w = w.clip(lc.wmin,lc.wmax).repeat_interleave(lc.kernel_size[0], dim=0).repeat_interleave(lc.kernel_size[1], dim=1).cpu()

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(reshaped.cpu()*w, cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
    div = make_axes_locatable(ax)
    cax = div.append_axes("right", size="5%", pad=0.05)

    if lines:
        for i in range(
            n_sqrt * lc.kernel_size[0],
            n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
            n_sqrt * lc.kernel_size[0],
        ):
            ax.axhline(i - 0.5, color=color, linestyle="--")

        for i in range(
            n_sqrt * lc.kernel_size[1],
            n_sqrt * lc.conv_size[1] * lc.kernel_size[1],
            n_sqrt * lc.kernel_size[1],
        ):
            ax.axvline(i - 0.5, color=color, linestyle="--")

    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_aspect("auto")

    plt.colorbar(im, cax=cax)
    fig.tight_layout()

    return im

In [None]:
# net.fit(dataloader = dataloader, **train_hparams)

In [None]:
# net.spikes['layer1'].get('s').sum()

In [None]:
# net.connections

In [None]:
plot_locally_connected_weights_one_chan(net.connections[('layer0', 'layer1')].w,100,1,0,22,15,2,0)

In [None]:
# net.connections[('layer1_inh_neuron', 'layer1')].w

# **Following lines are NOT edited yet!**

### Visualization 

In [None]:
plot_locally_connected_weights_meh(net.connections[('layer0', 'layer1')].w,100,1,0,22,15,2)

In [None]:
dataloader_iterator = iter(dataloader)
for i in range(10):
    datum = next(dataloader_iterator)
    plt.imshow(datum['encoded_image'].squeeze().sum(0))
    plt.show()
    net.one_step(datum)
    spikes = net.spikes["main1"].get("s").sum(0).squeeze().view(4*5,4*5)
    print(torch.max(spikes))
    plot_locally_connected_weights_meh2(net.connections[('input', 'main1')].w,spikes,25,1,0,22,13,4)
    plt.show()

## Tensorboard

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
%tensorboard --logdir '/content/runs'

## Save/Load Sessions

Save tensorBoard Session 

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp -a /content/runs/. /content/drive/MyDrive/LCNet/logs/

Read Saved Sessions

In [None]:
%tensorboard --logdir /content/drive/MyDrive/LCNet/logs/

## Optuna

install and import optuna

In [None]:
!pip install optuna
import optuna

Define objective function

In [None]:
STUDY_NAME  = ''
DATA_PATH = ''
N_TRIALS = ''

def objective(trial):
    ### Suggest parameters: 
    num_layers = trial.suggest_int('Number of Layers', 1, 4)
    dropout_rate  = trial.suggest_float('Dropout', 0, .99)
    activation = trial.suggest_categorical('activation', ['relu', 'selu', 'sigmoid', 'elu'])
    lr = trial.suggest_float('Learning rate', 1e-6, 1)
    network_hparams.update({
        
    })
    ###Define your model
    manual_seed(SEED)
    hparams = {**reward_hparams, **network_hparams, **train_hparams, **data_hparams}
    net = LCNet(**hparams, reward_fn = DynamicDopamineInjection)
    va_acc = net.fit(dataloader = dataloader, val_loader = val_loader, reward_hparams = reward_hparams, **train_hparams)
    ### Define objective value
    objective_value = min(va_acc)

    
    return objective_value


Run the study

In [None]:
study = optuna.create_study(study_name = STUDY_NAME , storage=f"sqlite:////content/drive/MyDrive/LCNet/optuna/optuna_study.db", load_if_exists=True)
study.optimize(objective, n_trials=N_TRIALS )

In [None]:
study.best_params

{'Dropout': 0.09142336347778651,
 'Layer 1': 82,
 'Layer 2': 508,
 'Layer 3': 286,
 'Learning rate': 0.008771184760927113,
 'Number of Layers': 3,
 'activation': 'relu',
 'l1': None}

Visualization

In [None]:
plot_parallel_coordinate(study, params=["Learning rate", "Number of Layers", "Dropout"])

# SVM 1

In [None]:
network_hparams.update({'LC_weights_path':'/content/drive/My Drive/LCNet/BioLCNet_layer1_Shallow_f13_s3_inh100_norm3.pth'})
train_hparams.update({'n_train' : 1,})

In [None]:
from google.colab import drive
if network_hparams['save_path'] or network_hparams['LC_weights_path']:    
    drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
manual_seed(SEED)
hparams = {**reward_hparams, **network_hparams, **train_hparams, **data_hparams}
net = LCNet(**hparams, reward_fn = DynamicDopamineInjection)
net.fit(dataloader = dataloader, val_loader = val_loader, reward_hparams = reward_hparams, **train_hparams)

S, V, y = net.predict(
    val_loader= val_loader,
    n_pred= 9999,
)

Weights loaded ...
Previous model not found! Training from the beginning...



  0%|          | 0/1 [00:00<?, ?it/s]

output tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) pred_label: 0 GT: 0 , Acc Rew: 0.0 Pos dps: 1.00000, Neg dps: 1.00000, Rew base: 1.00000, Pun base: 1.00000, RPe: 0.000 input_mean_fire_freq: 31.6,main_mean_fire_freq:8.5 output_mean_fire_freq:19.6 mean_lc1_w: 0.25000, mean_fc_w:0.49627 std_lc1_w: 0.23600, std_fc_w:0.28884 convergence_lc1: 0.29705, convergence_fc: 0.83342

  0%|          | 0/9999 [00:00<?, ?it/s]

In [None]:
def SVM(layer, monitor):
    X = np.array([i[layer].squeeze(1).view(i[layer].shape[0],-1).sum(0).numpy()  for i in monitor])
    y_array = np.array(y)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y_array, test_size=0.2, random_state=42,
    )

    classifier = svm.SVC()
    classifier.fit(X_train, y_train)
    
    return accuracy_score(y_test, classifier.predict(X_test))

for l in range(len(S[0])):
    print(f'* Layer {l}')
    print(f"Accuracy for spikes:", SVM(l, S))
    print(f"Accuracy for voltage:", SVM(l, V))

* Layer 0
Accuracy for spikes: 0.8755
Accuracy for voltage: 0.814
* Layer 1
Accuracy for spikes: 0.444
Accuracy for voltage: 0.4195


# SVM 2 

In [None]:
network_hparams.update({'LC_weights_path':'/content/drive/My Drive/LCNet/BioLCNet_layer1_Shallow_f15_s4_inh100_norm25_ch100.pth',})
train_hparams.update({'n_train' : 0,})

In [None]:
from google.colab import drive
if network_hparams['save_path'] or network_hparams['LC_weights_path']:    
    drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
manual_seed(SEED)
hparams = {**reward_hparams, **network_hparams, **train_hparams, **data_hparams}
net = LCNet(**hparams, reward_fn = DynamicDopamineInjection)
net.fit(dataloader = dataloader, val_loader = val_loader, reward_hparams = reward_hparams, **train_hparams)

S, V, y = net.predict(
    val_loader= val_loader,
    n_pred= 9999,
)

Weights loaded ...
Previous model not found! Training from the beginning...



0it [00:00, ?it/s]

output tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) pred_label: 0 GT: 5 , Acc Rew: 0.0 Pos dps: 1.00000, Neg dps: 1.00000, Rew base: 1.00000, Pun base: 1.00000, RPe: 0.000 input_mean_fire_freq: 28.0,main_mean_fire_freq:4.0 output_mean_fire_freq:11.6 mean_lc1_w: 0.25000, mean_fc_w:0.50152 std_lc1_w: 0.23792, std_fc_w:0.28744 convergence_lc1: 0.30191, convergence_fc: 0.83260

  0%|          | 0/9999 [00:00<?, ?it/s]

In [None]:
def SVM(layer, monitor):
    X = np.array([i[layer].squeeze(1).view(i[layer].shape[0],-1).sum(0).numpy()  for i in monitor])
    y_array = np.array(y)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y_array, test_size=0.2, random_state=42,
    )

    classifier = svm.SVC()
    classifier.fit(X_train, y_train)
    
    return accuracy_score(y_test, classifier.predict(X_test))

for l in range(len(S[0])):
    print(f'* Layer {l}')
    print(f"Accuracy for spikes:", SVM(l, S))
    print(f"Accuracy for voltage:", SVM(l, V))

* Layer 0
Accuracy for spikes: 0.833
Accuracy for voltage: 0.7495
