In [None]:
import os
import torch
import numpy as np
from scipy.stats import pearsonr
from minimodel import data, metrics, model_builder, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../../data'
weight_path = '../checkpoints/fullmodel'

# load all

In [None]:
op_list = ['conv2_1x1', 'conv2_spatial', 'conv2_relu', 'Wxy', 'elu']
catvar_data_path = os.path.join('outputs', 'catvar', 'monkey') # where the catvar files saved
nneurons = 166
catvar_all = np.zeros((nneurons, len(op_list)))
channel_activities_all = np.zeros((nneurons, 12800, 64))
wc_all = np.zeros((nneurons, 64))
ineurons = np.arange(nneurons)
pred_all = np.zeros((nneurons, 12800))
for i, ineuron in enumerate(ineurons):
    ineur = [ineuron]
    file_name = f'monkey_minimodel_16_64_pairwise_catvar_neuron{ineur[0]}_result.npz'
    file_path = os.path.join(catvar_data_path, file_name)
    dat = np.load(file_path, allow_pickle=True)
    channel_activities = dat['channel_activities']
    op_names = dat['op_names']
    catvar = dat['catvar']
    catvar_all[i] = catvar
    channel_activities_all[i] = channel_activities
    pred_all[i] = dat['pred'].squeeze()

    file_path = f'outputs/minimodel_monkey_result.npz'
    dat = np.load(file_path, allow_pickle=True)
    wc_all[i] = dat['wc_all'][i]


In [None]:
# correlaton of positive channels
# Initialize list to store mean correlations for each neuron
mean_correlations = []

# Loop through each neuron
for neuron_idx in range(channel_activities_all.shape[0]):
    # Get the weights and channel activities for the current neuron
    weights = wc_all[neuron_idx]
    activities = channel_activities_all[neuron_idx]
    
    # Get the indices of channels with positive weights
    positive_indices = np.where(weights > 0.01)[0]
    
    if len(positive_indices) > 1:
        # Calculate correlations between channels with positive weights
        correlations = []
        for i in range(len(positive_indices)):
            for j in range(i + 1, len(positive_indices)):
                idx1, idx2 = positive_indices[i], positive_indices[j]
                corr, _ = pearsonr(activities[:, idx1], activities[:, idx2])
                correlations.append(corr)
        
        # Calculate the mean correlation for the current neuron
        mean_correlation = np.mean(correlations)
    else:
        # If there are less than 2 channels with positive weights, correlation is not defined
        mean_correlation = 0
    
    # Append the mean correlation to the list
    mean_correlations.append(mean_correlation)

# Convert the list to a numpy array for easier handling
mean_correlations = np.array(mean_correlations)

In [None]:
dat = np.load(f'outputs/minimodel_monkey_result.npz', allow_pickle=True)
feve_all = dat['feve_all']
test_pred_all = dat['test_pred_all']
wc_all = dat['wc_all']
nconv2 = np.sum(np.abs(wc_all) > 0.01, axis=1)
print(test_pred_all.shape)

# conv1 catvar

In [None]:
# load txt16 data
mouse_id = 4
fname = 'text16_%s_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
dat = np.load(os.path.join(data_path, fname), allow_pickle=True)
txt16_spks_test = dat['ss_all']
nstim, nrep, nneuron = txt16_spks_test.shape
txt16_istim_test = dat['ss_istim'].astype(int)
txt16_istim_test = np.repeat(txt16_istim_test[:, np.newaxis], nrep, axis=1).flatten()
txt16_spks_test = txt16_spks_test.reshape(-1, nneuron)
txt16_labels_test = dat['ss_labels']
txt16_labels_test = np.repeat(txt16_labels_test[:, np.newaxis], nrep, axis=1).flatten()

print('txt16_spks_test shape:', txt16_spks_test.shape)
print('txt16_labels_test shape:', txt16_labels_test.shape)


txt16_spks_train = dat['sp'].T
txt16_istim_train = dat['istim'].astype(int)
txt16_labels_train = dat['labels']

print('txt16_spks_train shape:', txt16_spks_train.shape)
print('txt16_labels_train shape:', txt16_labels_train.shape)

