In [None]:
import os
import torch
import numpy as np
from minimodel import data, metrics, model_builder, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../../data'
weight_path = '../checkpoints/fullmodel'

# param search

In [None]:
ineuron = 1
n_max_neurons = 166
hs_list = [0.0, 0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.01, 0.02, 0.05]
n_select_neurons = 10
np.random.seed(1)
ineurons = np.random.choice(n_max_neurons, n_select_neurons, replace=False)

feve_test_all = np.zeros((len(hs_list), n_select_neurons))
feve_val_all = np.zeros((len(hs_list), n_select_neurons))
wc_all = np.zeros((len(hs_list), len(ineurons),  64))
for i, hs_readout in enumerate(hs_list):
    for j, ineuron in enumerate(ineurons):
        # load data
        dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
        images = dat['images']
        responses = dat['responses'][:, ineuron][:, None]
        real_responses = dat['real_responses'][:, ineuron][:, None]
        test_images = dat['test_images']
        test_responses = dat['test_responses'][:, :, ineuron][:, :, None]
        test_real_responses = dat['test_real_responses'][:, :, ineuron][:, :, None]
        train_idx = dat['train_idx']
        val_idx = dat['val_idx']
        repetitions = [dat['repetitions'][ineuron]]
        monkey_id = dat['subject_id']

        responses_nan = np.where(real_responses, responses, np.nan)
        resp_std = np.nanstd(responses_nan, axis=0)
        responses = responses / resp_std
        test_responses = test_responses / resp_std


        train_images = images[train_idx]
        val_images = images[val_idx]
        train_responses = responses[train_idx]
        val_responses = responses[val_idx]
        train_real_responses = real_responses[train_idx]
        val_real_responses = real_responses[val_idx]

        test_responses_nan = np.where(test_real_responses, test_responses, np.nan)

        monkey_ids = dat['subject_id']
        print(len(monkey_ids), np.unique(monkey_ids))

        NN = train_responses.shape[1]
        Lx, Ly = train_images.shape[2], train_images.shape[3]

        train_images = torch.from_numpy(train_images)
        val_images = torch.from_numpy(val_images)
        train_responses = torch.from_numpy(train_responses)
        val_responses = torch.from_numpy(val_responses)
        train_real_responses = torch.from_numpy(train_real_responses)
        val_real_responses = torch.from_numpy(val_real_responses)

        # build model
        seed = 1
        nlayers = 2
        nconv1 = 16
        nconv2 = 64
        wc_coef = 0.2
        l2_readout = 0.2
        
        model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly, Wc_coef=wc_coef)
        model_name = model_builder.create_model_name('monkeyV1', '2019', ineuron=ineuron, n_layers=nlayers, in_channels=in_channels, seed=seed, hs_readout=hs_readout)
        weight_path = os.path.join(weight_path, 'minimodel', 'monkeyV1')
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)
        model = model.to(device)

        test_images = torch.from_numpy(test_images).to(device)
        spks_pred_test = model_trainer.test_epoch(model, test_images)

        test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
        print('FEVE (test):', np.mean(test_feve))
        feve_test_all[i, j] = test_feve

        num_reps = 4
        sz = val_responses.shape[0]
        val_responses = val_responses.reshape([num_reps, int(sz / num_reps), NN])
        val_images = val_images.reshape([num_reps, int(sz / num_reps), 1, 80, 80])[0]

        val_images = val_images.to(device)
        spks_pred_val = model_trainer.test_epoch(model, val_images)
        val_fev, val_feve = metrics.monkey_feve(val_responses, spks_pred_val, repetitions)
        print('FEVE (val):', np.mean(val_feve))
        feve_val_all[i, j] = val_feve

        wc_all[i, j] = model.readout.Wc.detach().cpu().numpy().squeeze()

In [None]:
data_dict['param_search_feve_val'] = feve_val_all.mean(axis=1)
data_dict['param_search_nconv2_val'] = np.sum(np.abs(wc_all)>0.01, axis=2).mean(axis=1)

# param search (all neurons)

