In [None]:
import os
import torch
import numpy as np
from minimodel import data, model_builder, metrics, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../notebooks/data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

# figure 1: retinotopy

In [None]:
# load model
mouse_id = 0
xpos_all = []
ypos_all = []
xpos_visual_all = []
ypos_visual_all = []
x_pixel_ratio = 0.75
y_pixel_ratio = 0.5
fev_all = []
for mouse_id in range(6):
    fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
    spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
    xpos_all.append(xpos)
    ypos_all.append(ypos)
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    from minimodel.utils import weight_bandwidth
    Wx = dat['fullmodel_Wx']
    Wy = dat['fullmodel_Wy']
    NN = Wx.shape[0]
    bandwidth_Wx = np.zeros(NN)
    bandwidth_Wy = np.zeros(NN)
    centerpos_Wx = np.zeros(NN)
    centerpos_Wy = np.zeros(NN)
    for i in range(NN):
        bandwidth_Wx[i], centerpos_Wx[i] = weight_bandwidth(Wx[i, :], return_peak=True)
        bandwidth_Wy[i], centerpos_Wy[i] = weight_bandwidth(Wy[i, :], return_peak=True)
    xpos_model = np.argmax(Wx.squeeze(), axis=1) 
    ypos_model = np.argmax(Wy.squeeze(), axis=1) 
    xpos_visual = centerpos_Wx*2*(270/264) - 135 # 0 is at the center, so it should be 135 pixels
    ypos_visual = centerpos_Wy*2*(65/66) - 32.5 # vertical visual range is 65, so it sould be (66/65) pixels per degree
    if mouse_id == 5: xpos_visual += (46*270/264) # xrange of mouse 6 is 46-176
    xpos = xpos / x_pixel_ratio
    ypos = ypos / y_pixel_ratio
    if mouse_id == 1:
        idx_up = np.where(xpos>325/ x_pixel_ratio)[0]
        idx_down = np.where(xpos<=325/ x_pixel_ratio)[0]
        ymax = ypos[idx_up].max()
        xmax, xmin = xpos[idx_up].max(), xpos[idx_up].min()
        ypos[idx_up] = ymax - ypos[idx_up] +300 # + ymax
    xpos_visual_all.append(xpos_visual)
    ypos_visual_all.append(ypos_visual)
    fev_all.append(dat['fev'])
data_dict['xpos_all'] = np.array(xpos_all)
data_dict['ypos_all'] = np.array(ypos_all)
data_dict['xpos_visual_all'] = np.array(xpos_visual_all)
data_dict['ypos_visual_all'] = np.array(ypos_visual_all)
data_dict['fev_all'] = np.array(fev_all)

In [None]:
from pathlib import Path

db = []

db.append({'mname': 'L1_A5', 'datexp': '2023_01_27', 'blk':'3', 'stim':'short3'})
db.append({'mname': 'L1_A1', 'datexp': '2023_03_27', 'blk':'1', 'stim':'short3'})
db.append({'mname': 'FX9', 'datexp': '2023_05_02', 'blk':'2', 'stim':'short3'})
db.append({'mname': 'FX10', 'datexp': '2023_05_02', 'blk':'1', 'stim':'short3'})
db.append({'mname': 'FX8', 'datexp': '2023_05_02', 'blk':'1', 'stim':'short3'})
db.append({'mname': 'FX20', 'datexp': '2023_09_08', 'blk':'2','stim':'nat15k'})

iregion_all = []
iarea_all = []
xy_all = []
out_all = []
jxy_all = []

for i, mouse in enumerate(db[:6]):
    aligned_path = Path(os.path.join(result_path, 'Ret_maps'))
    aligned_path = aligned_path.joinpath(f"{db[i]['mname']}_{db[i]['datexp']}_{db[i]['blk']}_{db[i]['stim']}.npz")

    m = np.load(aligned_path, allow_pickle=True)
    iregion = m['iregion']
    iarea = m['iarea']
    xy = m['xy']
    out = m['out']
    jxy = m['jxy']
    iregion_all.append(iregion)
    iarea_all.append(iarea)
    xy_all.append(xy)
    out_all.append(out)
    jxy_all.append(jxy)
data_dict['iregion'] = np.array(iregion_all)
data_dict['iarea'] = np.array(iarea_all)
data_dict['xy'] = np.array(xy_all)  
data_dict['out'] = np.array(out_all)
data_dict['jxy'] = np.array(jxy_all)


In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure1(data_dict, save_path)

# figure 2: signal variance

In [None]:
# Example neuron Trial 1 vs trial 2
mouse_id = 1
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
# spks, spks_rep, itrain, ival, fev, istim_test = load_activity(file_path=mouse_file_paths[mouse_id], mouse_id=mouse_id)
# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
fev_all = metrics.fev(spks_rep_all)

ineurons = [2550, 5020, 264]
spks_rep = [spks_rep_all[i][:,ineurons] for i in range(len(spks_rep_all))]
data_dict['example_repeats'] = np.stack(spks_rep)
data_dict['example_fev'] = fev_all[ineurons]

In [None]:
# Signal variance distribution mouse
fev_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev_all.append(dat['fev'])
fev_all = np.hstack(fev_all)
data_dict['fev_all'] = fev_all
print(np.mean(fev_all))
print(np.mean(fev_all[fev_all>0.15]))

## monkey

In [None]:
# load data
dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses']
real_responses = dat['real_responses']
test_images = dat['test_images']
test_responses = dat['test_responses']
test_real_responses = dat['test_real_responses']
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = dat['repetitions']
monkey_ids = dat['subject_id']
image_ids = dat['image_ids']

# normalize responses
responses_nan = np.where(real_responses, responses, np.nan)
resp_std = np.nanstd(responses_nan, axis=0)
responses = responses / resp_std
test_responses = test_responses / resp_std

test_responses_nan = np.where(test_real_responses, test_responses, np.nan)
print('test_responses_nan: ', test_responses_nan.shape)

# spks_test_mean = np.nanmean(test_responses_nan, axis=0)
# print(spks_test_mean.shape)
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
data_dict['monkey_fev_all'] = dat['fev_all']
ineurons = [51, 105, 86]
data_dict['monkey_example_repeats'] = test_responses_nan[:, :, ineurons]
data_dict['monkey_example_fev'] = dat['fev_all'][ineurons]
data_dict['monkey_example_reps'] = repetitions[ineurons]

In [None]:
ineurons = [51, 105, 86]
data_dict['monkey_example_repeats'] = test_responses_nan[:, :, ineurons]
data_dict['monkey_example_fev'] = dat['fev_all'][ineurons]
data_dict['monkey_example_reps'] = repetitions[ineurons]

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure2(data_dict, save_path)

# figure 4: Gabor model

In [None]:
mouse_id = 0
server_path = '/home/carsen/dm11_cluster/fengtongd/Desktop/approxineuro'
res_dict = {'fev': [], 'feve_minimodel': [], 'feve_gabor': [], 'mf': [], 'msigma': [], 'mtheta': [], 'cratio': [], 'feve_fullmodel': []}