txt16_spks = np.vstack((txt16_spks_train, txt16_spks_test))
txt16_labels = np.hstack((txt16_labels_train, txt16_labels_test))
txt16_istim = np.hstack((txt16_istim_train, txt16_istim_test))

print('txt16_spks shape:', txt16_spks.shape)
print('txt16_labels shape:', txt16_labels.shape)

from scipy.stats import zscore
txt16_spks = zscore(txt16_spks, axis=0)
pred_catvar = metrics.category_variance_pairwise(pred_all, txt16_labels)

In [None]:
conv1_op_list = ['conv1', 'conv1_relu', 'conv1_pool']

# load txt16 images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, 'nat60k_text16.mat'), normalize=False)
# img = data.load_images_mat(img_root, file='nat60k_text16.mat', downsample=1, normalize=True, crop=False, origin=True)[0]
print('img: ', img.shape, img.min(), img.max(), img.dtype)

txt16_img = img[txt16_istim]
print(txt16_img.shape, txt16_img.max(), txt16_img.min())

# zscore txt16_imgs
img_mean = txt16_img.mean()
img_std = txt16_img.std()
txt16_img_zscore = (txt16_img - img_mean) / img_std
print(txt16_img_zscore.shape, txt16_img_zscore.max(), txt16_img_zscore.min())
txt16_img_zscore = torch.from_numpy(txt16_img_zscore).to(device).unsqueeze(1)
print(txt16_img_zscore.shape)

In [None]:
# build model
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 64
wc_coef = 0.2
hs_readout = 0.003
l2_readout = 0.2
Lx, Ly = 80, 80
model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly, Wc_coef=wc_coef)
model_name = model_builder.create_model_name('monkeyV1', '2019', ineuron=ineuron, n_layers=nlayers, in_channels=in_channels, seed=seed, 
                                    hs_readout=hs_readout)
weight_path = os.path.join(weight_path, 'minimodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

model.eval()
conv1_fvs = model.core.features.layer0.conv(txt16_img_zscore)
print('after conv1: ', conv1_fvs.shape, conv1_fvs.max(), conv1_fvs.min())
conv1_bn_fvs = model.core.features.layer0.norm(conv1_fvs)
print('after conv1_bn: ', conv1_bn_fvs.shape, conv1_bn_fvs.max(), conv1_bn_fvs.min())
conv1_relu_fvs = model.core.features.layer0.activation(conv1_bn_fvs)
print('after conv1_relu: ', conv1_relu_fvs.shape, conv1_relu_fvs.max(), conv1_relu_fvs.min())
conv1_pool_fvs = model.core.features.layer0.pool(conv1_relu_fvs)
print('after conv1_pool: ', conv1_pool_fvs.shape, conv1_pool_fvs.max(), conv1_pool_fvs.min())

conv1_fvs_all = [conv1_fvs.cpu().detach().numpy(), conv1_relu_fvs.cpu().detach().numpy(), conv1_pool_fvs.cpu().detach().numpy()]

conv1_catvar_all = np.zeros(len(conv1_op_list))
for i in range(len(conv1_op_list)):
    fv = conv1_fvs_all[i].reshape(conv1_fvs_all[i].shape[0], -1) # (nstim, nfeatures)
    cat_var = metrics.category_variance_pairwise(fv.T, txt16_labels)
    conv1_catvar_all[i] = np.nanmean(cat_var)

print(conv1_catvar_all)

In [None]:
nneurons = catvar_all.shape[0]
conv1_catvar_all = np.repeat(conv1_catvar_all[np.newaxis, :], nneurons, axis=0)
print(conv1_catvar_all.shape, catvar_all.shape)
catvar_all = np.hstack([conv1_catvar_all, catvar_all])
print(catvar_all.shape)

op_list = conv1_op_list + op_list

# save

In [None]:
# save 
save_path = './outputs/'
fname = f'catvar_monkey_result.npz'
fpath = os.path.join(save_path, fname)
data_dict = {}
data_dict['model_catvar'] = catvar_all
data_dict['mean_correlation'] = mean_correlations
data_dict['op_names'] = op_list
data_dict['wc_all'] = wc_all
# data_dict['channel_activities_all'] = channel_activities_all
data_dict['pred_catvar'] = pred_catvar
np.savez(fpath, **data_dict)
print(f'saved to {fpath}')