In [10]:
import os
import h5py

import numpy as np
import scipy
from scipy import stats, signal
from scipy import linalg as LA
from numpy.core.records import fromarrays
from scipy.io import savemat
from tqdm import tqdm
import matplotlib.pyplot as plt

%matplotlib notebook

In [11]:
def get_sparse_dynamic_wts(st_wts, dyn_wt_dict, send, recv,
                           recv_spikes=None, static_only=False):  
    dyn_wts = dyn_wt_dict[recv][send]
    if st_wts is not None:
        if static_only:
            return np.tile(st_wts.reshape((1, 1, -1)), (dyn_wts.shape[0], dyn_wts.shape[1], 1))
        else:
            st_wts = st_wts.reshape((1, 1, -1))
            return st_wts + dyn_wts
    else:
        return dyn_wts
    
def avg_over_send_spikes(spikes, wts, avg_len=0, plot_wt_count=False):
    #spikes is [b, t]
    #wts is [b, t, num_delays]
    result = np.zeros_like(wts)
    counts = np.zeros_like(wts)
    for delay in range(wts.shape[2]):
        if delay == 0:
            result[:, delay:, delay] += spikes*wts[:, :, 0]
            counts[:, :, delay] += spikes
        else:
            result[:, delay:, delay] += spikes[:, :-delay]*wts[:, delay:, delay]
            counts[:, delay:, delay] += spikes[:, :-delay]
    if avg_len > 0:
        final_result = np.copy(result)
        final_counts = np.copy(counts)
        for i in range(1, avg_len):
            final_result[:, i:] += result[:, :-i]
            final_counts[:, i:] += counts[:, :-i]
        result = final_result
        counts = final_counts
    if plot_wt_count:
        return counts.sum(0)
    else:
        return result.sum(0) / (counts.sum(0) + 1e-8)
    
def get_spikes_for_idx(spikes_path, idx, val_only=False, train_only=False):
    with h5py.File(spikes_path, 'r') as fin:
        if val_only:    
            all_spikes = fin['val'][:, :, idx]
        elif train_only:
            all_spikes = fin['train'][:, :, idx]
        else:
            spikes_train = fin['train'][:, :, idx]
            spikes_val = fin['val'][:, :, idx]
            all_spikes = np.concatenate([spikes_train, spikes_val])
    return all_spikes