for mouse_id in range(6):
    # load images
    data_path = os.path.join(server_path, 'data')  

    # load neurons
    fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
    spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
    n_stim, n_max_neurons = spks.shape

    # split train and validation set
    itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
    ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

    # normalize spks
    spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
    spks_val = torch.from_numpy(spks[ival][:,ineur]) 
    spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

    ineurons = np.arange(data.NNs_valid[mouse_id])
    # np.random.seed(42)
    # ineurons = np.random.choice(ineurons, 100, replace=False)

    fev_test = metrics.fev(spks_rep_all)
    isort_neurons = np.argsort(fev_test)[::-1]
    ineur = isort_neurons[ineurons]

    print(spks.shape, spks_val.shape, len(spks_rep_all), spks_rep_all[0].shape)

    spks = spks[:,ineur]
    spks_val = spks_val[:,ineur]
    spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]
    print(spks.shape, spks_val.shape, len(spks_rep_all), spks_rep_all[0].shape)

    if mouse_id == 5:
        xrange_max = 176
    else:
        xrange_max = 130 
    img_all = data.load_images(data_path, file=os.path.join(data_path, data.img_file_name[mouse_id]), xrange=[xrange_max-130, xrange_max], downsample=2)
    nimg, Ly, Lx = img_all.shape
    print('img: ', img_all.shape, img_all.min(), img_all.max())

    n_stim = -1 # spks.shape[0]
    n_neurons = -1

    # generate random data
    if n_stim > 0:
        istims = np.random.choice(spks.shape[0], n_stim, replace=False)
    else:
        n_stim = spks.shape[0]
        istims = np.arange(n_stim)
    if n_neurons > 0:
        ineurons = np.random.choice(spks.shape[1], n_neurons, replace=False)
        X_test = [spks_rep_all[i][:,ineurons] for i in range(len(spks_rep_all))]
    else:
        n_neurons = spks.shape[1]
        ineurons = np.arange(n_neurons)
        X_test = spks_rep_all.copy()
    X = spks[istims][:,ineurons]
    img = img_all[istim_train][istims].transpose(1,2,0)
    img_test = img_all[istim_test].transpose(1,2,0)
    print(f'img: {img.shape}, X: {X.shape}')
    Ly, Lx, _ = img.shape

    # define gabor parameters
    sigma = np.array([0.75, 1.25, 1.5, 2.5, 3.5, 4.5, 5.5])
    f = np.array([0.1, 0.25, 0.5, 1, 2]) #[.01:.02:.13];
    theta = np.arange(0, np.pi, np.pi/8) # [0, pi/8, pi/4, 3*pi/8];
    ph = np.arange(0, 2*np.pi, np.pi/4) # [0, pi/4, pi/2, 3*pi/4, pi, 5*pi/4, 3*pi/2, 7*pi/4];
    ar = np.array([1, 1.5, 2])
    print(f'sigma: {sigma.shape}, f: {f.shape}, theta: {theta.shape}, ph: {ph.shape}, ar: {ar.shape}')

    params = np.meshgrid(sigma, f, theta, ph, ar, indexing='ij')
    n_gabors = params[0].size
    print(f'number of gabors: {n_gabors}')

    for i in range(len(params)):
        params[i] = np.expand_dims(params[i], axis=(-2,-1))
        params[i] = torch.from_numpy(params[i].astype('float32'))
    sigma, f, theta, ph, ar = params
    print(f'sigma: {sigma.shape}, f: {f.shape}, theta: {theta.shape}, ph: {ph.shape}, ar: {ar.shape}')


    result_dict = np.load(os.path.join(server_path, 'weights', 'gabor', f'gabor_params_{data.db[mouse_id]["mname"]}.npz'), allow_pickle=True)

    xmax, ymax = result_dict['xmax'], result_dict['ymax']
    ys, xs = np.meshgrid(np.arange(0,Ly), np.arange(0,Lx), indexing='ij')
    ys, xs = torch.from_numpy(ys.astype('float32')), torch.from_numpy(xs.astype('float32'))
    gmax = result_dict['gmax']
    gabor_params = torch.zeros((5, n_neurons, 1, 1))
    for i in range(len(gabor_params)):
        gabor_params[i] = params[i].flatten()[gmax].reshape(n_neurons, 1, 1)
    msigma, mf, mtheta, mph, mar = gabor_params
    Amax = result_dict['Amax']
    mu1 = torch.from_numpy(result_dict['mu1']).to(device)
    mu2 = torch.from_numpy(result_dict['mu2']).to(device)
    #  test
    ym = torch.from_numpy(ymax.astype('float32')).unsqueeze(-1).unsqueeze(-1)
    xm = torch.from_numpy(xmax.astype('float32')).unsqueeze(-1).unsqueeze(-1)
    # print(f'ym: {ym.shape}, xm: {xm.shape}')
    gabor_params = torch.zeros((5, n_neurons, 1, 1))
    for i in range(len(gabor_params)):
        gabor_params[i] = params[i].flatten()[gmax].reshape(n_neurons, 1, 1)
    msigma, mf, mtheta, mph, mar = gabor_params
    from minimodel.gabor import gabor_filter, eval_gabors
    gabor_filters1 = gabor_filter(ys, xs, ym, xm, 1, msigma, mf, mtheta, mph, mar, is_torch=True).to(device).unsqueeze(-3)
    gabor_filters2 = gabor_filter(ys, xs, ym, xm, 1, msigma, mf, mtheta, mph + np.pi/2, mar, is_torch=True).to(device).unsqueeze(-3)

    # predict responses
    ntest = len(istim_test)
    resp_test1 = torch.zeros((n_neurons, ntest), dtype=torch.float32, device=device)
    resp_test2 = torch.zeros((n_neurons, ntest), dtype=torch.float32, device=device)
    eval_gabors(img_test, gabor_filters1, resp_test1, device=device, rectify=False)
    eval_gabors(img_test, gabor_filters2, resp_test2, device=device, rectify=False)
    resp_test2 = torch.sqrt(resp_test1**2 + resp_test2**2) # RMS for complex cell response
    from torch.nn.functional import relu
    resp_test2 = relu(resp_test2) # rectify
    resp_test1 = relu(resp_test1) # rectify

    c = torch.from_numpy(Amax).to(device)

    rpred = ((resp_test1.T - mu1) * c[:,0] + (resp_test2.T - mu2) * c[:,1]) # (n_stim, n_neurons)
    print(f'rpred: {rpred.shape}')

    # test responses
    train_mu = result_dict['train_mu']
    train_std = result_dict['train_std']
    X_test = [spks_rep_all[i][:,ineurons] for i in range(len(spks_rep_all))]
    for i in range(len(X_test)):
        X_test[i] -= train_mu
        X_test[i] /= train_std

    fev, feve_gabor = metrics.feve(X_test, rpred.cpu().numpy())
    print(f'fev:{fev.mean():.3f}, feve:{feve_gabor.mean():.3f}')

    cratio = Amax[:,1]/Amax.sum(axis=1)

    params = [mf.cpu().numpy().squeeze(), msigma.cpu().numpy().squeeze(), mtheta.cpu().numpy().squeeze()]

    # load fullmodel feve
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    # fullmodel_feve = dat['fullmodel_feve_all']
    fullmodel_feve = dat['feve_depth'][3]
    nn_all = len(fullmodel_feve)
    valid_idxes = dat['valid_idxes']
    fullmodel_feve = fullmodel_feve[ineur]

    dat = np.load(os.path.join(result_path, f'minimodel_{data.mouse_names[mouse_id]}_result.npz'), allow_pickle=True)
    minimodel_feve = np.inf * np.ones(nn_all)
    minimodel_feve[valid_idxes] = dat['feve_all']
    minimodel_feve = minimodel_feve[ineur]

    print(fev.shape, fullmodel_feve.shape, minimodel_feve.shape, feve_gabor.shape)
    print(cratio.shape)

    res_dict['fev'].append(fev)
    res_dict['feve_minimodel'].append(minimodel_feve)
    res_dict['feve_gabor'].append(feve_gabor)
    res_dict['mf'].append(params[0])
    res_dict['msigma'].append(params[1])
    res_dict['mtheta'].append(params[2])
    res_dict['cratio'].append(cratio)
    res_dict['feve_fullmodel'].append(fullmodel_feve)

In [None]:
for key in res_dict.keys():
    res_dict[key] = np.hstack(res_dict[key])
    print(key, res_dict[key].shape)

In [None]:
data_dict['fev_gabor'] = res_dict['fev']
data_dict['feve_minimodel_gabor'] = res_dict['feve_minimodel']
data_dict['feve_gabor'] = res_dict['feve_gabor']
data_dict['feve_fullmodel_gabor'] = res_dict['feve_fullmodel']
data_dict['mf'] = res_dict['mf']
data_dict['msigma'] = res_dict['msigma']
data_dict['mtheta'] = res_dict['mtheta']
data_dict['cratio'] = res_dict['cratio']          

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure_gabor(data_dict, save_path)

# figure 5: same core

## 1 layer

