In [None]:
import os
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path = './data'
weight_path = './checkpoints'

np.random.seed(1)

In [None]:
ineuron = 85

# load data
dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses'][:, ineuron][:, None]
real_responses = dat['real_responses'][:, ineuron][:, None]
test_images = dat['test_images']
test_responses = dat['test_responses'][:, :, ineuron][:, :, None]
test_real_responses = dat['test_real_responses'][:, :, ineuron][:, :, None]
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = [dat['repetitions'][ineuron]]
monkey_id = dat['subject_id']
image_ids = dat['image_ids']

# normalize responses
responses_nan = np.where(real_responses, responses, np.nan)
resp_std = np.nanstd(responses_nan)
responses = responses / resp_std
test_responses = test_responses / resp_std
    
train_images = images[train_idx]
val_images = images[val_idx]
train_responses = responses[train_idx]
val_responses = responses[val_idx]
train_real_responses = real_responses[train_idx]
val_real_responses = real_responses[val_idx]

print('train:', train_images.shape, train_responses.shape, train_real_responses.shape)
print('val:', val_images.shape, val_responses.shape, val_real_responses.shape)
print('test:', test_images.shape, test_responses.shape, test_real_responses.shape)

print('resp:', responses.min(), responses.max())
print('test resp:', test_responses.min(), test_responses.max())

test_responses = np.where(test_real_responses, test_responses, np.nan)

NN = train_responses.shape[1]
Lx, Ly = train_images.shape[2], train_images.shape[3]

In [None]:
train_images = torch.from_numpy(train_images)
val_images = torch.from_numpy(val_images)
train_responses = torch.from_numpy(train_responses)
val_responses = torch.from_numpy(val_responses)
train_real_responses = torch.from_numpy(train_real_responses)
val_real_responses = torch.from_numpy(val_real_responses)

In [None]:
# build model
from minimodel import model_builder
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 64
wc_coef = 0.2
hs_readout = 0.004
l2_readout = 0.2
model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly, Wc_coef=wc_coef)
model_name = model_builder.create_model_name('monkeyV1', '2019', ineuron=ineuron, n_layers=nlayers, in_channels=in_channels, seed=seed, hs_readout=hs_readout)
model_path = os.path.join(weight_path, 'minimodel', model_name)
print('model path: ', model_path)
model = model.to(device)

In [None]:
if not os.path.exists(model_path):
    # initialize model conv1
    pretrained_model_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1_2019_2layer_16_320_clamp_norm_depthsep_pool.pt')
    pretrained_state_dict = torch.load(pretrained_model_path, map_location=device)
    model.core.features.layer0.conv.weight.data = pretrained_state_dict['core.features.layer0.conv.weight']
    # set the weight fix
    model.core.features.layer0.conv.weight.requires_grad = False

    from minimodel import model_trainer
    best_state_dict = model_trainer.monkey_train(model, train_responses, train_real_responses, val_responses, val_real_responses, train_images, \
                                                    val_images, device=device, hs_readout=hs_readout, l2_readout=l2_readout)
    torch.save(best_state_dict, model_path)
    print('model saved', model_path)

In [None]:
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

In [None]:
from minimodel import model_trainer
test_images = torch.from_numpy(test_images).to(device)
spks_pred_test = model_trainer.test_epoch(model, test_images)
print('predctions:', spks_pred_test.shape, spks_pred_test.min(), spks_pred_test.max())

In [None]:
from minimodel import metrics
test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
print('FEVE (test):', np.mean(test_feve))

# check Wc

In [None]:
Wc = model.readout.Wc.detach().cpu().numpy().squeeze()

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(3,3))
ax.plot(np.sort(Wc))
ax.set_title('Wc')
plt.show()

# check fullmodel performance

In [None]:
from minimodel import model_builder
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
fullmodel, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels)
model_path = os.path.join(weight_path, 'fullmodel', model_name)
print('model path: ', model_path)

fullmodel.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
fullmodel = fullmodel.to(device)

In [None]:
from minimodel import model_trainer
# test_images = torch.from_numpy(test_images).to(device)
spks_pred_test = model_trainer.test_epoch(fullmodel, test_images)
print('predctions:', spks_pred_test.shape, spks_pred_test.min(), spks_pred_test.max())

In [None]:
from minimodel import metrics
nstim = spks_pred_test.shape[0]
test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test[:, ineuron].reshape((nstim, 1)), repetitions)
print('FEVE (test):', np.mean(test_feve))

# category variance (FECV)

In [None]:
# load txt16 data
from minimodel import data
fname = 'text16_%s_%s.npz'%(data.db[3]['mname'], data.db[3]['datexp'])
dat = np.load(os.path.join(data_path, fname), allow_pickle=True)
txt16_spks_test = dat['ss_all']
nstim, nrep, nneuron = txt16_spks_test.shape
txt16_istim_test = dat['ss_istim'].astype(int)
txt16_istim_test = np.repeat(txt16_istim_test[:, np.newaxis], nrep, axis=1).flatten()
txt16_labels_test = dat['ss_labels']
txt16_labels_test = np.repeat(txt16_labels_test[:, np.newaxis], nrep, axis=1).flatten()

print('txt16_labels_test shape:', txt16_labels_test.shape)

txt16_istim_train = dat['istim'].astype(int)
txt16_labels_train = dat['labels']

print('txt16_labels_train shape:', txt16_labels_train.shape)

txt16_labels = np.hstack((txt16_labels_train, txt16_labels_test))
txt16_istim = np.hstack((txt16_istim_train, txt16_istim_test))

In [None]:
# load txt16 images
img = data.load_images(data_path, 3, file=os.path.join(data_path, 'nat60k_text16.mat'), normalize=False)
txt16_img = img[txt16_istim]

