In [None]:
import os
import torch
import numpy as np

device = torch.device('cuda')

In [None]:
from minimodel import data
mouse_id = 4

data_path = './data'
weight_path = './checkpoints'
np.random.seed(1)

In [None]:
# load images
img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id])

In [None]:
# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_neurons = spks.shape

In [None]:
# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

In [None]:
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

In [None]:
ineuron = 3218
ineur = [ineuron]
spks_train = torch.from_numpy(spks[itrain][:,ineur])
spks_val = torch.from_numpy(spks[ival][:,ineur]) 
print('spks_train: ', spks_train.shape, spks_train.min(), spks_train.max())
print('spks_val: ', spks_val.shape, spks_val.min(), spks_val.max())

img_train = torch.from_numpy(img[istim_train][itrain]).to(device).unsqueeze(1) # change :130 to 25:100 
img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

print('img_train: ', img_train.shape, img_train.min(), img_train.max())
print('img_val: ', img_val.shape, img_val.min(), img_val.max())
print('img_test: ', img_test.shape, img_test.min(), img_test.max())

input_Ly, input_Lx = img_train.shape[-2:]

In [None]:
# build model
from minimodel import model_builder
nlayers = 2
nconv1 = 16
nconv2 = 64
seed = 1
hs_readout = 0.03
wc_coef = 0.2
model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, Wc_coef=wc_coef)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, seed=seed,hs_readout=hs_readout)

model_path = os.path.join(weight_path, model_name)
model = model.to(device)