In [None]:
nmouse = 6
# res_path = 'outputs/feve_same_core_1layer.npz'
res_path = os.path.join(result_path, 'feve_same_core_1layer.npz')
if os.path.exists(res_path):
    dat = np.load(res_path)
    feve_ds = dat['feve_ds']
    feve_our = dat['feve_our']
else:
    feve_ds = np.zeros(nmouse)
    feve_our = np.zeros(nmouse)

    for mouse_id in range(nmouse):
        pool = False
        # load images
        if mouse_id == 5:
            xrange_max = 176
        else:
            xrange_max = 130 
        img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id], downsample=2)
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)

        # load neurons
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # split train and validation set
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

        # normalize spks
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

        spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        input_Ly, input_Lx = img_test.shape[-2:]

        seed = 1
        nlayers = 1
        nconv1 = 64
        nconv2 = 64

        suffix = ''
        if mouse_id == 5: suffix = f'xrange_{xrange_max}'
        if suffix != '': suffix += '_'
        suffix += f'downsample_2_ks_9'
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[9,7])
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        # print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_ds[mouse_id] = np.mean(test_feve[test_fev > threshold])


        # load our 1 layer 192 model
        img = data.load_images(data_path, file=os.path.join(data_path, data.img_file_name[mouse_id]), xrange=[xrange_max-130, xrange_max], downsample=1)
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)
        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)
        input_Ly, input_Lx = img_test.shape[-2:]

        nconv1 = 192
        nconv2 = 192
        pool = True

        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        # print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_our[mouse_id] = np.mean(test_feve[test_fev > threshold])
    np.savez(res_path, feve_ds=feve_ds, feve_our=feve_our)

In [None]:
data_dict['same_core_1layer_feve_ds'] = feve_ds
data_dict['same_core_1layer_feve_our'] = feve_our

In [None]:
nmouse = 6
feve_depth_all = np.zeros((nmouse, 4))
feve_LN_all = np.zeros(nmouse)
for mouse_id in range(6):
    # dat = np.load(f'fullmodel_{data.mouse_names[mouse_id]}_results.npz', allow_pickle=True)
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev = dat['fev']
    feve_depth = dat['feve_depth']
    feve_depth_all[mouse_id] = feve_depth[:, fev>0.15].mean(axis=1)
    feve_LN_all[mouse_id] = dat['LNmodel_feve_all'][fev>0.15].mean()
data_dict['feve_our_model'] = feve_depth_all
data_dict['feve_LN_model'] = feve_LN_all

# run Lurz_model_train_test.ipynb first
lurz_feve_all = np.load(os.path.join(result_path, 'lurz_feve_all.npy'))
data_dict['feve_lurz_model'] = lurz_feve_all
print(data_dict['feve_lurz_model'].mean(axis=0))

## 2 layer

In [None]:
nmouse = 6
# res_path = 'outputs/feve_same_core_2layer.npz'
res_path = os.path.join(result_path, 'feve_same_core_2layer.npz')
if os.path.exists(res_path):
    dat = np.load(res_path)
    feve_ds = dat['feve_ds']
    feve_our = dat['feve_our']
else:
    feve_ds = np.zeros(nmouse)
    feve_our = np.zeros(nmouse)
    pool = False
    for mouse_id in range(nmouse):
        pool = False
        # load images
        if mouse_id == 5:
            xrange_max = 176
        else:
            xrange_max = 130 
        img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id], downsample=2)
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)

        # load neurons
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # split train and validation set
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

        # normalize spks
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

        spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        input_Ly, input_Lx = img_test.shape[-2:]

        seed = 1
        nlayers = 2
        nconv1 = 64
        nconv2 = 64

        suffix = ''
        if suffix != '': suffix += '_'
        suffix += f'downsample_2_ks_9_7'
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[9,7])
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_ds[mouse_id] = np.mean(test_feve[test_fev > threshold])

        img = data.load_images(data_path, file=os.path.join(data_path, data.img_file_name[mouse_id]), xrange=[xrange_max-130, xrange_max])
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)
        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        input_Ly, input_Lx = img_test.shape[-2:]

        nlayers = 2
        nconv1 = 192
        nconv2 = 192
        pool = True
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_our[mouse_id] = np.mean(test_feve[test_fev > threshold])
    np.savez(res_path, feve_ds=feve_ds, feve_our=feve_our)

In [None]:
data_dict['same_core_2layer_feve_ds'] = feve_ds
data_dict['same_core_2layer_feve_our'] = feve_our

## vary kernel size

In [None]:
nmouse = 6
res_path = os.path.join(result_path, 'feve_same_core_conv_ks.npz')
conv1_ks_list = [7,13,17,21,25,29]
conv2_ks_list = [5,7,9,11,13,15]
if os.path.exists(res_path):
    dat = np.load(res_path)
    feve_conv_ks_all = dat['feve_conv_ks_all']
else:
    pool = True
    mouse_id = 3
    conv1_ks_list = [7,13,17,21,25,29]
    conv2_ks_list = [5,7,9,11,13,15]
    feve_conv_ks_all = np.zeros((nmouse, len(conv1_ks_list), len(conv2_ks_list)))
    conv1_all = np.zeros((nmouse, 16, 9, 9))
    for mouse_id in range(nmouse):
        # load images
        if mouse_id == 5:
            xrange_max = 176
        else:
            xrange_max = 130 
        img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id], downsample=1)
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)

        # load neurons
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # split train and validation set
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

        # normalize spks
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

        spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        input_Ly, input_Lx = img_test.shape[-2:]

        seed = 1
        nlayers = 2
        nconv1 = 16
        nconv2 = 320
    
        for conv1_ks in conv1_ks_list:
            for conv2_ks in conv2_ks_list:
                suffix = ''
                if (conv1_ks != 25) or (conv2_ks != 9):
                    if suffix != '': suffix += '_'
                    suffix += f'ks_{conv1_ks}_{conv2_ks}'
                model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[conv1_ks, conv2_ks])
                model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

                model_path = os.path.join(weight_path, 'fullmodel', model_name)
                print('model path: ', model_path)
                model = model.to(device)
                model.load_state_dict(torch.load(model_path))
                print('loaded model', model_path)

                # test model
                test_pred = model_trainer.test_epoch(model, img_test)

                test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
                print('FEVE (test, all): ', np.mean(test_feve))

                threshold = 0.15
                print(f'filtering neurons with FEV > {threshold}')
                valid_idxes = np.where(test_fev > threshold)[0]
                print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
                print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
                feve_conv_ks_all[mouse_id, conv1_ks_list.index(conv1_ks), conv2_ks_list.index(conv2_ks)] = np.mean(test_feve[test_fev > threshold])
    np.savez(res_path, feve_conv_ks_all=feve_conv_ks_all)

In [None]:
data_dict['same_core_conv_ks_feve'] = feve_conv_ks_all

In [None]:
# plot
import sup_figure
from fig_utils import *
root = './outputs'
sup_figure.figure_same_core(data_dict, root)

# figure 6: Wx and Wy

In [None]:
# mouse
# Example neuron Trial 1 vs trial 2
mouse_id = 0
Wx_all = []
Wy_all = []
fev_all = []
for mouse_id in range(6):
    # spks, spks_rep, itrain, ival, fev, istim_test = load_activity(file_path=data.mouse_file_paths[mouse_id], mouse_id=mouse_id)
    n_neurons = data.NNs[mouse_id]
    ineur = np.arange(0, n_neurons) #np.arange(0, n_neurons, 5)
    input_Ly, input_Lx = 66, 130
    nlayers = 2
    nconv1 = 192
    nconv2 = 192

    suffix = ''
    if mouse_id == 5: suffix += f'xrange_176'
    model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, depth_separable=depth_separable)
    model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, clamp=clamp, suffix=suffix)

    weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
    if not os.path.exists(weight_path):
        os.makedirs(weight_path)
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)

    Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
    # # change model Wx and Wy
    Wx = model.readout.Wx.detach().cpu().numpy()
    Wy = model.readout.Wy.detach().cpu().numpy()
    # outer product of Wx and Wy
    Wxy = np.einsum('icj,ick->ijk', Wy, Wx)

    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev = dat['fev']
    valid_idx = np.where(fev > 0.15)[0]
    Wx_all.append(Wx[valid_idx])
    Wy_all.append(Wy[valid_idx])
    fev_all.append(fev[valid_idx])