def get_wt_img_vals(spikes_path, n1, n2, num_delays, static_wts, dyn_offsets, start_time=0, bin_size=5,
                          plot_start=None, plot_end=None, title=None, trial_num=None, dyn_only=False,
                             subtract_outsides=False, avg_len=0, plot_max=None, shift_wts=0, use_max=False,
                 val_only=False, avg_at_spikes=False, plot_wt_count=False, static_convert=None,
                 plot_info=None, normalize_across_time=False, use_wt_zero_max=True,
                       use_wt_peak=False, delay_avg_len=None, use_gaussian_delay_averaging=False,
                     units=None, id2localidx=None):
    if static_convert is not None:
        s_n1, s_n2 = static_convert[n1], static_convert[n2]
    else:
        s_n1, s_n2 = n1, n2
    seqstdyn_static_0to1 = static_wts[s_n1, s_n2]
    seqstdyn_static_1to0 = static_wts[s_n2, s_n1]
    if dyn_only:
        forward_dyn_wts = get_sparse_dynamic_wts(None, dyn_offsets, n1, n2)
        backward_dyn_wts = get_sparse_dynamic_wts(None, dyn_offsets, n2, n1)
    else:
        forward_dyn_wts = get_sparse_dynamic_wts(seqstdyn_static_0to1, dyn_offsets, n1, n2)
        backward_dyn_wts = get_sparse_dynamic_wts(seqstdyn_static_1to0, dyn_offsets, n2, n1)
    if trial_num is not None:
        forward_dyn_wts = forward_dyn_wts[trial_num]
        backward_dyn_wts = backward_dyn_wts[trial_num]
    elif avg_at_spikes:
        if id2localidx is not None:
            id1, id2 = units.iloc[n1].name, units.iloc[n2].name
            sp_n1 = id2localidx[id1]
            sp_n2 = id2localidx[id2]
        else:
            sp_n1, sp_n2 = n1, n2
        tmp_sp_1 = get_spikes_for_idx(spikes_path, sp_n1, val_only=True)[:, :forward_dyn_wts.shape[1]]
        tmp_sp_2 = get_spikes_for_idx(spikes_path, sp_n2, val_only=True)[:, :forward_dyn_wts.shape[1]]
        forward_dyn_wts = avg_over_send_spikes(tmp_sp_1, forward_dyn_wts, avg_len=avg_len, plot_wt_count=plot_wt_count)
        backward_dyn_wts = avg_over_send_spikes(tmp_sp_2, backward_dyn_wts, avg_len=avg_len, plot_wt_count=plot_wt_count)
    elif use_max:
        forward_dyn_wts = forward_dyn_wts.max(0)
        backward_dyn_wts = backward_dyn_wts.max(0)
    else:
        forward_dyn_wts = forward_dyn_wts.mean(0)
        backward_dyn_wts = backward_dyn_wts.mean(0)
    t = len(forward_dyn_wts)
    if plot_end is None:
        plot_end = t*bin_size
    if plot_start is None:
        plot_start = start_time
    first_bin = int(plot_start - start_time) // bin_size
    last_bin = int(plot_end-start_time) // bin_size
    first_time = start_time + first_bin*bin_size
    last_time = start_time + last_bin*bin_size
    if avg_len > 0 and not avg_at_spikes:
        final_forward = np.copy(forward_dyn_wts)
        final_backward = np.copy(backward_dyn_wts)
        counts = np.ones_like(final_forward)
        for i in range(1, avg_len):
            final_forward[i:] += forward_dyn_wts[:-i]
            final_backward[i:] += backward_dyn_wts[:-i]
            counts[i:] += 1
        if plot_wt_count:
            forward_dyn_wts = counts
            backward_dyn_wts = counts
        else:
            forward_dyn_wts = final_forward / counts
            backward_dyn_wts = final_backward / counts
    forward_dyn_wts = forward_dyn_wts[first_bin:last_bin]
    backward_dyn_wts = backward_dyn_wts[first_bin:last_bin]
    forward_dyn_wts = np.copy(forward_dyn_wts)
    final_result = np.concatenate([np.flip(forward_dyn_wts.T, (0,)), backward_dyn_wts[:, 1:].T])
    
    if delay_avg_len is not None:
        if use_gaussian_delay_averaging:
            final_result = gaussian_filter1d(final_result, delay_avg_len, axis=0, mode='nearest')
        else:
            tmp_result = np.empty((len(final_result)+2*delay_avg_len, final_result.shape[1]))
            tmp_result[:delay_avg_len] = final_result[0]
            tmp_result[-delay_avg_len:] = final_result[-1]
            tmp_result[delay_avg_len:-delay_avg_len] = final_result
            final_tmp_result = np.zeros_like(tmp_result)
            for idx in range(delay_avg_len, len(final_tmp_result)-delay_avg_len):
                final_tmp_result[idx] = tmp_result[idx-delay_avg_len:idx+delay_avg_len+1].mean(0)
            final_result = final_tmp_result[delay_avg_len:-delay_avg_len]
    
    return final_result, forward_dyn_wts, backward_dyn_wts

def compute_sig_for_wts(static_wts, n1, n2, final_result, forward_dyn_wts, backward_dyn_wts, 
                        normalize_across_time=True,
                        use_wt_peak=False, static_convert=None, delay_ind=None,
                        plot_std_units=True):   
    out_len = forward_dyn_wts.shape[1]//2
    outsides = np.concatenate([final_result[:out_len], final_result[-out_len:]])
    if normalize_across_time:
        means = outsides.mean(keepdims=True)
        stds = outsides.std(keepdims=True)+1e-8
    else:
        means = outsides.mean(0, keepdims=True)
        stds = outsides.std(axis=0, keepdims=True)+1e-8
    
    if delay_ind is not None:
        all_wts = np.concatenate([np.flip(forward_dyn_wts.T, (0,)), backward_dyn_wts.T[1:]])
        if plot_std_units:
            wt_plot_vals = ((all_wts - means)/stds)
        else:
            wt_plot_vals = all_wts
        wt_plot_vals = wt_plot_vals[delay_ind + forward_dyn_wts.shape[1]-1]
    elif use_wt_peak:
        if plot_std_units:
            wt_plot_vals = ((forward_dyn_wts.T - means)/stds)
        else:
            wt_plot_vals = forward_dyn_wts.T
        wt_plot_vals = wt_plot_vals[:-out_len]
        if static_convert is not None:
            s_n1, s_n2 = static_convert[n1], static_convert[n2]
        else:
            s_n1, s_n2 = n1, n2
        seqstdyn_static_0to1 = static_wts[s_n1, s_n2]
        seqstdyn_static_1to0 = static_wts[s_n2, s_n1]
        delay_ind = np.abs(seqstdyn_static_0to1[:out_len+1]).argmax()
        wt_plot_vals = wt_plot_vals[delay_ind]
    
    else:
        if plot_std_units:
            wt_plot_vals = ((forward_dyn_wts.T - means)/stds)
        else:
            wt_plot_vals = forward_dyn_wts.T
        wt_plot_vals = wt_plot_vals[:-out_len]
        wt_plot_vals = wt_plot_vals.max(axis=0)
    return wt_plot_vals