# resize all images to 80x80
txt16_img = img[txt16_istim]
xrange = np.arange(22, 22+66)
txt16_img = txt16_img[:, :, xrange]
print(txt16_img.shape, txt16_img.max(), txt16_img.min())

import cv2
txt16_img = np.array([cv2.resize(img, (80, 80)) for img in txt16_img])
print(txt16_img.shape, txt16_img.max(), txt16_img.min())

# zscore txt16_imgs
img_mean = txt16_img.mean()
img_std = txt16_img.std()
txt16_img_zscore = (txt16_img - img_mean) / img_std
txt16_img_zscore = torch.from_numpy(txt16_img_zscore).to(device).unsqueeze(1)
print(txt16_img_zscore.shape)

txt16_pred = model_trainer.test_epoch(model, txt16_img_zscore)
print('test pred:', txt16_pred.shape)

In [None]:
catvar = metrics.fecv_pairwise(txt16_pred.T, txt16_labels)
print(f'FECV (neuron {ineuron}): {catvar[0]:.3f}')

# visualize neuron

In [None]:
# find the unique train images
train_image_ids = image_ids[train_idx]
unique_idxes = np.unique(train_image_ids).astype(np.int64)
img_train = train_images[unique_idxes].to(device)
Nimgs_unique = img_train.shape[0]

# get conv2 features of train images (in batches)
model.eval()
batch_size = 160
nconv2 = 64
conv2_fvs = np.zeros((Nimgs_unique, nconv2))
for i in range(0, Nimgs_unique, batch_size):
    images = img_train[i:i+batch_size].to(device)
    conv2_fv = model.core(images)
    wxy_fv = torch.einsum('iry, irx, ncyx -> ncr', model.readout.Wy, model.readout.Wx, conv2_fv).detach().cpu().numpy().squeeze()
    conv2_fvs[i:i+batch_size] = wxy_fv

# sort the features and select top 8 image for each channel
fv_isort = np.argsort(-conv2_fvs, axis=0)
Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
ivalid_Wc = np.where(np.abs(Wc)>0.01)[0]
print('ivalid_Wc:', len(ivalid_Wc))
fv_isort = fv_isort[:, ivalid_Wc]
fv_isort_top8 = fv_isort[:8]
Nimg, Nchannel = fv_isort_top8.shape

# get mask of the images
from minimodel.utils import get_image_mask
ineuron_mask_up = get_image_mask(model, Ly=Ly, Lx=Lx)

# get predictions from the training set
neuron_activity_model = model_trainer.test_epoch(model, img_train)
neuron_activity_model = neuron_activity_model.squeeze()
prediction_isort = np.argsort(neuron_activity_model)[::-1]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
import matplotlib as mpl
mpl.rcParams['font.family'] = 'Arial'
# Function to add a frame around a channel
def add_channel_frame(axs, row, col_start, col_end, color, alpha, monkey=False):
    ax = axs[row, col_start]  # Leftmost axis in the row
    if monkey: adjust_value = 1.64
    else: adjust_value = 1.33
    # Rectangle coordinates (x, y) and dimensions (width, height)
    rect = patches.Rectangle(
        (-0.025, -0.05), (col_end - col_start + 1)*adjust_value , 1.1, transform=ax.transAxes,
        color=color, fill=False, linewidth=3, zorder=10, alpha=alpha,
        clip_on=False  # To ensure it draws outside the axes
    )
    ax.add_patch(rect)


# Parameters for the second plot
pad = 5
vmin = 0
vmax = 255
valid_wc = Wc[ivalid_Wc]
isort = np.argsort(valid_wc)[::-1]
Nchannel = np.min([len(valid_wc), 8])

# Combined plot layout
fig = plt.figure(figsize=(Nimg + 15, 8 * 1.1))
gs = plt.GridSpec(8, Nimg + 4, figure=fig, hspace=0.3, wspace=0.1, width_ratios=[1, 1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])

# Plot one (4x4 grid on the left side, occupying 2 rows per row)
nshow = 16
for i in range(nshow):
    row = (i // 4) * 2
    col = i % 4
    ax = fig.add_subplot(gs[row:row + 2, col])
    ax.imshow(img_train[prediction_isort[i]].cpu().numpy().squeeze() * ineuron_mask_up, cmap='gray', vmin=-1, vmax=1)
    ax.axis('off')

# Plot two (8xNimg grid on the right side)
axs = np.empty((8, Nimg), dtype=object)
for i in range(Nchannel):
    if i < 6:
        ichannel = i
    else:
        ichannel = -(Nchannel - i)
    for j in range(Nimg):
        axs[i, j] = fig.add_subplot(gs[i, j + 4])
        axs[i, j].imshow(img_train[fv_isort_top8[j, isort[ichannel]]].cpu().numpy().squeeze() * ineuron_mask_up, cmap='gray', vmin=-1, vmax=1)
        axs[i, j].axis('off')
    wc_value = valid_wc[isort[ichannel]]
    # Determine the frame color and linewidth based on valid_wc[isort[ichannel]]
    if wc_value > 0:
        color = 'red'
    else:
        color = 'blue'
    add_channel_frame(axs, i, 0, Nimg - 1, color, np.abs(valid_wc[isort[ichannel]]/np.max(np.abs(valid_wc))), monkey=True)

    ax = axs[i, Nimg - 1]  # Rightmost axis in the row
    if ichannel < 0: ichannel = len(valid_wc) + ichannel
    ax.text(1.2, 0.5, f'channel {ichannel+1}', transform=ax.transAxes,
            verticalalignment='center', fontsize=16, color='black', alpha=0.8)
plt.suptitle(f'neuron {ineuron}, FEVE={test_feve[0]:.3f}, FECV={catvar[0]:.3f}', fontsize=18)
plt.show()