data_dict['fev_all'] = fev_all
data_dict['Wx_all'] = Wx_all
data_dict['Wy_all'] = Wy_all

In [None]:
from minimodel.utils import weight_bandwidth
from scipy.interpolate import interp1d
fev_all = data_dict['fev_all']
Wx_all = data_dict['Wx_all']
Wy_all = data_dict['Wy_all']

fev = np.hstack(fev_all)
Wx = np.vstack(Wx_all)
Wy = np.vstack(Wy_all)

In [None]:
NN = Wx.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
centerpos_Wx = np.zeros(NN)
centerpos_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i], centerpos_Wx[i] = weight_bandwidth(Wx[i, 0, :], return_peak=True)
    bandwidth_Wy[i], centerpos_Wy[i] = weight_bandwidth(Wy[i, 0, :], return_peak=True)

Lx, Ly = Wx.shape[-1], Wy.shape[-1]
xmid, ymid = int(Lx/2), int(Ly/2)
x_xrange, y_xrange = np.arange(Lx), np.arange(Ly)

ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(x_xrange - np.max(centerpos_Wx)), 
                    np.max(x_xrange - np.min(centerpos_Wx)), 
                    num=len(x_xrange))

# Container for interpolated Wx values
interp_Wx_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = x_xrange - centerpos_Wx[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wx[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wx_values.append(interp_func(common_x))
    # plt.plot(original_x, Wx[i, 0, :], color='whitesmoke', alpha=1/255)


data_dict['interp_Wx_values'] = interp_Wx_values
data_dict['common_x'] = common_x

In [None]:
from scipy.interpolate import interp1d

# Assumptions and initial setup (Ensure these variables are defined correctly in your context)
ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(y_xrange - np.max(centerpos_Wy)), 
                       np.max(y_xrange - np.min(centerpos_Wy)), 
                       num=len(y_xrange))

# Container for interpolated Wx values
interp_Wy_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = y_xrange - centerpos_Wy[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wy[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wy_values.append(interp_func(common_x))
    # plt.plot(original_x, Wy[i, 0, :], color='gray', alpha=0.1)

# Calculate the mean of the interpolated Wx values
data_dict['interp_Wy_values'] = interp_Wy_values
data_dict['common_y'] = common_x

## monkey

In [None]:
# monkey
Lx, Ly = 80, 80
nconv1 = 192
nconv2 = 192
use_sensorium_normalization = True
model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
suffix = ''
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, suffix=suffix, seed=1, use_sensorium_normalization=use_sensorium_normalization)
weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
# # change model Wx and Wy
Wx = model.readout.Wx.detach().cpu().numpy()
Wy = model.readout.Wy.detach().cpu().numpy()
# outer product of Wx and Wy
Wxy = np.einsum('icj,ick->ijk', Wy, Wx)
print(Wxy.shape, Wc.shape)
print(Wc.shape, Wc.min(), Wc.max())

In [None]:
NN = Wx.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
centerpos_Wx = np.zeros(NN)
centerpos_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i], centerpos_Wx[i] = weight_bandwidth(Wx[i, 0, :], return_peak=True)
    bandwidth_Wy[i], centerpos_Wy[i] = weight_bandwidth(Wy[i, 0, :], return_peak=True)

Lx, Ly = Wx.shape[-1], Wy.shape[-1]
xmid, ymid = int(Lx/2), int(Ly/2)
x_xrange, y_xrange = np.arange(Lx), np.arange(Ly)

ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(x_xrange - np.max(centerpos_Wx)), 
                    np.max(x_xrange - np.min(centerpos_Wx)), 
                    num=len(x_xrange))

# Container for interpolated Wx values
interp_Wx_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = x_xrange - centerpos_Wx[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wx[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wx_values.append(interp_func(common_x))
    # plt.plot(original_x, Wx[i, 0, :], color='whitesmoke', alpha=1/255)


data_dict['monkey_interp_Wx_values'] = interp_Wx_values
data_dict['monkey_common_x'] = common_x

In [None]:
from scipy.interpolate import interp1d

# Assumptions and initial setup (Ensure these variables are defined correctly in your context)
ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(y_xrange - np.max(centerpos_Wy)), 
                       np.max(y_xrange - np.min(centerpos_Wy)), 
                       num=len(y_xrange))

# Container for interpolated Wx values
interp_Wy_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = y_xrange - centerpos_Wy[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wy[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wy_values.append(interp_func(common_x))
    # plt.plot(original_x, Wy[i, 0, :], color='gray', alpha=0.1)

# Calculate the mean of the interpolated Wx values
data_dict['monkey_interp_Wy_values'] = interp_Wy_values
data_dict['monkey_common_y'] = common_x

In [None]:
# plot
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure3(data_dict, save_path)

# figure 7: reuse conv1

In [None]:
nmouse = 6
NN = 100
feve_matrix = np.zeros((nmouse, nmouse, NN))

nconv1 = 16
nconv2 = 64
wc_coef = 0.2
hs_readout = 0.03
nlayers = 2

for mouse_id in range(nmouse):
    img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id])
    nimg, Ly, Lx = img.shape
    print('img: ', img.shape, img.min(), img.max(), img.dtype)
    fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])

    spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
    n_stim, n_max_neurons = spks.shape

    # normalize spks
    itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
    spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
    img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)
    input_Ly, input_Lx = img_test.shape[-2:]

    ineurons = np.arange(data.NNs_valid[mouse_id])
    np.random.seed(42)
    ineurons = np.random.choice(ineurons, 100, replace=False)

    fev_test = metrics.fev(spks_rep_all)
    isort_neurons = np.argsort(fev_test)[::-1]
    for mouse_id_base in range(nmouse):
        for i, ineuron in enumerate(ineurons):
            ineur = [isort_neurons[ineuron]]
            spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

            if mouse_id_base != mouse_id:
                suffix = f'pretrainconv1_{data.mouse_names[mouse_id_base]}_{data.exp_date[mouse_id_base]}'
                if mouse_id_base == 5: 
                    suffix = f'pretrainconv1_{data.mouse_names[mouse_id_base]}_{data.exp_date[mouse_id_base]}_xrange_176'
            else: 
                suffix = ''
            model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, Wc_coef=wc_coef)
            model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, hs_readout=hs_readout, suffix=suffix)

            weight_path = os.path.join(weight_path, 'minimodel', data.mouse_names[mouse_id])
            model_path = os.path.join(weight_path, model_name)
            print('model path: ', model_path)
            model.load_state_dict(torch.load(model_path))
            print('loaded model', model_path)
            model = model.to(device)

            # test model
            test_pred = model_trainer.test_epoch(model, img_test)
            test_fev, test_feve = metrics.feve(spks_rep, test_pred)
            print('FEVE (test): ', test_feve)

            feve_matrix[mouse_id, mouse_id_base, i] = np.mean(test_feve)

In [None]:
data_dict['reuse_conv1_feve'] = feve_matrix.transpose(1,0,2)

In [None]:
# conv1 kernels
mouse_id = 0
conv1_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    conv1 = dat['fullmodel_conv1_W']
    conv1_all.append(conv1)
data_dict['conv1_W'] = np.stack(conv1_all)
print(data_dict['conv1_W'].shape)

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure6(data_dict, save_path)

# figure 8: pool/no pool

In [None]:
res_path = os.path.join(result_path, 'feve_pool_no_pool_mouse.npz')
if os.path.exists(res_path):
    dat = np.load(res_path)
    feve_no_pool = dat['feve_no_pool']
    feve_pool = dat['feve_pool']
    conv1_no_pool = dat['conv1_no_pool']
    conv1_pool = dat['conv1_pool']
    feve_small_ks = dat['feve_small_ks']
    conv1_small_ks = dat['conv1_small_ks']
