In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from minimodel import data, metrics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_path = '../notebooks/data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

mouse_id = 3

# train gabor model

In [None]:
# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_max_neurons = spks.shape

# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
spks_val = torch.from_numpy(spks[ival][:,ineur]) 
spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

ineurons = np.arange(data.NNs_valid[mouse_id])
# np.random.seed(42)
# ineurons = np.random.choice(ineurons, 100, replace=False)

fev_test = metrics.fev(spks_rep_all)
isort_neurons = np.argsort(fev_test)[::-1]
ineur = isort_neurons[ineurons]

print(spks.shape, spks_val.shape, len(spks_rep_all), spks_rep_all[0].shape)

spks = spks[:,ineur]
spks_val = spks_val[:,ineur]
spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]
print(spks.shape, spks_val.shape, len(spks_rep_all), spks_rep_all[0].shape)

In [None]:
img_all = data.load_images(data_path, mouse_id, file=data.img_file_name[mouse_id], downsample=2)
nimg, Ly, Lx = img_all.shape
print('img: ', img_all.shape, img_all.min(), img_all.max())

In [None]:
n_stim = -1 # spks.shape[0]
n_neurons = -1

# generate random data
if n_stim > 0:
    istims = np.random.choice(spks.shape[0], n_stim, replace=False)
else:
    n_stim = spks.shape[0]
    istims = np.arange(n_stim)
if n_neurons > 0:
    ineurons = np.random.choice(spks.shape[1], n_neurons, replace=False)
    X_test = [spks_rep_all[i][:,ineurons] for i in range(len(spks_rep_all))]
else:
    n_neurons = spks.shape[1]
    ineurons = np.arange(n_neurons)
    X_test = spks_rep_all.copy()

X = spks[istims][:,ineurons]
img = img_all[istim_train][istims].transpose(1,2,0)
img_test = img_all[istim_test].transpose(1,2,0)
print(f'img: {img.shape}, X: {X.shape}')
Ly, Lx, _ = img.shape


In [None]:
from minimodel import gabor
result_dict = gabor.fit_gabor_model(X, img, X_test, img_test)

# test

In [None]:
# define gabor parameters
sigma = np.array([0.75, 1.25, 1.5, 2.5, 3.5, 4.5, 5.5])
f = np.array([0.1, 0.25, 0.5, 1, 2]) #[.01:.02:.13];
theta = np.arange(0, np.pi, np.pi/8)
ph = np.arange(0, 2*np.pi, np.pi/4)
ar = np.array([1, 1.5, 2])
print(f'sigma: {sigma.shape}, f: {f.shape}, theta: {theta.shape}, ph: {ph.shape}, ar: {ar.shape}')

params = np.meshgrid(sigma, f, theta, ph, ar, indexing='ij')
n_gabors = params[0].size
print(f'number of gabors: {n_gabors}')

for i in range(len(params)):
    params[i] = np.expand_dims(params[i], axis=(-2,-1))
    params[i] = torch.from_numpy(params[i].astype('float32'))
sigma, f, theta, ph, ar = params
print(f'sigma: {sigma.shape}, f: {f.shape}, theta: {theta.shape}, ph: {ph.shape}, ar: {ar.shape}')

In [None]:
result_dict = np.load(os.path.join(weight_path, 'gabor', f'gabor_params_{data.db[mouse_id]["mname"]}.npz'), allow_pickle=True)

In [None]:
xmax, ymax = result_dict['xmax'], result_dict['ymax']
ys, xs = np.meshgrid(np.arange(0,Ly), np.arange(0,Lx), indexing='ij')
ys, xs = torch.from_numpy(ys.astype('float32')), torch.from_numpy(xs.astype('float32'))
gmax = result_dict['gmax']
gabor_params = torch.zeros((5, n_neurons, 1, 1))
for i in range(len(gabor_params)):
    gabor_params[i] = params[i].flatten()[gmax].reshape(n_neurons, 1, 1)
msigma, mf, mtheta, mph, mar = gabor_params
Amax = result_dict['Amax']
mu1 = torch.from_numpy(result_dict['mu1']).to(device)
mu2 = torch.from_numpy(result_dict['mu2']).to(device)
#  test
ym = torch.from_numpy(ymax.astype('float32')).unsqueeze(-1).unsqueeze(-1)
xm = torch.from_numpy(xmax.astype('float32')).unsqueeze(-1).unsqueeze(-1)
# print(f'ym: {ym.shape}, xm: {xm.shape}')
gabor_params = torch.zeros((5, n_neurons, 1, 1))
for i in range(len(gabor_params)):
    gabor_params[i] = params[i].flatten()[gmax].reshape(n_neurons, 1, 1)
msigma, mf, mtheta, mph, mar = gabor_params
from minimodel.gabor import gabor_filter, eval_gabors
gabor_filters1 = gabor_filter(ys, xs, ym, xm, 1, msigma, mf, mtheta, mph, mar, is_torch=True).to(device).unsqueeze(-3)
gabor_filters2 = gabor_filter(ys, xs, ym, xm, 1, msigma, mf, mtheta, mph + np.pi/2, mar, is_torch=True).to(device).unsqueeze(-3)

# load test images
# img_test = img_all[istim_test].transpose(1,2,0)
# img_test = (img_test - img_mean) / img_std
# print(f'img_test: {img_test.shape} {img_test.min()}, {img_test.max()}')

# predict responses
ntest = len(istim_test)
resp_test1 = torch.zeros((n_neurons, ntest), dtype=torch.float32, device=device)
resp_test2 = torch.zeros((n_neurons, ntest), dtype=torch.float32, device=device)
eval_gabors(img_test, gabor_filters1, resp_test1, device=device, rectify=False)
eval_gabors(img_test, gabor_filters2, resp_test2, device=device, rectify=False)
resp_test2 = torch.sqrt(resp_test1**2 + resp_test2**2) # RMS for complex cell response
from torch.nn.functional import relu
resp_test2 = relu(resp_test2) # rectify
resp_test1 = relu(resp_test1) # rectify

c = torch.from_numpy(Amax).to(device)

rpred = ((resp_test1.T - mu1) * c[:,0] + (resp_test2.T - mu2) * c[:,1]) # (n_stim, n_neurons)
print(f'rpred: {rpred.shape}')

In [None]:
# test responses
train_mu = result_dict['train_mu']
train_std = result_dict['train_std']
X_test = [spks_rep_all[i][:,ineurons] for i in range(len(spks_rep_all))]
for i in range(len(X_test)):
    X_test[i] -= train_mu
    X_test[i] /= train_std

In [None]:
fev, feve = metrics.feve(X_test, rpred.cpu().numpy())
print(f'fev:{fev.mean():.3f}, feve:{feve.mean():.3f}')

cratio = Amax[:,1]/Amax.sum(axis=1)

In [None]:
gabor_filters1.shape
ineurons = np.random.choice(n_neurons, 10, replace=False)
import matplotlib.pyplot as plt 
fig, ax = plt.subplots(2, 5, figsize=(15,6))
for i, axi in enumerate(ax.flat):
    axi.imshow(gabor_filters1[ineurons[i]].cpu().numpy().squeeze(), cmap='gray')
    axi.axis('off')
    axi.set_title(f'sigma={msigma[ineurons[i]].item():.2f}, f={mf[ineurons[i]].item():.2f}, \ntheta={mtheta[ineurons[i]].item():.2f}, ph={mph[ineurons[i]].item():.2f}, \nar={mar[ineurons[i]].item():.2f}')
plt.show()