# Complute GQMs to Save

In [1]:
import sys
import os

myhost = os.uname()[1]
print("Running on Computer: [%s]" %myhost)

sys.path.insert(0, '/home/ifernand/Code/') 
dirname = '/home/ifernand/Cloud_SynthData_Proj'

import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from scipy import io as sio
import torch
import time
import h5py

# NDN tools
import NDNT
import NDNT.utils as utils
from NDNT.modules.layers import *
from NDNT.networks import *
import NDNT.NDN as NDN
from NTdatasets.conway.synthcloud_datasets import SimCloudData
from NTdatasets.generic import GenericDataset
from ColorDataUtils.multidata_utils import MultiExperiment
import ColorDataUtils.ConwayUtils as CU
from ColorDataUtils import readout_fit
from ColorDataUtils.simproj_utils import *
from NDNT.utils import fit_lbfgs, fit_lbfgs_batch
from NDNT.utils import imagesc   
from NDNT.utils import ss

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device0 = torch.device("cpu")
dtype = torch.float32

%load_ext autoreload
%autoreload 2

Running on Computer: [sc]
Invoking __init__.py for NDNT.utils


## Load Data

In [2]:
# Load RF center and cell key
out_file = 'data/cloud_data_stim_dim120_robs_sqrad_0.3.hdf5' #'data/cloud_data_stim_dim80_robs_sqrad_0.1.hdf5'
with h5py.File(out_file, 'r') as f:
    x_pos = f['x_pos'][:]
    y_pos = f['y_pos'][:]
    cell_key = [str(f['cell_key'][:][i], encoding='utf-8') for i in range(x_pos.shape[0])]
    file_start_pos = list(f['file_start_pos'][:])

In [3]:
#out_file = 'data/cloud_data_stim_dim120_robs_sqrad_0.3.hdf5'
start    = time.time()
data     = SimCloudData(file_name=out_file, down_sample=2, num_lags=11)
end      = time.time()
print('CPU Time', end-start, 'sec')

CPU Time 47.18281126022339 sec


In [4]:
stim_dims = data.stim_dims
num_lags = data.num_lags
L  = stim_dims[1]
NC = data[0]['robs'].shape[1]
NT = data.NT
print('stim_dims =', stim_dims)
print('num_lags =', num_lags)
print('L =', L)
print('Number of cells', NC)
print('Number of time points', NT)

stim_dims = [1, 60, 60, 1]
num_lags = 11
L = 60
Number of cells 1705
Number of time points 510000


In [5]:
pxl_x_pos, pxl_y_pos = deg2pxl(x_pos, y_pos, L, down_sample=2)
mu0s = utils.pixel2grid(np.stack((pxl_x_pos,pxl_y_pos),axis=1), L=L)
print('Spatial mu0:', mu0s.shape)

Spatial mu0: (1705, 2)


## GQM

In [6]:
Treg = 1
XregL, XregQ = 10000.0, 10.0
LOCregL, LOCregQ = 10.0, 1000.0
L1regL, L1regQ = 1.0, None

# linear layer
glm_layer = Tlayer.layer_dict( 
    input_dims = data.stim_dims, num_filters=1, bias=False, norm_type=0, 
    NLtype='lin', initialize_center=True, num_lags=data.num_lags,
    reg_vals= {'d2x': XregL, 'd2t': Treg, 'l1': L1regL, 'glocalx': LOCregL,'edge_t':10, 'bcs':{'d2t':1, 'd2x':1}} )

# quadratic layer
gqm_layer = Tlayer.layer_dict( 
    input_dims = data.stim_dims, num_filters=2, bias=False, norm_type=0, 
    NLtype='square', initialize_center=True, num_lags=data.num_lags,
    reg_vals= {'d2x': XregQ, 'd2t': Treg, 'l1': L1regQ, 'glocalx': LOCregQ,'edge_t':10, 'bcs':{'d2t':1, 'd2x':1}} )

# set up linear and quadratic network
lin_net =  FFnetwork.ffnet_dict( xstim_n='stim', layer_list = [glm_layer] )
quad_net =  FFnetwork.ffnet_dict( xstim_n='stim', layer_list = [gqm_layer] )

# NDN layer that combines quadratic and linear layer
comb_layer = NDNLayer.layer_dict( num_filters=1, NLtype='softplus', bias=True)
comb_layer['weights_initializer'] = 'ones'

# set up combination network
comb_net = FFnetwork.ffnet_dict( xstim_n=None, ffnet_n=[0,1], layer_list=[comb_layer])