else:
    nmouse = 6
    feve_no_pool = np.zeros(nmouse)
    feve_pool = np.zeros(nmouse)
    conv1_no_pool = np.zeros((nmouse, 16, 25, 25))
    conv1_pool = np.zeros((nmouse, 16, 25, 25))
    feve_small_ks = np.zeros(nmouse)
    conv1_small_ks = np.zeros((nmouse, 16, 9, 9))

    for mouse_id in range(nmouse):
        # load images
        img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id])
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)

        # load neurons
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # split train and validation set
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

        # normalize spks
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

        spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        input_Ly, input_Lx = img_test.shape[-2:]

        seed = 1
        nlayers = 2
        nconv1 = 16
        nconv2 = 320
        pool = False

        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_no_pool[mouse_id] = np.mean(test_feve[test_fev > threshold])

        conv1_no_pool[mouse_id] = model.core.features.layer0.conv.weight.cpu().detach().numpy().squeeze()

        pool = True

        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_pool[mouse_id] = np.mean(test_feve[test_fev > threshold])
        conv1_pool[mouse_id] = model.core.features.layer0.conv.weight.cpu().detach().numpy().squeeze()

        conv1_ks = 9
        conv2_ks = 9

        suffix = ''
        if (conv1_ks != 25) or (conv2_ks != 9):
            if suffix != '': suffix += '_'
            suffix += f'ks_{conv1_ks}_{conv2_ks}'
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[conv1_ks, conv2_ks])
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

        model_path = os.path.join(weight_path, 'fullmodel', model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_small_ks[mouse_id] = np.mean(test_feve[test_fev > threshold])
        conv1_small_ks[mouse_id] = model.core.features.layer0.conv.weight.cpu().detach().numpy().squeeze()

    np.savez(res_path, feve_no_pool=feve_no_pool, feve_pool=feve_pool, conv1_no_pool=conv1_no_pool, conv1_pool=conv1_pool, feve_small_ks=feve_small_ks, conv1_small_ks=conv1_small_ks)

In [None]:
data_dict['feve_no_pool'] = feve_no_pool
data_dict['feve_pool'] = feve_pool
data_dict['conv1_no_pool'] = conv1_no_pool
data_dict['conv1_pool'] = conv1_pool
data_dict['feve_small_ks'] = feve_small_ks
data_dict['conv1_small_ks'] = conv1_small_ks

In [None]:
# plot
import sup_figure
root = './outputs'
sup_figure.figure_pool_nopool(data_dict, root)

# figure 9. sparsity penalty

## mouse validation

In [None]:
mouse_id = 0
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nhs = len(hs_list)
feve_all = []
nconv2_all = []

dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_val_result.npz'))
feve_all = dat['feve_val']
nconv2_all = dat['nconv2']
print(feve_all.shape, nconv2_all.shape)

data_dict['mouse_feve_val'] = feve_all
data_dict['mouse_nconv2_val'] = nconv2_all

dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_5k_val_result.npz'))
feve_all = dat['feve_val']
nconv2_all = dat['nconv2']
print(feve_all.shape, nconv2_all.shape)

data_dict['mouse_feve_val_5k'] = feve_all
data_dict['mouse_nconv2_val_5k'] = nconv2_all

# monkey
dat = np.load(os.path.join(result_path, "minimodel_monkey_result.npz"))
feve_all = dat['param_search_feve_val']
nconv2_all = dat['param_search_nconv2_val']
print(feve_all.shape, nconv2_all.shape)

data_dict['monkey_feve_val'] = feve_all
data_dict['monkey_nconv2_val'] = nconv2_all

## mouse test

In [None]:
nmouse = 6
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nhs = len(hs_list)
feve_all = []
nconv2_all = []
for mouse_id in range(nmouse):
    dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_result.npz'))
    feve = dat['feve_all']
    wc = dat['wc_all']
    nconv2 = np.sum(np.abs(wc)>0.01, axis=2)
    feve_all.append(feve)
    nconv2_all.append(nconv2)
feve_all = np.vstack(feve_all)
nconv2_all = np.vstack(nconv2_all)
print(feve_all.shape, nconv2_all.shape)

In [None]:
data_dict['sparsity_feve_all'] = feve_all
data_dict['sparsity_nconv2_all'] = nconv2_all
data_dict['sparsity_hs'] = hs_list

### mouse 5k

In [None]:
nmouse = 6
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nhs = len(hs_list)
feve_all = []
nconv2_all = []
for mouse_id in range(nmouse):
    dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_5k_result.npz'))
    feve = dat['feve_all']
    wc = dat['wc_all']
    nconv2 = np.sum(np.abs(wc)>0.01, axis=2)
    feve_all.append(feve)
    nconv2_all.append(nconv2)
feve_all = np.vstack(feve_all)
nconv2_all = np.vstack(nconv2_all)
print(feve_all.shape, nconv2_all.shape)

data_dict['sparsity_feve_5k_all'] = feve_all
data_dict['sparsity_nconv2_5k_all'] = nconv2_all

## monkey

In [None]:
# dat = np.load('./outputs/minimodel_monkey_result.npz')
dat = np.load(os.path.join(result_path, 'minimodel_monkey_result.npz'))
feve_hs_all = dat['feve_hs_all']
wc_hs_all = dat['wc_hs_all']
nconv2_all = np.sum(np.abs(wc_hs_all)>0.01, axis=2)
hs_list = dat['hs_list']
data_dict['sparsity_monkey_feve_all'] = feve_hs_all 
data_dict['sparsity_monkey_nconv2_all'] = nconv2_all
data_dict['sparsity_monkey_hs'] = hs_list

## fullmodel

In [None]:
fpath = os.path.join(result_path, f'fullmodel_hoyer_loss_result_{data.mouse_names[mouse_id]}_wc.npz')
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]

if os.path.exists(fpath):
    dat = np.load(fpath)
    feve_all = dat['feve_all']
    wc_all = dat['wc_all']
    n_wc_all = dat['n_wc_all']
    # hs_list = dat['hs_list']
    data_dict['fullmodel_hoyer_feve_all'] = feve_all
    data_dict['fullmodel_hoyer_wc_all'] = wc_all
    data_dict['fullmodel_hoyer_n_wc_all'] = n_wc_all
    data_dict['fullmodel_hoyer_hs'] = hs_list
else:
    nmouse = 6
    pool = True
    mouse_id = 5
    feve_all_mice = []
    wc_all_mice = []
    n_wc_all_mice = []
    for mouse_id in range(nmouse):
        if mouse_id == 5:
            xrange_max = 176
        else:
            xrange_max = 130 
        img = data.load_images(data_path, file=data.img_file_name[mouse_id])
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)

        # load neurons
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # split train and validation set
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

        # normalize spks
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

        # spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        # img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        input_Ly, input_Lx = img_test.shape[-2:]

        seed = 1
        nlayers = 2
        nconv1 = 16
        nconv2 = 320
        conv1_ks = 25
        conv2_ks = 9
        feve_all = np.zeros(len(hs_list))
        wc_all = []
        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)
        for i in range(len(hs_list)): 
            hs_readout = hs_list[i]
            suffix = ''
            if mouse_id == 5: suffix = f'xrange_{xrange_max}'
            if (conv1_ks != 25) or (conv2_ks != 9):
                if suffix != '': suffix += '_'
                suffix += f'ks_{conv1_ks}_{conv2_ks}'
            model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, depth_separable=depth_separable, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[conv1_ks, conv2_ks])
            model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, clamp=clamp, seed=seed, suffix=suffix, pool=pool, hs_readout=hs_readout)

            model_path = os.path.join(weight_path, 'fullmodel', model_name)
            print('model path: ', model_path)
            model = model.to(device)
            model.load_state_dict(torch.load(model_path))
            print('loaded model', model_path)

            # test model
            test_pred = model_trainer.test_epoch(model, img_test)

            test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
            print('FEVE (test, all): ', np.mean(test_feve))

            threshold = 0.15
            print(f'filtering neurons with FEV > {threshold}')
            valid_idxes_tmp = np.where(test_fev > threshold)[0]
            print(f'valid neurons: {len(valid_idxes_tmp)} / {len(test_fev)}')
            print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
            feve_all[i] = np.mean(test_feve[test_fev > threshold])
            wc_all.append(model.readout.Wc.cpu().detach().numpy().squeeze())

        valid_idxes = np.where(test_fev > 0.15)[0]
        n_wc_all = []
        valid_wc_all = []
        for wc in wc_all:
            wc = wc[valid_idxes]
            n_wc = np.sum(np.abs(wc) > 0.01, axis=1)
            n_wc_all.append(np.mean(n_wc))
            valid_wc_all.append(wc)
        feve_all_mice.append(feve_all) # (nmouse, nhs)
        wc_all_mice.append(valid_wc_all) # (nmouse, nhs, n_neurons, 320)
        n_wc_all_mice.append(n_wc_all) # (nmouse, nhs)

    feve_all_mice = np.array(feve_all_mice)
    for i, wc_all in enumerate(wc_all_mice):
        wc_all_mice[i] = np.array(wc_all)
    n_wc_all_mice = np.array(n_wc_all_mice)
    wc_all_mice = np.array(wc_all_mice)
    np.savez(fpath, feve_all=feve_all_mice, wc_all=wc_all_mice[3], n_wc_all=n_wc_all_mice)


