# LGN Unit Analysis

Looking at how many LGN units are needed in the model.

January 17, 2024

In [1]:
import sys
import os

myhost = os.uname()[1]
print("Running on Computer: [%s]" %myhost)

sys.path.insert(0, '/home/ifernand/Code/') 
dirname = '/home/ifernand/Cloud_SynthData_Proj'

import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from scipy import io as sio
import torch
import time
import h5py

# NDN tools
import NDNT
import NDNT.utils as utils
from NDNT.modules.layers import *
from NDNT.networks import *
import NDNT.NDN as NDN
from NTdatasets.conway.synthcloud_datasets import SimCloudData
from NTdatasets.generic import GenericDataset
from ColorDataUtils.multidata_utils import MultiExperiment
import ColorDataUtils.ConwayUtils as CU
from ColorDataUtils import readout_fit
from ColorDataUtils.simproj_utils import *
from NDNT.utils import fit_lbfgs, fit_lbfgs_batch
from NDNT.utils import imagesc   
from NDNT.utils import ss
from NDNT.utils import subplot_setup
from NDNT.utils import figure_export

# Clustering
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA, KernelPCA
from sklearn.manifold import TSNE
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import DBSCAN
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device0 = torch.device("cpu")
dtype = torch.float32

%load_ext autoreload
%autoreload 2

Running on Computer: [sc]
Invoking __init__.py for NDNT.utils


## Load Data

In [2]:
start    = time.time()
data     = SimCloudData(cell_type_list=['V1_Exc_L4', 'V1_Inh_L4', 'V1_Exc_L2/3', 'V1_Inh_L2/3'], down_sample=2, num_lags=11)
end      = time.time()
print('CPU Time', end-start, 'sec')

CPU Time 53.34195899963379 sec


In [3]:
# Load baseline LL
GLM_LL = np.load('data/cloud_data_stim_dim_120_sqrad_0.3_GLM_LL.pkl', allow_pickle=True)
GQM_LL = np.load('data/cloud_data_stim_dim_120_sqrad_0.3_GQM_LL.pkl', allow_pickle=True)

In [4]:
cell_idx_dict = data.cell_idx_dict
thetas_dict = data.thetas
V1_thetas = np.concatenate((thetas_dict['V1_Exc_L4'], thetas_dict['V1_Inh_L4'], thetas_dict['V1_Exc_L2/3'], thetas_dict['V1_Inh_L2/3']))
mu0s = data.mu0s

In [5]:
stim_dims = data.stim_dims
num_lags = data.num_lags
L  = stim_dims[1]
NC = data.NC
NT = data.NT
print('stim_dims =', stim_dims)
print('num_lags =', num_lags)
print('L =', L)
print('Number of cells', NC)
print('Number of time points', NT)

stim_dims = [1, 60, 60, 1]
num_lags = 11
L = 60
Number of cells 1491
Number of time points 510000


## Degree to Qmu

In [6]:
def degrees2mu(theta_deg, angles, continuous=True, max_angle=180 ):
    """
    Converts degrees into mu-values. If to_output=True, outputs to an array, and otherwise
    stores in the Qmu variable. It detects whether half-circle of full circle using stored angle values
        
    Args:
        theta_deg (np array): array of angles in degrees into mu values, based on 180 or 360 deg wrap-around
        continuous (Boolean): whether to convert to continuous angle or closest "integer" mu value (def True, continuous)
        max_angle: maximum angle represented in OriConv layers (default 180, but could be 360)
    Returns:
        Qmus: as numpy-array, if to_output is set to True, otherwise, nothing 
    """
    num_angles = len(angles)

    # convert inputs to np.array
    if not isinstance(theta_deg, np.ndarray):
        theta_deg = np.array(theta_deg, dtype=np.float32)
    if not continuous:
        dQ = max_angle/num_angles
        theta_deg = dQ * np.round(theta_deg/dQ)
    theta_deg = (theta_deg%max_angle)  # map between 0 and max_angle

    mu_offset = 1/num_angles # first bin at 0 degrees is actually a shifted mu value (not right at edge)
    Qmus = (theta_deg-max_angle/2) / (max_angle/2) + mu_offset
    Qmus[Qmus <= -1] += 2
    Qmus[Qmus > 1] += -2
    
    return Qmus

In [7]:
angles = np.arange(0, 180, 30).astype(int)
Qmu0s = degrees2mu(V1_thetas, angles)
print('Angle mu0:', Qmu0s.shape)