In [7]:
gqms = []
for i in range(NC):
    gdata = GenericDataset( {'stim': data[data.train_blks]['stim'], 
                         'robs': data[data.train_blks]['robs'][:,i], 
                         'dfs':data[data.train_blks]['dfs'][:,i]}, device=device)

    
    gqm = NDN( ffnet_list = [lin_net, quad_net, comb_net], loss_type='poisson')
    gqm.set_parameters(val=False,name='weight',ffnet_target=2)
    
    gqm = gqm.to(device)
    fit_lbfgs(gqm, gdata[:], verbose=False)
    gqm = gqm.to(device0)
    gqms.append(gqm)
    print('Cell', i, 'GQM complete')

    gdata = None

Cell 0 GQM complete
Cell 1 GQM complete
Cell 2 GQM complete
Cell 3 GQM complete
Cell 4 GQM complete
Cell 5 GQM complete
Cell 6 GQM complete
Cell 7 GQM complete
Cell 8 GQM complete
Cell 9 GQM complete
Cell 10 GQM complete
Cell 11 GQM complete
Cell 12 GQM complete
Cell 13 GQM complete
Cell 14 GQM complete
Cell 15 GQM complete
Cell 16 GQM complete
Cell 17 GQM complete
Cell 18 GQM complete
Cell 19 GQM complete
Cell 20 GQM complete
Cell 21 GQM complete
Cell 22 GQM complete
Cell 23 GQM complete
Cell 24 GQM complete
Cell 25 GQM complete
Cell 26 GQM complete
Cell 27 GQM complete
Cell 28 GQM complete
Cell 29 GQM complete
Cell 30 GQM complete
Cell 31 GQM complete
Cell 32 GQM complete
Cell 33 GQM complete
Cell 34 GQM complete
Cell 35 GQM complete
Cell 36 GQM complete
Cell 37 GQM complete
Cell 38 GQM complete
Cell 39 GQM complete
Cell 40 GQM complete
Cell 41 GQM complete
Cell 42 GQM complete
Cell 43 GQM complete
Cell 44 GQM complete
Cell 45 GQM complete
Cell 46 GQM complete
Cell 47 GQM complete
Ce

In [13]:
all_gqm_weights = np.zeros((NC,L,L,num_lags,3))
for i in range(NC):
    all_gqm_weights[i,:,:,:,0] = gqms[i].get_weights(ffnet_target=0)[:,:,:,0] # lin weights
    all_gqm_weights[i,:,:,:,1] = gqms[i].get_weights(ffnet_target=1)[:,:,:,0] # quad 1 weights
    all_gqm_weights[i,:,:,:,2] = gqms[i].get_weights(ffnet_target=1)[:,:,:,1] # quad 2 weights
print(all_gqm_weights.shape)

(1705, 60, 60, 11, 3)


In [14]:
import pickle
with open('data/all_neuron_GQM_weights_sqrad_0.3.pkl', 'wb') as f:
    pickle.dump(all_gqm_weights, f)

## Complexity

In [22]:
complexity = np.zeros(NC)
for i in range(NC):
    gdata = GenericDataset( {'stim': data[data.val_blks]['stim'], 
                         'robs': data[data.val_blks]['robs'][:,i], 
                         'dfs':data[data.val_blks]['dfs'][:,i]}, device=device)

    gqm = gqms[i].to(device)
    lin_pred = gqm.networks[0](gdata[:]['stim']).detach().cpu().numpy()[:,0]
    quad_pred = np.sum(gqm.networks[1](gdata[:]['stim']).detach().cpu().numpy(),axis=1) # + gqm[i].networks[2](gdata[:]['stim']).detach().cpu().numpy()
    gqm = gqm.to(device0)

    lin_var = np.var(lin_pred)
    quad_var = np.var(quad_pred)
    complexity[i] = lin_var/(lin_var+quad_var)

    print('Cell', i, 'complexity computed')
    gdata = None

Cell 0 complexity computed
Cell 1 complexity computed
Cell 2 complexity computed
Cell 3 complexity computed
Cell 4 complexity computed
Cell 5 complexity computed
Cell 6 complexity computed
Cell 7 complexity computed
Cell 8 complexity computed
Cell 9 complexity computed
Cell 10 complexity computed
Cell 11 complexity computed
Cell 12 complexity computed
Cell 13 complexity computed
Cell 14 complexity computed
Cell 15 complexity computed
Cell 16 complexity computed
Cell 17 complexity computed
Cell 18 complexity computed
Cell 19 complexity computed
Cell 20 complexity computed
Cell 21 complexity computed
Cell 22 complexity computed
Cell 23 complexity computed
Cell 24 complexity computed
Cell 25 complexity computed
Cell 26 complexity computed
Cell 27 complexity computed
Cell 28 complexity computed
Cell 29 complexity computed
Cell 30 complexity computed
Cell 31 complexity computed
Cell 32 complexity computed
Cell 33 complexity computed
Cell 34 complexity computed
Cell 35 complexity computed
Ce

In [23]:
np.save('data/all_neuron_complexity_sqrad_0.3.npy', complexity)