## plot

In [None]:
# plot
import sup_figure
root = './outputs'
sup_figure.figure4(data_dict, root)

# figure 10: 5k vs 30k FEVE with varying #neurons

In [None]:
# load all neurons FEVE
n_neuron = 1
nmouse = 6
neuron_numbers = np.geomspace(1, 1000, num=10, dtype=int)
neuron_numbers = np.unique(np.concatenate(([1], neuron_numbers)))  # Ensure 1 is included and remove duplicates
seed_numbers = np.linspace(20, 1, num=len(neuron_numbers), dtype=int)
feves_all = np.zeros((2, 2, nmouse, seed_numbers[0], n_neuron))
feves_pretrain_all = np.zeros((2, 2, nmouse, seed_numbers[0], n_neuron))
nstim_list = [5000, 30000]

res_path = os.path.join(result_path, 'feve_nstim_pretrain_results.npz')
if os.path.exists(res_path):
    dat = np.load(res_path)
    feves_all = dat['feves_all']
    feves_pretrain_all = dat['feves_pretrain_all']
else:
    pool = True
    mouse_id = 0

    for mouse_id in range(nmouse):
        # load images
        img = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id])
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)

        # load neurons
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # split train and validation set
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

        # normalize spks
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

        ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

        spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

        for m, nstims in enumerate(nstim_list):
            if len(itrain) < nstims:
                print('not enough training stimuli, using all stimuli')
                nstims = len(itrain)

            input_Ly, input_Lx = img_test.shape[-2:]
            seed = 1
            nlayers = 2
            nconv1 = 16
            nconv2 = 320
            conv1_ks = 25
            conv2_ks = 9
            suffix = ''
            suffix += f'nstims_{nstims}'
            ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)
            model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[conv1_ks, conv2_ks])
            model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

            model_path = os.path.join(weight_path, 'fullmodel', model_name)
            print('model path: ', model_path)
            model = model.to(device)
            model.load_state_dict(torch.load(model_path))
            print('loaded model', model_path)

            # test model
            test_pred = model_trainer.test_epoch(model, img_test)

            test_fev, feve_all = metrics.feve(spks_rep_all, test_pred)
            print('FEVE (test, all): ', np.mean(feve_all))

            fev_test = metrics.fev(spks_rep_all)
            valid_idxes = np.where(fev_test > 0.15)[0]
            print(len(valid_idxes), len(fev_test))

            # without pretrain
            i = np.where(neuron_numbers == n_neuron)[0][0]
            for seed in range(1, seed_numbers[i]+1):
                if n_neuron >= len(valid_idxes):
                    print(f'not enough neurons with FEV > 0.15, using all neurons')
                    ineur = valid_idxes.copy()
                    n_neuron = len(valid_idxes)
                else:
                    np.random.seed(n_neuron*seed)
                    ineur = np.random.choice(valid_idxes, size=n_neuron, replace=False)
                spks_rep_all_tmp = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]
                suffix = ''
                if n_neuron != -1:
                    suffix = f'nneurons_{n_neuron}'
                if suffix != '': suffix += '_'
                suffix += f'nstims_{nstims}'
                model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[conv1_ks, conv2_ks])
                model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

                model_path = os.path.join(weight_path, 'fullmodel', model_name)
                print('model path: ', model_path)
                model = model.to(device)
                model.load_state_dict(torch.load(model_path))
                print('loaded model', model_path)

                # test model
                test_pred = model_trainer.test_epoch(model, img_test)

                test_fev, test_feve = metrics.feve(spks_rep_all_tmp, test_pred)
                print('FEVE (test, all): ', np.mean(test_feve))
                
                feves_all[m, 0, mouse_id, seed-1] = test_feve # feve from model trained with n neurons
                feves_all[m, 1, mouse_id, seed-1] = feve_all[ineur] # feve from model trained with all neurons

            # load pretrained model FEVE
            pretrain_mouse_id = mouse_id
            feves_pretrain = np.zeros((2, len(seed_numbers), n_neuron))
            i = np.where(neuron_numbers == n_neuron)[0][0]
            for seed in range(1, seed_numbers[i]+1):
                if n_neuron >= len(valid_idxes):
                    print(f'not enough neurons with FEV > 0.15, using all neurons')
                    ineur = valid_idxes.copy()
                    n_neuron = len(valid_idxes)
                else:
                    np.random.seed(n_neuron*seed)
                    ineur = np.random.choice(valid_idxes, size=n_neuron, replace=False)
                spks_rep_all_tmp = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]
                suffix = ''
                if n_neuron != -1:
                    suffix = f'nneurons_{n_neuron}'
                if suffix != '': suffix += '_'
                suffix += f'nstims_{nstims}'
                if suffix != '': suffix += '_'
                suffix += f'pretrainconv1_{data.mouse_names[pretrain_mouse_id]}_{data.exp_date[pretrain_mouse_id]}'
                if nstims == 5000:
                    if suffix != '': suffix += '_'
                    suffix += f'nstims_{nstims}'
                model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx, kernel_size=[conv1_ks, conv2_ks])
                model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

                model_path = os.path.join(weight_path, 'fullmodel', model_name)
                print('model path: ', model_path)
                model = model.to(device)
                model.load_state_dict(torch.load(model_path))
                print('loaded model', model_path)

                # test model
                test_pred = model_trainer.test_epoch(model, img_test)

                test_fev, test_feve = metrics.feve(spks_rep_all_tmp, test_pred)
                print('FEVE (test, all): ', np.mean(test_feve))

                feves_pretrain_all[m, 0, mouse_id, seed-1] = test_feve
                feves_pretrain_all[m, 1, mouse_id, seed-1] = feve_all[ineur]
    np.savez(res_path, feves_all=feves_all, feves_pretrain_all=feves_pretrain_all)

In [None]:
data_dict['nstim_feve_all'] = feves_all
data_dict['nstim_feve_pretrain_all'] = feves_pretrain_all

In [None]:
# plot
import sup_figure
root = './outputs'
sup_figure.figure_vary_nneuron(data_dict, root)

# figure 11: model structure

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure5(data_dict, save_path)

# figure 12: conv2 clustering

In [None]:
def test_channel_output(model, img_test, batch_size=100):
    model.eval()
    n_test = img_test.shape[0]
    conv2_features = []
    conv2_indepth_features = []
    with torch.no_grad():
        for k in np.arange(0, n_test, batch_size):
            kend = min(k+batch_size, n_test)
            img_batch = img_test[k:kend]


            x = model.core.features.layer0(img_batch)
            conv2_indepth_fv = model.core.features.layer1.ds_conv.in_depth_conv(x)
            x = model.core.features.layer1.ds_conv.spatial_conv(conv2_indepth_fv)
            x = model.core.features.layer1.norm(x)
            conv2_relu_fvs = model.core.features.layer1.activation(x)
            # print('after in_depth_conv: ', conv2_indepth_fv.shape, conv2_indepth_fv.max(), conv2_indepth_fv.min())
            # print('after conv2_relu: ', conv2_relu_fvs.shape, conv2_relu_fvs.max(), conv2_relu_fvs.min())
            conv2_fv = conv2_relu_fvs[:, :, 16, 32]
            conv2_indepth_fv = conv2_indepth_fv[:, :, 16, 32]
            conv2_features.append(conv2_fv.detach().cpu().numpy())
            conv2_indepth_features.append(conv2_indepth_fv.detach().cpu().numpy())
            # spks_test_pred[k:kend] = spks_pred
    conv2_features = np.vstack(conv2_features)
    conv2_indepth_features = np.vstack(conv2_indepth_features)
    return conv2_indepth_features, conv2_features