In [None]:
ineuron = 1
hs_list = [0.0, 0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.01, 0.02, 0.05]
feve_all = np.zeros((len(hs_list), 166))
wc_all = np.zeros((len(hs_list), 166, 64))
for i, hs_readout in enumerate(hs_list):
    for ineuron in range(166):
        # load data
        dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
        images = dat['images']
        responses = dat['responses'][:, ineuron][:, None]
        real_responses = dat['real_responses'][:, ineuron][:, None]
        test_images = dat['test_images']
        test_responses = dat['test_responses'][:, :, ineuron][:, :, None]
        test_real_responses = dat['test_real_responses'][:, :, ineuron][:, :, None]
        train_idx = dat['train_idx']
        val_idx = dat['val_idx']
        repetitions = [dat['repetitions'][ineuron]]
        monkey_id = dat['subject_id']


        # normalize responses
        responses_nan = np.where(real_responses, responses, np.nan)
        resp_std = np.nanstd(responses_nan, axis=0)
        responses = responses / resp_std
        test_responses = test_responses / resp_std


        train_images = images[train_idx]
        val_images = images[val_idx]
        train_responses = responses[train_idx]
        val_responses = responses[val_idx]
        train_real_responses = real_responses[train_idx]
        val_real_responses = real_responses[val_idx]

        test_responses = np.where(test_real_responses, test_responses, np.nan)

        monkey_ids = dat['subject_id']
        print(len(monkey_ids), np.unique(monkey_ids))

        NN = train_responses.shape[1]
        Lx, Ly = train_images.shape[2], train_images.shape[3]

        train_images = torch.from_numpy(train_images)
        val_images = torch.from_numpy(val_images)
        train_responses = torch.from_numpy(train_responses)
        val_responses = torch.from_numpy(val_responses)
        train_real_responses = torch.from_numpy(train_real_responses)
        val_real_responses = torch.from_numpy(val_real_responses)

        # build model
        seed = 1
        nlayers = 2
        nconv1 = 16
        nconv2 = 64
        wc_coef = 0.2
        l2_readout = 0.2
        model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly, Wc_coef=wc_coef)
        model_name = model_builder.create_model_name('monkeyV1', '2019', ineuron=ineuron, n_layers=nlayers, in_channels=in_channels, seed=seed, 
                                            hs_readout=hs_readout)
        weight_path = os.path.join(weight_path, 'minimodel', 'monkeyV1')
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)
        model = model.to(device)

        test_images = torch.from_numpy(test_images).to(device)
        spks_pred_test = model_trainer.test_epoch(model, test_images)

        test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
        print('FEVE (test):', np.mean(test_feve))
        feve_all[i, ineuron] = test_feve
        wc_all[i, ineuron] = model.readout.Wc.detach().cpu().numpy().squeeze()

In [None]:
data_dict['feve_hs_all'] = feve_all
data_dict['wc_hs_all'] = wc_all
data_dict['hs_list'] = hs_list

# load all neurons

In [None]:
ineuron = 1
feve_all = np.zeros(166)
fev_all = np.zeros(166)
wc_all = np.zeros((166, 64))
wx_all = []
wy_all = []
test_pred_all = []
for ineuron in range(166):
    # load data
    dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
    images = dat['images']
    responses = dat['responses'][:, ineuron][:, None]
    real_responses = dat['real_responses'][:, ineuron][:, None]
    test_images = dat['test_images']
    test_responses = dat['test_responses'][:, :, ineuron][:, :, None]
    test_real_responses = dat['test_real_responses'][:, :, ineuron][:, :, None]
    train_idx = dat['train_idx']
    val_idx = dat['val_idx']
    repetitions = [dat['repetitions'][ineuron]]
    monkey_id = dat['subject_id']


    # normalize responses
    responses_nan = np.where(real_responses, responses, np.nan)
    resp_std = np.nanstd(responses_nan, axis=0)
    responses = responses / resp_std
    test_responses = test_responses / resp_std


    train_images = images[train_idx]
    val_images = images[val_idx]
    train_responses = responses[train_idx]
    val_responses = responses[val_idx]
    train_real_responses = real_responses[train_idx]
    val_real_responses = real_responses[val_idx]

    test_responses = np.where(test_real_responses, test_responses, np.nan)

    monkey_ids = dat['subject_id']
    print(len(monkey_ids), np.unique(monkey_ids))

    NN = train_responses.shape[1]
    Lx, Ly = train_images.shape[2], train_images.shape[3]

    train_images = torch.from_numpy(train_images)
    val_images = torch.from_numpy(val_images)
    train_responses = torch.from_numpy(train_responses)
    val_responses = torch.from_numpy(val_responses)
    train_real_responses = torch.from_numpy(train_real_responses)
    val_real_responses = torch.from_numpy(val_real_responses)

    # build model
    seed = 1
    nlayers = 2
    nconv1 = 16
    nconv2 = 64
    wc_coef = 0.2
    hs_readout = 0.004
    l2_readout = 0.2
    model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly, Wc_coef=wc_coef)
    model_name = model_builder.create_model_name('monkeyV1', '2019', ineuron=ineuron, n_layers=nlayers, in_channels=in_channels, seed=seed, hs_readout=hs_readout)
    weight_path = os.path.join(weight_path, 'minimodel', 'monkeyV1')
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)
    model = model.to(device)

    test_images = torch.from_numpy(test_images).to(device)
    spks_pred_test = model_trainer.test_epoch(model, test_images)
    test_pred_all.append(spks_pred_test)

    test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
    print('FEVE (test):', np.mean(test_feve))
    feve_all[ineuron] = np.mean(test_feve)
    fev_all[ineuron] = np.mean(test_fev)
    wc_all[ineuron] = model.readout.Wc.detach().cpu().numpy().squeeze()
    wx_all.append(model.readout.Wx.detach().cpu().numpy().squeeze())
    wy_all.append(model.readout.Wy.detach().cpu().numpy().squeeze())

In [None]:
data_dict['feve_all'] = feve_all
data_dict['fev_all'] = fev_all
data_dict['wc_all'] = wc_all
data_dict['wx_all'] = np.stack(wx_all)
data_dict['wy_all'] = np.stack(wy_all)
data_dict['test_pred_all'] = np.stack(test_pred_all).squeeze()

# save

In [None]:
np.savez("outputs/minimodel_monkey_result.npz", **data_dict)