Angle mu0: (1491,)


## Create Model

In [8]:
# Adam Parameters
adam_pars = utils.create_optimizer_params(
    optimizer_type='AdamW', batch_size=1,
    learning_rate=0.01, early_stopping_patience=4,
    optimize_graph=False, weight_decay=0.2, accumulated_grad_batches=3)
adam_pars['device'] = device

In [10]:
XTreg = 0.0001
Xreg0 = 1.0 # d2/dx
Creg0 = 0.05 # center

Xreg1 = 0.05 # d2/dx
Creg1 = 0.001 # center

MaxReg = 0.001

angle_mode = 'nearest' # 'bilinear'

NQ = len(angles)

num_LGN_units = [1,2,3,4]
models = {}

for i in range(len(num_LGN_units)):
    num_subs = [num_LGN_units[i], 24, 16, 16]
    fws = [19, 19, 5, 5]  

    # LGN LAYER
    clayersQ = [STconvLayer.layer_dict( 
        input_dims = data.stim_dims, num_filters=num_subs[0], norm_type=1,
        filter_dims=[1,fws[0],fws[0],num_lags-1] , bias=False, NLtype='relu',
        padding='circular', output_norm='batch', window='hamming', initialize_center=True,
        reg_vals={'d2xt':XTreg, 'd2x':Xreg0, 'center': Creg0} )]

    # PROJECTION LAYER
    clayersQ.append(
        OriConvLayer.layer_dict(
            num_filters=num_subs[1], num_inh=num_subs[1]//2,
            filter_width=fws[1], NLtype='relu', norm_type=1,
            bias=False, output_norm='batch', window='hamming', padding='circular', initialize_center=True, 
            reg_vals={'d2x':Xreg1, 'center': Creg1}, angles=angles) )

    # TIME SHIFT LAYER
    clayersQ.append(TimeShiftLayer.layer_dict())

    # REST
    for ii in range(2,len(fws)):
        clayersQ.append(OriConvLayer.layer_dict( 
            num_filters=num_subs[ii], num_inh=num_subs[ii]//2, bias=False, norm_type=1, 
            filter_width=fws[ii], NLtype='relu',
            output_norm='batch', initialize_center=True, #window='hamming', 
            angles=angles) )
        
    scaffold_netQ =  FFnetwork.ffnet_dict(
        ffnet_type='scaffold3d', xstim_n='stim', layer_list=clayersQ, scaffold_levels=[1,3,4], num_lags_out=NQ)

    readout_parsQ = ReadoutLayerQsample.layer_dict(
        num_filters=NC, NLtype='softplus', bias=True, pos_constraint=True,
        reg_vals={'max': MaxReg})

    readout_netQ = FFnetwork.ffnet_dict(xstim_n = None, ffnet_n=[0], layer_list = [readout_parsQ], ffnet_type='readout')
    
    models[str(num_LGN_units[i])+' LGN Units'] = NDN(ffnet_list = [scaffold_netQ, readout_netQ], loss_type='poisson', seed=100)
    print('Model', i, 'created')

Model 0 created
Model 1 created
Model 2 created
Model 3 created


In [None]:
LL_dict = {}
for i in range(len(num_LGN_units)):
    cnn = models[str(num_LGN_units[i])+' LGN Units']
    cnn.networks[1].layers[0].mu.data = torch.tensor(mu0s, dtype=torch.float32)
    cnn.networks[1].layers[0].Qmu.data[:,0] = torch.tensor(Qmu0s, dtype=torch.float32)
    cnn.networks[1].layers[0].fit_mus(True)
    cnn.networks[1].layers[0].fit_Qmus(False)
    cnn.networks[1].layers[0].Qsample_mode = angle_mode
    cnn.block_sample = True

    start = time.time()
    cnn.fit(data, **adam_pars, verbose=False)
    end = time.time()

    LL_dict[str(num_LGN_units[i])+' LGN Units'] = cnn.eval_models(data, data_inds=data.val_blks, device=device, batch_size=3, null_adjusted=True)

    print('Model', i, 'done fitting')

  ReadoutLayer: fitting mus
  ReadoutLayer: not fitting Qmus


Eval models: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:33<00:00,  1.02it/s]


Model 0 done fitting
  ReadoutLayer: fitting mus
  ReadoutLayer: not fitting Qmus


Eval models: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:37<00:00,  1.09s/it]


Model 1 done fitting
  ReadoutLayer: fitting mus
  ReadoutLayer: not fitting Qmus