def get_all_spikes(spikes_path):
    with h5py.File(spikes_path, 'r') as fin:
        spikes_train = fin['train'][:]
        spikes_val = fin['val'][:]
    all_spikes = np.concatenate([spikes_train, spikes_val])
    return all_spikes

def get_dim(data, dim_thresh, filter_width):
    if filter_width > 0:
        out = np.zeros_like(data)
        fltHL = int(np.ceil(3*filter_width))
        flt = scipy.stats.norm.pdf(np.arange(-fltHL, fltHL+1), 0, filter_width)
        yDim, T = data.shape
        nm = scipy.signal.convolve(flt, np.ones(T))

        for i in range(yDim):
            ys = scipy.signal.convolve(flt, data[i]) / nm
            out[i] = ys[fltHL:-fltHL]
        data = out
    data = data.T
    
    # run PCA
    data -= np.mean(data, axis=0)
    cov = np.cov(data, rowvar = False)
    evals, evecs = LA.eig(cov)
    
    idx = np.argsort(evals)[::-1]
    evecs = evecs[:,idx]
    evals = evals[idx]
    cum_scores = np.cumsum(evals)/np.sum(evals)
    dim = (cum_scores > dim_thresh).nonzero()[0][0]
    return dim, cum_scores

In [12]:
root_data_dir='session_data'
out_dir = 'figures/dimensionality_data/'
os.makedirs(out_dir, exist_ok=True)
use_stronger_pair_filtering = False

In [1]:
sessions = [766640955,767871931,768515987,771160300,771990200,774875821,778240327,
            778998620,779839471,781842082,786091066,787025148,789848216,793224716,
            794812542,816200189,819186360,819701982,821695405,829720705,831882777,
            835479236,839068429,839557629,840012044,847657808]

plot_start=-100
plot_end=400
start_time=-150
bin_size=5
start_bin = (plot_start-start_time)//5
end_bin = (plot_end-start_time)//5

dim_thresh = 0.95
filter_width=0

max_ratio = 0
max_diff = 0
max_ratio_sess = None
max_diff_sess = None

