In [1]:
%load_ext autoreload
%autoreload 2

# first we need a bit of import boilerplate
import os
import sys
from sys import platform
if platform == 'win32':
    sys.path.append('D:/Brain_Network/Code/')
    manifest_path = 'D:/Brain_Network/Data/Allen_Institute_Dataset/manifest.json'
    project_dir = 'D:/Brain_Network/'
elif platform =='darwin':
    sys.path.append('/Users/chenyu/Workspace/Brain_Network/Code/')
    manifest_path = '/Users/chenyu/Workspace/Brain_Network/Data/Allen_Institute_Dataset/manifest.json'
    project_dir = '/Users/chenyu/Workspace/Brain_Network/'
elif platform == 'linux':
    sys.path.append('/home/yuchen/workspace/Brain_Network/Code/')
    manifest_path = '/home/yuchen/workspace/Brain_Network/Data/Allen_Institute_Dataset/manifest.json'
    project_dir = '/home/yuchen/workspace/Brain_Network/'
else:
    print('Add new computer system settings.')

import numpy as np; np.set_printoptions(linewidth=110); print(np.get_printoptions()['linewidth'])
import glob
import pandas as pd
import matplotlib.pyplot as plt
# import seaborn
import scipy
import scipy.io as sio
from scipy.ndimage.filters import gaussian_filter
import seaborn
from tqdm.notebook import trange
import time

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
from allensdk.brain_observatory.ecephys.ecephys_session import (
    EcephysSession,
    removed_unused_stimulus_presentation_columns
)
from allensdk.brain_observatory.ecephys.visualization import (
    plot_mean_waveforms, 
    plot_spike_counts, 
    raster_plot)
from allensdk.brain_observatory.visualization import plot_running_speed

# tell pandas to show all columns when we display a DataFrame
pd.set_option("display.max_columns", None)

import data_visualizer
import data_model
import hierarchical_model
import hierarchical_model_generator
import hierarchical_sampling_model
import util
import smoothing_spline

110
Download time: 291600


Number of samples needed to get accurate coverage estimation. If 95% CI is used. $\sigma$ is the standard error.

$\sigma^2 = \text{Var}(C) = \frac{p(1-p)}{n} = \frac{0.95*0.05}{n}$
So, $n = \frac{0.95*0.05}{\sigma^2}$

In [158]:
sigma_coverage_error = 0.005
number_samples = 0.95 * 0.05 / sigma_coverage_error**2
print(number_samples)

sigma_coverage_error = 0.01
number_samples = 0.95 * 0.05 / sigma_coverage_error**2
print(number_samples)

1900.0
475.0


# Hierarchical model Generator

### Load tamplate model

In [2]:
model_files = [
'HBM_checkpoint_B_MC_0_500ms_probeCE_condition1_20200813-092937_generator_template.pkl',
'HBM_checkpoint_BSS_MC_0_500ms_probeCE_o225_270_f8_20200801-174224_generator_template.pkl',
'HBM_checkpoint_BSS_MC_0_500ms_probeCDE_o225_270_f8_20200731-125456_generator_template.pkl',
'798911424_checkpoints_batch10_20200910-194436_generator_template.pkl']

In [4]:

def runner(random_seed):

    print(f'---------------------- {random_seed} --------------------')
    ## Load simulation
    generator_model = hierarchical_model_generator.HierarchicalModelGenerator()
    data_folder = project_dir + 'Output/simulation/'
    file_path = (data_folder + model_files[3])
    generator_model.load_model(file_path)

    np.random.seed(random_seed)
    model_feature_type = 'BSS'
    generator_model.initial_step(model_feature_type=model_feature_type, num_trials=30, num_conditions = 1)

    ## Generate data
    generator_model.generate_mu_sigma(sample_type='fixed', verbose=False)
    generator_model.generate_q()
    generator_model.generate_f_pop_gac(select_clist=[6], same_as_cross=False, verbose=False)
    generator_model.generate_z(verbose=False)
    generator_model.generate_p_gac(verbose=False)
    generator_model.generate_log_lambda_nargc(verbose=False)
    generator_model.generate_spikes(verbose=False)

    ## Prepare for the data fitting.
    trial_time_window=generator_model.trial_time_window
    spike_train_time_line = generator_model.spike_train_time_line
    spike_trains, spike_times = generator_model.spike_trains, generator_model.spike_times
    session = None
    spike_counts, spike_shifts = generator_model.spike_counts, generator_model.spike_shifts
    units_probes = generator_model.selected_units['probe_description']
    probes = generator_model.probes
    selected_units = generator_model.selected_units
    trials_indices = generator_model.trials_indices
    trials_groups = generator_model.trials_groups

    ## Initial HBM
    model = hierarchical_sampling_model.HierarchicalSamplingModel(session)
    np.random.seed(2)
    model.initial_step(spike_trains, spike_times, spike_train_time_line, selected_units, trials_groups,
                       trial_time_window, probes=['probeC', 'probeD', 'probeE'], 
                       model_feature_type=model_feature_type, num_areas=3, num_groups=3, 
                       eta_smooth_tuning=1e-8, verbose=False)

    ##
    clist = [0]
    slc_begin = 200
    slc_step = 5

    t = trange(0, 1000)
    for itr in t:
        slc = (itr >= slc_begin) and (itr % slc_step == 0)
        for c in clist:
            if itr < 150:
                model.update_f_local_pop_cag(c, sample_type='fit', verbose=False)
                model.update_f_cross_pop_ca(c, sample_type='fit', record=(c==clist[-1] and slc), verbose=False)
            elif itr >= 150:
                model.update_f_local_pop_cag(c, sample_type='sample', verbose=False)
                model.update_f_cross_pop_ca(c, sample_type='sample', record=(c==clist[-1] and slc), verbose=False)

            if itr < 15:
                model.update_q_arc(c, sample_type='fit', fit_peak_ratio=0,
                                   record=(c==clist[-1] and slc), verbose=False)
            if (itr >= 15) and (itr < 30):  # Fit peaks.
                model.update_q_arc(c, sample_type='fit', fit_peak_ratio=0.01,
                                   record=(c==clist[-1] and slc), verbose=False)
            if itr >= 30:
                model.update_q_arc(c, sample_type='sample', proposal_scalar=0.02, fit_peak_ratio=0,
                                   record=(c==clist[-1] and slc), verbose=False)
            if itr >= 60 and (itr % 10 == 0):
                model.update_z_ngac(c, sample_type='sample', 
                                    record=(c==clist[-1]), verbose=False)
            if itr >= 60:
                model.update_p_gac(c, sample_type='sample',
                                   record=(c==clist[-1] and slc), verbose=False)
        if (itr >= 60) and (itr < 100):
            model.update_mu_simga(clist=clist, sample_type='iw_fit', update_prior_ratio=0.3, 
                                  record=slc, verbose=False)
        if itr >= 100:
            model.update_mu_simga(clist=clist, sample_type='iw_sample', record=slc, verbose=False)
        model.complete_log_likelihood(clist)
        t.set_description(f'log-likelihood: {model.samples.log_likelihood[-1]:.2f}')

    ## Output
    model.samples.plot_log_likelihood()
    model.samples.plot_marginal_correlation(0, 1, burn_in=0, end=-1, step=1, plot_type='rho',
            true_model=generator_model, model_feature_type=model.model_feature_type,
            distribution_type='hist')

    # Save data.
    experiment_name = (f'{model_feature_type}_0_500ms_probe3_R{len(trials_indices)}_' + 
                       f'sim_slcbegin{slc_begin}_slcstep{slc_step}_')
    timestr = time.strftime("%Y%m%d-%H%M%S")
    output_folder = project_dir + 'Output/simulation_output/BSS_A3_R30_fixed_sigma/'

    prefix = f'HBM_samples_seed{random_seed}_'
    file_path = output_folder + prefix + experiment_name + timestr + '.pkl'
    print(file_path)
    model.samples.save(file_path)

    # Save model.
    prefix = f'HBM_checkpoint_seed{random_seed}_'
    file_path = output_folder + prefix + experiment_name + timestr + '.pkl'
    model.save_model(save_data=False, file_path=file_path)

    # Save true model.
    prefix = f'HBM_generator_seed{random_seed}_'
    file_path = output_folder + prefix + experiment_name + timestr + '.pkl'
    print(file_path)
    generator_model.save_data(save_spikes=False, file_path=file_path)


In [None]:
seed_range = trange(75, 200)
for random_seed in seed_range:
    runner(random_seed)

HBox(children=(FloatProgress(value=0.0, max=125.0), HTML(value='')))

---------------------- 75 --------------------
model_feature_type: BSS

num areas 3   num trials 30   num conditions 1   num qs 3