In [None]:
# train model
from minimodel import model_trainer
if not os.path.exists(model_path):
    if mouse_id == 5: pretrained_model_path = os.path.join(weight_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_2layer_16_320_clamp_norm_depthsep_pool_xrange_176.pt')
    else: pretrained_model_path = os.path.join(weight_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_2layer_16_320_clamp_norm_depthsep_pool.pt')
    print('pretrained_model_path: ', pretrained_model_path)
    pretrained_state_dict = torch.load(pretrained_model_path, map_location=device)
    # initialize conv1 with the fullmodel weights
    model.core.features.layer0.conv.weight.data = pretrained_state_dict['core.features.layer0.conv.weight']
    model.core.features.layer0.conv.weight.requires_grad = False
    best_state_dict = model_trainer.train(model, spks_train, spks_val, img_train, img_val, device=device, l2_readout=0.2, hs_readout=hs_readout)
    torch.save(best_state_dict, model_path)
    print('saved model', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

In [None]:
# test model
test_pred = model_trainer.test_epoch(model, img_test)
print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

In [None]:
from minimodel import metrics
spks_rep = []
for i in range(len(spks_rep_all)):
    spks_rep.append(spks_rep_all[i][:,ineur])
test_fev, test_feve = metrics.feve(spks_rep, test_pred)
# print('FEV (test): ', np.mean(test_fev))
print('FEVE (test): ', np.mean(test_feve))

# check Wc

In [None]:
Wc = model.readout.Wc.detach().cpu().numpy().squeeze()

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(3,3))
ax.plot(np.sort(Wc))
ax.set_title('Wc')
plt.show()

# check fullmodel performance

In [None]:
from minimodel import model_builder
nlayers = 2
nconv1 = 16
nconv2 = 320
fullmodel, in_channels = model_builder.build_model(NN=n_neurons, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)
model_path = os.path.join(weight_path, model_name)

fullmodel.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
fullmodel = fullmodel.to(device)

In [None]:
test_pred = model_trainer.test_epoch(fullmodel, img_test)
print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

In [None]:
from minimodel import metrics
test_fev, test_feve = metrics.feve(spks_rep, test_pred[:, ineur])
print('FEVE (test): ', np.mean(test_feve))

# calculate category variance (FECV)

In [None]:
# load txt16 data
fname = 'text16_%s_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
dat = np.load(os.path.join(data_path, fname), allow_pickle=True)
txt16_spks_test = dat['ss_all']
nstim, nrep, nneuron = txt16_spks_test.shape
txt16_istim_test = dat['ss_istim'].astype(int)
txt16_istim_test = np.repeat(txt16_istim_test[:, np.newaxis], nrep, axis=1).flatten()
txt16_labels_test = dat['ss_labels']
txt16_labels_test = np.repeat(txt16_labels_test[:, np.newaxis], nrep, axis=1).flatten()

print('txt16_labels_test shape:', txt16_labels_test.shape)

txt16_istim_train = dat['istim'].astype(int)
txt16_labels_train = dat['labels']

print('txt16_labels_train shape:', txt16_labels_train.shape)

txt16_labels = np.hstack((txt16_labels_train, txt16_labels_test))
txt16_istim = np.hstack((txt16_istim_train, txt16_istim_test))

In [None]:
# load txt16 images
img = data.load_images(data_path, mouse_id, file='nat60k_text16.mat', normalize=False)
txt16_img = img[txt16_istim]
# zscore txt16_imgs
img_mean = txt16_img.mean()
img_std = txt16_img.std()
txt16_img_zscore = (txt16_img - img_mean) / img_std
txt16_img_zscore = torch.from_numpy(txt16_img_zscore).to(device).unsqueeze(1)
print(txt16_img_zscore.shape)

In [None]:
txt16_pred = model_trainer.test_epoch(model, txt16_img_zscore)
print('test pred:', txt16_pred.shape)

In [None]:
catvar = metrics.fecv_pairwise(txt16_pred.T, txt16_labels)
print(f'FECV (neuron {ineur[0]}): {catvar[0]:.4f}')

# visualize neuron

In [None]:
# find the unique train images
Nimgs_unique = img_train.shape[0]

# get conv2 features of train images (in batches)
model.eval()
batch_size = 160
nconv2 = 64
conv2_fvs = np.zeros((Nimgs_unique, nconv2))
for i in range(0, Nimgs_unique, batch_size):
    images = img_train[i:i+batch_size].to(device)
    conv2_fv = model.core(images)
    wxy_fv = torch.einsum('iry, irx, ncyx -> ncr', model.readout.Wy, model.readout.Wx, conv2_fv).detach().cpu().numpy().squeeze()
    conv2_fvs[i:i+batch_size] = wxy_fv

# sort the features and select top 8 image for each channel
fv_isort = np.argsort(-conv2_fvs, axis=0)
Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
ivalid_Wc = np.where(np.abs(Wc)>0.01)[0]
print('ivalid_Wc:', len(ivalid_Wc))
fv_isort = fv_isort[:, ivalid_Wc]
fv_isort_top8 = fv_isort[:8]
Nimg, Nchannel = fv_isort_top8.shape

# get mask of the images
from minimodel.utils import get_image_mask
ineuron_mask_up = get_image_mask(model, Ly=input_Ly, Lx=input_Lx)

# get predictions from the training set
neuron_activity_model = model_trainer.test_epoch(model, img_train)
neuron_activity_model = neuron_activity_model.squeeze()
prediction_isort = np.argsort(neuron_activity_model)[::-1]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from minimodel.utils import add_channel_frame

# Parameters for the second plot
pad = 5
vmin = 0
vmax = 255
valid_wc = Wc[ivalid_Wc]
isort = np.argsort(valid_wc)[::-1]
Nchannel = 8
import matplotlib as mpl
mpl.rcParams['font.family'] = 'Arial'
# Combined plot layout
fig = plt.figure(figsize=(Nimg * 2 + 20, Nchannel * 1.1))
gs = plt.GridSpec(Nchannel, Nimg + 4, figure=fig, hspace=0.3, wspace=0.1, width_ratios=[1, 1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])

# Plot one (4x4 grid on the left side, occupying 2 rows per row)
nshow = 16
for i in range(nshow):
    row = (i // 4) * 2
    col = i % 4
    ax = fig.add_subplot(gs[row:row + 2, col])
    ax.imshow(img_train[prediction_isort[i]].cpu().numpy().squeeze() * ineuron_mask_up, cmap='gray', vmin=-1, vmax=1)
    ax.axis('off')

# Plot two (8xNimg grid on the right side)
axs = np.empty((Nchannel, Nimg), dtype=object)
for i in range(Nchannel):
    if i < 6:
        ichannel = i
    else:
        ichannel = -(Nchannel - i)
    for j in range(Nimg):
        axs[i, j] = fig.add_subplot(gs[i, j + 4])
        # ax = axs[i, j + 4]  # Offset by 4 columns to place it on the right side
        axs[i, j].imshow(img_train[fv_isort_top8[j, isort[ichannel]]].cpu().numpy().squeeze() * ineuron_mask_up, cmap='gray', vmin=-1, vmax=1)
        axs[i, j].axis('off')
    wc_value = valid_wc[isort[ichannel]]
    # Determine the frame color and linewidth based on valid_wc[isort[ichannel]]
    if wc_value > 0:
        color = 'red'
    else:
        color = 'blue'
    add_channel_frame(axs, i, 0, Nimg - 1, color, np.abs(valid_wc[isort[ichannel]]/np.max(np.abs(valid_wc))))

    ax = axs[i, Nimg - 1]  # Rightmost axis in the row
    if ichannel < 0: ichannel = len(valid_wc) + ichannel
    ax.text(1.1, 0.5, f'channel {ichannel+1}', transform=ax.transAxes,
            verticalalignment='center', fontsize=16, color='black', alpha=0.8)
plt.suptitle(f'neuron {ineur[0]}, FEVE={test_feve[0]:.3f}, FECV={catvar[0]:.3f}', fontsize=18)
plt.show()