for session in tqdm(sessions):
    data_dir = os.path.join(root_data_dir, str(session))
    spikes_5ms_path = os.path.join(root_data_dir, str(session), 'spikes_all_5ms.h5')
    spikes_5ms = get_all_spikes(spikes_5ms_path)
    spikes_5ms = spikes_5ms.mean(0).T
    model_1ms_dir = os.path.join(data_dir, 'model_1ms')
    model_5ms_dir = os.path.join(data_dir, 'model_5ms')
    
    if use_stronger_pair_filtering:
        pair_file = os.path.join(data_dir, 'model_5ms_extrafilter', 'pairs.npy')
        bin5_dyn_path = os.path.join(data_dir, 'model_5ms_extrafilter', 'dynamic_offsets_val.npy')
        allp_static_wts_5ms_path = os.path.join(model_5ms_dir, 'static_weights_val.npy')
    else:
        pair_file = os.path.join(data_dir, 'model_5ms', 'pairs.npy')
        bin5_dyn_path = os.path.join(model_5ms_dir, 'dynamic_offsets_val.npy')
        allp_static_wts_5ms_path = os.path.join(model_5ms_dir, 'static_weights_val.npy')
        bin5_spikes_path = os.path.join(data_dir, 'spikes_all_5ms.h5')
   
    
    pairs = np.load(pair_file)

    used_n = np.unique(pairs)
    valid = np.array([n in used_n for n in range(len(spikes_5ms))])
    spikes_5ms = spikes_5ms[valid, start_bin:end_bin]
    spikes_5ms_full = spikes_5ms[:, start_bin:end_bin]

    allp_static_wts = np.load(allp_static_wts_5ms_path)[0]
    allp_dyn_wts = np.load(bin5_dyn_path, allow_pickle=True).item()
    pairs = np.load(pair_file)

    all_pair_sig = []
    cum_wt_dim = []
    wt_dims = []
    sp_dims = []
    sp_dims_full = []
    
    for pair_idx in range(len(pairs)):
        n1, n2 = pairs[pair_idx]
        avg_len = 0
        use_wt_peak = True
        avg_at_spikes=True
        use_corrected_jpsth=False
        allp_wts_im, allp_wt_f, allp_wt_b = get_wt_img_vals(spikes_5ms_path, n1, n2, 10, allp_static_wts, allp_dyn_wts, 
                                                            start_time=-150, bin_size=5,
                                                            plot_start=-100, plot_end=400,
                                                            avg_len=avg_len,
                                                            avg_at_spikes=avg_at_spikes, static_convert=None,
                                                            normalize_across_time=True, use_wt_peak=use_wt_peak)
        allp_wts_sig = compute_sig_for_wts(allp_static_wts, n1, n2, allp_wts_im, allp_wt_f, allp_wt_b,
                                           static_convert=None,
                                           normalize_across_time=True, use_wt_peak=use_wt_peak)
        all_pair_sig.append(allp_wts_sig)
    all_pair_sig = np.stack(all_pair_sig)
    
    tmp_all_pair_sig = all_pair_sig - all_pair_sig.min()

    
    wt_dim, wt_cum_var = get_dim(tmp_all_pair_sig, dim_thresh, filter_width)
    sp_dim, sp_cum_var = get_dim(spikes_5ms, dim_thresh, filter_width)
    sp_dim_full, sp_cum_var_full = get_dim(spikes_5ms_full, dim_thresh, filter_width)
    
    wt_dims.append(wt_dim)
    sp_dims.append(sp_dim)
    sp_dims_full.append(sp_dim_full)
    
    cum_wt_dim.append([session, wt_dim, sp_dim, sp_dim_full, len(spikes_5ms), len(pairs), len(spikes_5ms_full)])
    print(cum_wt_dim)
    
    dim_ratio = wt_dim/sp_dim
    if dim_ratio > max_ratio:
        max_ratio = dim_ratio
        max_ratio_sess = session
        max_ratio_vals = [wt_dim, sp_dim]
    dim_diff = wt_dim - sp_dim
    if dim_diff > max_diff:
        max_diff = dim_diff
        max_diff_sess = session
        max_diff_vals = [wt_dim, sp_dim]        

In [25]:
print(max_ratio, max_ratio_sess, max_ratio_vals)
print(max_diff, max_diff_sess, max_diff_vals)

8.0 831882777 [32, 4]
29 786091066 [36, 7]


In [76]:
session = 831882777
plot_smoothing_kernel_width = 0

data_dir = os.path.join(root_data_dir, str(session))
spikes_5ms_path = os.path.join(root_data_dir, str(session), 'spikes_all_5ms.h5')
spikes_5ms = get_all_spikes(spikes_5ms_path)
spikes_5ms = spikes_5ms.mean(0).T
model_1ms_dir = os.path.join(data_dir, 'model_1ms')
model_5ms_dir = os.path.join(data_dir, 'model_5ms')

if use_stronger_pair_filtering:
    pair_file = os.path.join(data_dir, 'model_5ms_extrafilter', 'pairs.npy')
    bin5_dyn_path = os.path.join(data_dir, 'model_5ms_extrafilter', 'dynamic_offsets_val.npy')
    allp_static_wts_5ms_path = os.path.join(model_5ms_dir, 'static_weights_val.npy')
else:
    pair_file = os.path.join(data_dir, 'model_5ms', 'pairs.npy')
    bin5_dyn_path = os.path.join(model_5ms_dir, 'dynamic_offsets_val.npy')
    allp_static_wts_5ms_path = os.path.join(model_5ms_dir, 'static_weights_val.npy')
    bin5_spikes_path = os.path.join(data_dir, 'spikes_all_5ms.h5')


pairs = np.load(pair_file)

used_n = np.unique(pairs)
valid = np.array([n in used_n for n in range(len(spikes_5ms))])
spikes_5ms_full = spikes_5ms[:, start_bin:end_bin]
spikes_5ms = spikes_5ms[valid, start_bin:end_bin]

