# Bernstein Conference 2023: Reliability computations

- Based on 10 seeds
- Data stores containing preprocessed spike signals using `reliability_and_structure/data_analysis/notebooks/run_preprocessing_pipeline.ipynb`

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import h5py
import time
from scipy.spatial import distance

In [2]:
# figs_path = None
figs_path = './bernstein2023_figs'
if not os.path.exists(figs_path):
    os.makedirs(figs_path)

In [3]:
data_stores = ['/gpfs/bbp.cscs.ch/data/scratch/proj83/home/pokorny/SimplifiedConnectomeModels/simulations_v2/SSCx-HexO1-Release-TC__Reliab/working_dir/processed_data_store.h5',
               '/gpfs/bbp.cscs.ch/data/scratch/proj83/home/pokorny/SimplifiedConnectomeModels/simulations_v2/SSCx-HexO1-Release-TC__Reliab__ConnRewireOrder1Hex0EE/working_dir/processed_data_store.h5',
               '/gpfs/bbp.cscs.ch/data/scratch/proj83/home/pokorny/SimplifiedConnectomeModels/simulations_v2/SSCx-HexO1-Release-TC__Reliab__ConnRewireOrder2Hex0EE/working_dir/processed_data_store.h5',
               '/gpfs/bbp.cscs.ch/data/scratch/proj83/home/pokorny/SimplifiedConnectomeModels/simulations_v2/SSCx-HexO1-Release-TC__Reliab__ConnRewireOrder3Hex0EE/working_dir/processed_data_store.h5',
               '/gpfs/bbp.cscs.ch/data/scratch/proj83/home/pokorny/SimplifiedConnectomeModels/simulations_v2/SSCx-HexO1-Release-TC__Reliab__ConnRewireOrder4Hex0EE/working_dir/processed_data_store.h5',
               '/gpfs/bbp.cscs.ch/data/scratch/proj83/home/pokorny/SimplifiedConnectomeModels/simulations_v2/SSCx-HexO1-Release-TC__Reliab__ConnRewireOrder5Hex0EE/working_dir/processed_data_store.h5']
plot_names = ['Original',
              '1st order',
              '2nd order',
              '3rd order',
              '4th order',
              '5th order']

In [4]:
# Modified from compute_reliabilities_10seed_all_connectomes.py
def load_spike_signals(file, sim_idx,return_metadata=False):
    spike_signals = []
    with h5py.File(file, 'r') as f:
        gids = f['gids'][()]
        metadata={'firing_rates':f['firing_rates'][()],
                  'mean_centered':f['mean_centered'][()],
                  'sigma':f['sigma'][()]}
        for sim_id in sim_idx:
            spike_signals.append(f['spike_signals_exc'][f'sim_{sim_id}'][()])
    if return_metadata==False:
        return gids, np.stack(spike_signals)
    elif return_metadata==True:
        return gids, np.stack(spike_signals), metadata

def avg_reliability(v_filt):
    """Computes average reliability between all pairs of trials of a give set.
        v_filt: Array spike trains many simuations of shape N_trials x #neuronss # time_bins"""
    no_cells=v_filt.shape[1]
    mean_rels=[(1-distance.pdist(v_filt[:,x,:],'cosine')).mean() for x in range(no_cells)]
    # mean_rels=[np.nanmean(1-distance.pdist(v_filt[:,x,:],'cosine')) for x in range(no_cells)]
    return np.array(mean_rels)

In [6]:
# Load spike signals and compute reliability
do_mean_centering = False
sim_sel = np.arange(10)
for _dstore, _nm in zip(data_stores, plot_names):
    print(f'Processing "{_nm}"...')
    t0 = time.time()
    nids, spike_signals, metadata = load_spike_signals(_dstore, sim_sel, return_metadata=True)
    print(f'  Loaded data in {time.time() - t0:.1f}s')
    if do_mean_centering:
        t0 = time.time()
        spike_signals = spike_signals - np.mean(spike_signals, 2, keepdims=True)
        metadata['mean_centered'] = True
        print(f'  Mean-centered data in {time.time() - t0:.1f}s')
    t0 = time.time()
    reliab = avg_reliability(spike_signals)
    print(f'  Computed reliabilities in {time.time() - t0:.1f}s')
    with open(os.path.join(figs_path, f'reliability_rates_centered{metadata["mean_centered"]}__{_nm}.pkl'), 'wb') as f:
        pickle.dump({"reliab": reliab, "rates": np.mean(metadata['firing_rates'], 0)}, f)

Processing "3rd order"...
  Loaded data in 193.1s
  Computed reliabilities in 45.8s


In [None]:
# Trim PNGs
# mogrify -trim *.png