In [None]:
# load 600 neurons minimodels
# save spatial conv, 1x1 conv, and mouse_id
# get conv2 channel responses of the train images
res_path = os.path.join(result_path, 'conv2_analysis.npz')
if os.path.exists(res_path):
    dat = np.load(res_path)
    feve_all = dat['feve_all']
    in_depth_conv_all = dat['in_depth_conv_all']
    spatial_conv_all = dat['spatial_conv_all']
    Wc_all = dat['Wc_all']
    channel_resp_all = dat['channel_resp_all']
    channel_indepth_all = dat['channel_indepth_all']
else: 
    nmouse = 6
    NN = 100
    feve_all = np.zeros((nmouse, NN))
    in_depth_conv_all = np.zeros((nmouse, NN, 64, 16))
    spatial_conv_all = np.zeros((nmouse, NN, 64, 9, 9))
    Wc_all = np.zeros((nmouse, NN, 64))
    nconv1 = 16
    nconv2 = 64
    wc_coef = 0.2
    hs_readout = 0.03
    nlayers = 2
    xrange_max = 130

    # use the same train images for all models
    mouse_id = 3
    img = data.load_images(data_path, file=os.path.join(data_path, data.img_file_name[mouse_id]), xrange=[xrange_max-130, xrange_max])
    nimg, Ly, Lx = img.shape
    print('img: ', img.shape, img.min(), img.max(), img.dtype)
    img_train = torch.from_numpy(img[500:30000]).to(device).unsqueeze(1) 
    print('img_train: ', img_train.shape)
    ntrain = img_train.shape[0]
    channel_resp_all = np.zeros((nmouse, NN, ntrain, 64))
    channel_indepth_all = np.zeros((nmouse, NN, ntrain, 64))

    for mouse_id in range(nmouse):
        if mouse_id == 5: xrange_max = 176
        else: xrange_max = 130
        img = data.load_images(data_path, file=os.path.join(data_path, data.img_file_name[mouse_id]), xrange=[xrange_max-130, xrange_max])
        nimg, Ly, Lx = img.shape
        print('img: ', img.shape, img.min(), img.max(), img.dtype)
        fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])

        spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
        n_stim, n_max_neurons = spks.shape

        # normalize spks
        itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
        spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
        img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)
        input_Ly, input_Lx = img_test.shape[-2:]

        ineurons = np.arange(data.NNs_valid[mouse_id])
        np.random.seed(42)
        ineurons = np.random.choice(ineurons, 100, replace=False)

        fev_test = metrics.fev(spks_rep_all)
        isort_neurons = np.argsort(fev_test)[::-1]
        # for mouse_id_base in range(nmouse):
        mouse_id_base = 0
        for i, ineuron in enumerate(ineurons):
            ineur = [isort_neurons[ineuron]]
            spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

            if mouse_id_base != mouse_id:
                suffix = f'pretrainconv1_{data.mouse_names[mouse_id_base]}_{data.exp_date[mouse_id_base]}'
                if mouse_id_base == 5: 
                    suffix = f'pretrainconv1_{data.mouse_names[mouse_id_base]}_{data.exp_date[mouse_id_base]}_xrange_176'
            else: 
                suffix = ''
            model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, Wc_coef=wc_coef)
            model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, hs_readout=hs_readout, suffix=suffix)

            model_path = os.path.join(weight_path, 'minimodel', model_name)
            print('model path: ', model_path)
            model.load_state_dict(torch.load(model_path))
            print('loaded model', model_path)
            model = model.to(device)

            in_depth_conv_all[mouse_id, i] = model.core.features.layer1.ds_conv.in_depth_conv.weight.cpu().detach().numpy().squeeze()
            spatial_conv_all[mouse_id, i] = model.core.features.layer1.ds_conv.spatial_conv.weight.cpu().detach().numpy().squeeze()
            Wc_all[mouse_id, i] = model.readout.Wc.cpu().detach().numpy().squeeze()

            # test model
            test_pred = model_trainer.test_epoch(model, img_test)
            test_fev, test_feve = metrics.feve(spks_rep, test_pred)
            print('FEVE (test): ', test_feve)
            feve_all[mouse_id, i] = np.mean(test_feve)
            channel_indepth_all[mouse_id, i], channel_resp_all[mouse_id, i] = test_channel_output(model, img_train)
    np.savez(res_path, feve_all=feve_all, in_depth_conv_all=in_depth_conv_all, spatial_conv_all=spatial_conv_all, Wc_all=Wc_all, channel_resp_all=channel_resp_all, channel_indepth_all=channel_indepth_all)

## 1x1 conv

In [None]:
# tsne of 1x1 conv (color by mouse_id)
print(in_depth_conv_all.shape, spatial_conv_all.shape, Wc_all.shape)
nmouse, NN, nconv2, nconv1 = in_depth_conv_all.shape
in_depth_conv_all_flat = in_depth_conv_all.reshape(nmouse*NN*nconv2, nconv1)
# normalize by the max value in each row 
# in_depth_conv_all_flat = in_depth_conv_all_flat / np.abs(in_depth_conv_all_flat).max(axis=1)[:, None]
# normalize but the norm of each row 
in_depth_conv_all_flat = in_depth_conv_all_flat / np.linalg.norm(in_depth_conv_all_flat, axis=1)[:, None]
mouse_ids = np.zeros(Wc_all.shape)
for i in range(nmouse):
    mouse_ids[i] = i
mouse_ids_flat = mouse_ids.flatten()
Wc_all_flat = Wc_all.flatten()
print(in_depth_conv_all_flat.shape, mouse_ids_flat.shape, Wc_all_flat.shape)

# select the valid channels
valid_idxes = np.where(np.abs(Wc_all_flat) > 0.01)[0]
print('valid channels: ', len(valid_idxes))
in_depth_conv_all_flat = in_depth_conv_all_flat[valid_idxes]
mouse_ids_flat = mouse_ids_flat[valid_idxes]

from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)
X_2d = tsne.fit_transform(in_depth_conv_all_flat)
print('X_2d: ', X_2d.shape)

In [None]:
data_dict['tsne_conv2_1x1'] = X_2d
data_dict['mouse_ids_flat'] = mouse_ids_flat

In [None]:
# kmeans clustering of 1x1 conv (color by mouse_id)
from sklearn.cluster import KMeans
n_clusters = 6
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X_2d)
cluster_labels = kmeans.labels_
print('cluster_labels: ', cluster_labels.shape)

# cluster centers from kmeans
cluster_centers = kmeans.cluster_centers_
print('cluster_centers: ', cluster_centers.shape)

# find the most representative channel for each cluster
from scipy.spatial import distance
cluster_center_idxes = []
cluster_center_samples = []
for i in range(n_clusters):
    idxes = np.where(cluster_labels == i)[0]
    samples = X_2d[idxes]
    dists = distance.cdist([cluster_centers[i]], samples)
    center_idx = idxes[np.argmin(dists)]
    cluster_center_idxes.append(center_idx)
    cluster_center_samples.append(samples[np.argmin(dists)])

In [None]:
data_dict['cluster_labels'] = cluster_labels
data_dict['center_cluster_samples'] = cluster_center_samples
data_dict['cluster_center_idxes'] = cluster_center_idxes