allp_static_wts = np.load(allp_static_wts_5ms_path)[0]
allp_dyn_wts = np.load(bin5_dyn_path, allow_pickle=True).item()

all_pair_sig = []
for pair_idx in range(len(pairs)):
    n1, n2 = pairs[pair_idx]
    avg_len = 0
    use_wt_peak = True
    avg_at_spikes=True
    use_corrected_jpsth=False
    allp_wts_im, allp_wt_f, allp_wt_b = get_wt_img_vals(spikes_5ms_path, n1, n2, 10, allp_static_wts, allp_dyn_wts, 
                                                        start_time=-150, bin_size=5,
                                                        plot_start=-100, plot_end=400,
                                                        avg_len=avg_len,
                                                        avg_at_spikes=avg_at_spikes, static_convert=None,
                                                        normalize_across_time=True, use_wt_peak=use_wt_peak)
    allp_wts_sig = compute_sig_for_wts(allp_static_wts, n1, n2, allp_wts_im, allp_wt_f, allp_wt_b,
                                       static_convert=None,
                                       normalize_across_time=True, use_wt_peak=use_wt_peak)
    all_pair_sig.append(allp_wts_sig)
all_pair_sig = np.stack(all_pair_sig)

tmp_all_pair_sig = all_pair_sig - all_pair_sig.min()

wt_dims = []
sp_dims = []
sp_dims_full = []
filter_widths = np.arange(11)

for filter_width in filter_widths:
    wt_dim, wt_cum_var = get_dim(tmp_all_pair_sig, dim_thresh, filter_width)
    sp_dim, sp_cum_var = get_dim(spikes_5ms, dim_thresh, filter_width)
    sp_dim_full, sp_cum_var_full = get_dim(spikes_5ms_full, dim_thresh, filter_width)
    
   
    wt_dims.append(wt_dim)
    sp_dims.append(sp_dim)
    sp_dims_full.append(sp_dim_full)
    
    if filter_width == plot_smoothing_kernel_width:
        plot_cum_vars = [wt_cum_var, sp_cum_var, sp_cum_var_full]

In [3]:
fig, ax = plt.subplots()
ax.plot(filter_widths, wt_dims, label='Model')
ax.plot(filter_widths, sp_dims, label='Spikes (Model pairs only)')
ax.plot(filter_widths, sp_dims_full, label='Spikes (full)')
ax.set_xlabel('Smoothing Kernel Width')
ax.set_ylabel('Dimensionality')
ax.legend()
fig.show()

plot_lines = np.stack([
    filter_widths,
    wt_dims,
    sp_dims,
    sp_dims_full,
],axis=1)
out_path = os.path.join(out_dir, '%d_dimensionality.csv')
np.savetxt(out_path, plot_lines, delimiter=',', header='kernel_width,model,spikes_model_pairs,spikes_full')
print(max(wt_dims), max(sp_dims), max(sp_dims_full))

In [2]:
fig, ax = plt.subplots()
ax.plot(np.arange(1, len(plot_cum_vars[0])+1), plot_cum_vars[0], label='Model')
ax.plot(np.arange(1, len(plot_cum_vars[1])+1), plot_cum_vars[1], label='Spikes (Model pairs only)')
ax.plot(np.arange(1, len(plot_cum_vars[2])+1), plot_cum_vars[2], label='Spikes (full)')
ax.set_xlabel('Latent Dimensionality')
ax.set_ylabel('Cumulative % variance explained')
ax.legend()
fig.show()

out_path = os.path.join(out_dir, '%d_variance_model.csv')
plot_lines = np.stack([
    np.arange(1, len(plot_cum_vars[0])+1),
    plot_cum_vars[0],
], axis=1)
np.savetxt(out_path, plot_lines, delimiter=',', header='latent_dimensionality,cum_var_explained')

out_path = os.path.join(out_dir, '%d_variance_spikes_pairs.csv')
plot_lines = np.stack([
    np.arange(1, len(plot_cum_vars[1])+1),
    plot_cum_vars[1],
], axis=1)
np.savetxt(out_path, plot_lines, delimiter=',', header='latent_dimensionality,cum_var_explained')

out_path = os.path.join(out_dir, '%d_variance_spikes_full.csv')
plot_lines = np.stack([
    np.arange(1, len(plot_cum_vars[2])+1),
    plot_cum_vars[2],
], axis=1)
np.savetxt(out_path, plot_lines, delimiter=',', header='latent_dimensionality,cum_var_explained')