In [None]:
# load conv1 weights of mouse 0
mouse_id = 0
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
n_max_neurons = data.NNs[mouse_id]
ineur = np.arange(0, n_max_neurons) 
input_Ly, input_Lx = 66, 130
model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, input_Ly=input_Ly, input_Lx=input_Lx)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix, pool=pool)

model_path = os.path.join(weight_path, 'fullmodel', model_name)
print('model path: ', model_path)
model = model.to(device)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

conv1_W = model.core.features.layer0.conv.weight.cpu().detach().numpy().squeeze()
print(f'conv1_W: {conv1_W.shape}')

ori_cluster_centers = in_depth_conv_all_flat[cluster_center_idxes]
print('ori_cluster_centers: ', ori_cluster_centers.shape)


In [None]:
data_dict['conv1_W'] = conv1_W
data_dict['ori_cluster_centers'] = ori_cluster_centers

## spatial conv

In [None]:
nmouse, NN, nconv2, Ly, Lx = spatial_conv_all.shape
spatial_conv_all_flat = spatial_conv_all.reshape(nmouse*NN*nconv2, Ly*Lx)
spatial_conv_all_flat = spatial_conv_all_flat / np.linalg.norm(spatial_conv_all_flat, axis=1)[:, None]
mouse_ids = np.zeros(Wc_all.shape)
for i in range(nmouse):
    mouse_ids[i] = i
mouse_ids_flat = mouse_ids.flatten()
Wc_all_flat = Wc_all.flatten()
print(spatial_conv_all_flat.shape, mouse_ids_flat.shape, Wc_all_flat.shape)

# select the valid channels
valid_idxes = np.where(np.abs(Wc_all_flat) > 0.01)[0]
print('valid channels: ', len(valid_idxes))
spatial_conv_all_flat = spatial_conv_all_flat[valid_idxes]
mouse_ids_flat = mouse_ids_flat[valid_idxes]

from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)
X_2d = tsne.fit_transform(spatial_conv_all_flat)
print('X_2d: ', X_2d.shape)

In [None]:
data_dict['tsne_conv2_spatial'] = X_2d

In [None]:
# kmeans clustering of 1x1 conv (color by mouse_id)
from sklearn.cluster import KMeans
n_clusters = 6
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X_2d)
cluster_labels = kmeans.labels_
print('cluster_labels: ', cluster_labels.shape)

# cluster centers from kmeans
cluster_centers = kmeans.cluster_centers_
print('cluster_centers: ', cluster_centers.shape)

# find the most representative channel for each cluster
from scipy.spatial import distance
cluster_center_idxes = []
cluster_center_samples = []
for i in range(n_clusters):
    idxes = np.where(cluster_labels == i)[0]
    samples = X_2d[idxes]
    dists = distance.cdist([cluster_centers[i]], samples)
    center_idx = idxes[np.argmin(dists)]
    cluster_center_idxes.append(center_idx)
    cluster_center_samples.append(samples[np.argmin(dists)])

In [None]:
data_dict['spatial_cluster_labels'] = cluster_labels
data_dict['spatial_center_cluster_samples'] = cluster_center_samples
data_dict['spatial_cluster_center_idxes'] = cluster_center_idxes

In [None]:
ori_cluster_centers = spatial_conv_all_flat[cluster_center_idxes].reshape(len(cluster_center_idxes), Ly, Lx)
print('ori_cluster_centers: ', ori_cluster_centers.shape)
data_dict['spatial_ori_cluster_centers'] = ori_cluster_centers

## channel responses

### after conv2

In [None]:
print(channel_resp_all.shape)
print(channel_indepth_all.shape)
from numpy import linalg as LA
conv2_fv = True
if conv2_fv: channel_resp = channel_resp_all # wether use the conv2 features or conv2 indepth features
else: channel_resp = channel_indepth_all
nmouse, NN, ntrain, nconv2 = channel_resp.shape
channel_resp_all_flat = channel_resp.transpose(0,1,3,2).reshape(nmouse*NN*nconv2, ntrain)
# channel_resp_all_flat = channel_resp_all_flat / (np.abs(channel_resp_all_flat).max(axis=1)[:, None] + 1e-6)
# normalize by the norm of each row
channel_resp_all_flat = channel_resp_all_flat / LA.norm(channel_resp_all_flat, axis=1)[:, None]
mouse_ids = np.zeros(Wc_all.shape)
for i in range(nmouse):
    mouse_ids[i] = i    
mouse_ids_flat = mouse_ids.flatten()
Wc_all_flat = Wc_all.flatten()
print(channel_resp_all_flat.shape, mouse_ids_flat.shape, Wc_all_flat.shape)
valid_idxes = np.where(np.abs(Wc_all_flat) > 0.01)[0]
Wc_all_flat = Wc_all_flat[valid_idxes]
channel_resp_all_flat = channel_resp_all_flat[valid_idxes]
mouse_ids_flat = mouse_ids_flat[valid_idxes]
print(channel_resp_all_flat.shape, mouse_ids_flat.shape, Wc_all_flat.shape)

# pca to 200 dimensions
from sklearn.decomposition import PCA
pca = PCA(n_components=200)
X_pca = pca.fit_transform(channel_resp_all_flat)
print('X_pca: ', X_pca.shape)

# tsne of conv2 channel responses (color by mouse_id)
# Get the first two PCs for initialization
from openTSNE import TSNE

initial_embedding = X_pca[:, :2]

# t-SNE with openTSNE
tsne = TSNE(
    n_components=2,
    initialization=initial_embedding,  # Initialize with first two PCs
    perplexity=50,  # Adjust based on your data
    n_jobs=-1,      # Use all available CPU cores
    random_state=42
)
X_tsne = tsne.fit(X_pca)
print('X_tsne: ', X_tsne.shape)

In [None]:
data_dict['tsne_conv2_channel'] = X_tsne

## rastermap of channel responses

In [None]:
print(channel_resp_all.shape) # nmouse, NN, ntrain, 64
nmouse, NN, ntrain, nconv2 = channel_resp_all.shape
channel_resp_all_flat = channel_resp_all.transpose(0,1,3,2).reshape(nmouse*NN*nconv2, ntrain)
channel_resp_all_flat = channel_resp_all_flat / (np.abs(channel_resp_all_flat).max(axis=1)[:, None] + 1e-6)
mouse_ids = np.zeros(Wc_all.shape)
for i in range(nmouse):
    mouse_ids[i] = i
mouse_ids_flat = mouse_ids.flatten()
Wc_all_flat = Wc_all.flatten()
print(channel_resp_all_flat.shape, mouse_ids_flat.shape, Wc_all_flat.shape)
valid_idxes = np.where(np.abs(Wc_all_flat) > 0.01)[0]
Wc_all_flat = Wc_all_flat[valid_idxes]
channel_resp_all_flat = channel_resp_all_flat[valid_idxes]
mouse_ids_flat = mouse_ids_flat[valid_idxes]
print(channel_resp_all_flat.shape, mouse_ids_flat.shape, Wc_all_flat.shape)

In [None]:
# rastermap of responses
# with a plot with mouse id
# 600 neurons, show everything binned
from scipy.stats import zscore
from rastermap import Rastermap
channel_resp = zscore(channel_resp_all_flat, axis=1)
n_neurons, n_stim = channel_resp.shape
n_bins = 500
bin_size = n_neurons // n_bins
model = Rastermap(n_clusters=100, # number of clusters to compute
                  n_PCs=200, # number of PCs to use
                  locality=0.5, # locality in sorting to find sequences (this is a value from 0-1)
                  # grid_upsample=10, # default value, 10 is good for large recordings
                  bin_size=bin_size
                ).fit(channel_resp)

y = model.embedding # neurons x 1
isort = model.isort

x = model.X_embedding

In [None]:
data_dict['rastermap_x'] = x
data_dict['rastermap_channel_resp'] = channel_resp
data_dict['rastermap_isort'] = isort

## plot

In [None]:
# plot
import sup_figure
root = './outputs'
sup_figure.figure_conv2_cluster(data_dict, root)

# figure 13: example neurons visualization

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure9(data_dict, save_path)

# save all

In [None]:
np.savez(f'supfig_results.npz', **data_dict)