In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

# Imports and Paths

In [2]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
import copy
import sklearn
import rastermap
import pandas as pd

import tensorly as tl
import tensorly.decomposition

In [104]:
dir_github        = Path(r'D:\RH_local\github').resolve()

dir_analysisFiles = Path(r'D:\RH_local\data\BMI_cage_1511_4\mouse_1511L\20230111\analysis_data').resolve()

## Directory with F.npy, stat.npy etc.
dir_s2p         = Path(r'D:\RH_local\data\BMI_cage_1511_4\mouse_1511L\20230111\analysis_data\suite2p\plane0').resolve()
# dir_s2p         = Path(r'C:\Users\Rich Hakim\Downloads\F.npy').resolve()

## Path of iscell. Can be from neural net output (iscell_NN)
path_iscell = dir_analysisFiles / 'iscell_NN_tqm.npy'

## Path of tqm (trace quality metrics). Used to get dFoF parameters
path_tqm = dir_analysisFiles / 'trace_quality.pkl'


dir_save       =  copy.copy(dir_analysisFiles)
path_save = dir_save / 'weights_day0'
# path_save = dir_save / 'weights_day0_PC2'

In [105]:
import sys
sys.path.append(str(dir_github))

%load_ext autoreload
%autoreload 2

from bnpm import torch_helpers, file_helpers, timeSeries, ca2p_preprocessing, welford_moving_2D, linear_regression, similarity

%load_ext autoreload
%autoreload 2
from Big_Ugly_ROI_Tracker.multiEps.multiEps_modules import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [106]:
def get_highest_evr_var(data, factors,  hits):
    evrs = np.zeros(factors.shape[0])
    for i, factor in enumerate(factors):
        _,_,_,evr_total = similarity.orthogonalize(data, factor)
        evrs[i] = evr_total
    evrs[np.logical_not(hits)] = 0
    highest_evr_ind = np.argsort(evrs)[-1]
    return highest_evr_ind, factors[highest_evr_ind,:], evrs

def orthogonalize_simple(data, factor):
    proj_vec = similarity.proj(data, factor)[0]
    return (np.squeeze(data) - np.squeeze(proj_vec))

In [107]:
DEVICE = torch_helpers.set_device(use_GPU=True)

devices available: [_CudaDeviceProperties(name='NVIDIA GeForce RTX 2080 Ti', major=7, minor=5, total_memory=11264MB, multi_processor_count=68)]
Using device: 'cuda:0': _CudaDeviceProperties(name='NVIDIA GeForce RTX 2080 Ti', major=7, minor=5, total_memory=11264MB, multi_processor_count=68)


In [108]:
iscell = np.load(path_iscell)
# spks = np.load(path_spks)

In [109]:
tqm = file_helpers.pickle_load(path_tqm)
dFoF_params = tqm['dFoF_params']

In [110]:
## == IMPORT DATA ==
F = np.load(dir_s2p / 'F.npy') # masks multiplied by data
Fneu = np.load(dir_s2p / 'Fneu.npy') # estimated neuropil signal (Fns = F - Fneu; Fo = ptile(Fns, 30); dFoF=Fns-Fo/Fo)
# iscell = np.load(dir_s2p/'iscell.npy') # probability and bool of each roi
ops = np.load(dir_s2p / 'ops.npy', allow_pickle=True) # parameters for the suite2p
spks_s2p = np.load(dir_s2p / 'spks.npy') # blind deconvolution
stat = np.load(dir_s2p / 'stat.npy', allow_pickle=True) # statistics for individual neurons 

num_frames_S2p = F.shape[1]
Fs = ops[()]['fs']

In [111]:
# # Oopsie!
# # Mouse g2FB: delete 43000 - 50000 frames
# F = np.delete(F, range(43000,50000), axis=1)
# Fneu = np.delete(Fneu, range(43000,50000), axis=1)

In [112]:
Fneu.shape

(3608, 108000)

In [113]:
frame_height = ops[()]['meanImg'].shape[0]
frame_width = ops[()]['meanImg'].shape[1]

sf = import_and_convert_to_CellReg_spatialFootprints([dir_s2p / 'stat.npy'], frame_height=frame_height, frame_width=frame_width, dtype=np.float32)[0]

In [114]:
F_toUse = F[iscell]
Fneu_toUse = Fneu[iscell]

# Prepare dFoF

In [115]:
win_smooth = 4
kernel_smoothing = np.zeros(win_smooth*2)
kernel_smoothing[win_smooth:] = 1
kernel_smoothing /= kernel_smoothing.sum()

In [116]:
# Pipeline for the NMF Strategy 
# Smooth F
F_smooth = timeSeries.convolve_along_axis(
    F_toUse,
    kernel=kernel_smoothing,
    axis=1,mode='same',
    multicore_pref=True,
    verbose=True
).astype(np.float32)

# dFoF with reduced percentile for baseline
channelOffset_correction = 500
percentile_baseline = 5
neuropil_fraction=0.7

dFoF , dF , F_neuSub , F_baseline = ca2p_preprocessing.make_dFoF(
    F=F_smooth + channelOffset_correction,
    Fneu=Fneu_toUse + channelOffset_correction,
    neuropil_fraction=neuropil_fraction,
    percentile_baseline=percentile_baseline,
    multicore_pref=True,
    verbose=True
)
# Threshold for nonnegativity
dFoF_z = dFoF / np.std(dFoF,axis=1,keepdims=True)

ThreadPool elapsed time : 0.12 s. Now unpacking list into array.
Calculated convolution. Total elapsed time: 0.24 seconds
Calculated dFoF. Total elapsed time: 0.87 seconds


In [331]:
# Test out rolling subtraction of the 10th percentile of the daa to remove microscope movement artifacts

ptile = 10
window = int(Fs*60*1)

# dFoF_sub_ptile = dFoF - timeSeries.rolling_percentile_pd(dFoF, ptile=ptile, window=window)
dFoF_sub_ptile = dFoF - timeSeries.rolling_percentile_rq_multicore(dFoF, ptile=ptile, window=window)
# dFoF_sub_ptile_clipped = np.clip(dFoF_sub_ptile, a_min=0, a_max=None)
dFoF_sub_ptile_clipped = np.clip(dFoF_sub_ptile, a_min=0.2, a_max=None)

In [None]:
dFoF_sub_ptile_clipped_ms = dFoF_sub_ptile_clipped - dFoF_sub_ptile_clipped.mean(x, keepdims=True)

### Look at rastermap

In [332]:
import rastermap

In [333]:
rmap = rastermap.Rastermap(
    n_components=1,
    n_X=40,
    nPC=200,
    init='pca',
    alpha=1.0,
    K=1.0,
    mode='basic',
    verbose=True,
    annealing=True,
    constraints=2,
)

In [334]:
# rmap.fit(dFoF_sub_ptile)
rmap.fit(dFoF_sub_ptile_clipped)
# rmap.fit(scipy.stats.zscore(dFoF))

nmin 200
0.19745445251464844
6.3685524463653564
6.608888387680054
6.611880779266357
(38, 40)
(70,)
1.0
time; iteration;  explained PC variance
0.00s     0        0.1543      2
0.02s    10        0.2925      4
0.04s    20        0.3391      8
0.06s    30        0.4415      18
0.09s    40        0.4870      28
0.11s    50        0.5403      38
0.13s    60        0.5404      38
0.15s   final      0.5404
0.15s upsampled    0.5404


<rastermap.mapping.Rastermap at 0x275cbe6be80>

In [335]:
%matplotlib notebook

plt.figure()
# plt.imshow(dFoF_sub_ptile[rmap.isort], aspect='auto', vmax=1)
# plt.imshow(np.clip(scipy.stats.zscore(dFoF_sub_ptile, axis=1), -1,1)[rmap.isort], aspect='auto', vmax=1)
plt.imshow(dFoF_sub_ptile_clipped[rmap.isort], aspect='auto', vmin=-0.2, vmax=1)
# plt.imshow(scipy.stats.zscore(dFoF)[rmap.isort], aspect='auto', vmin=-0.1, vmax=1)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x275cbe7efa0>

# Run NMF

In [336]:
# indices_cropped = np.arange(int(Fs*60*23),dFoF_sub_ptile.shape[1])

In [337]:
# neural_data_toUse =dFoF / np.std(dFoF,axis=1,keepdims=True)
# neural_data_toUse = (dFoF_sub_ptile / np.std(dFoF_sub_ptile,axis=1,keepdims=True))[:,indices_cropped]
neural_data_toUse = (dFoF_sub_ptile_clipped / np.std(dFoF_sub_ptile_clipped,axis=1,keepdims=True))

In [338]:
tl.set_backend('pytorch')

In [339]:
neural_data_toUse = tl.tensor(neural_data_toUse).to('cuda')

In [340]:
# Roll a dice and wish for luck
rank = 10

In [341]:
factors_nmf = tl.decomposition.non_negative_parafac(
            neural_data_toUse,
            rank=rank,
            n_iter_max=1000,
            init='random',
            svd='numpy_svd', 
            tol=1e-07,
        #     random_state=None,
            verbose=1,
        #     normalize_factors=False,
        #     return_errors=False,
        #     mask=None,
        #     cvg_criterion='abs_rec_error',
        #     fixed_modes=None
        )

reconstruction error=0.19661574065685272
iteration 1, reconstraction error: 0.19418169558048248, decrease = 0.0024340450763702393
iteration 2, reconstraction error: 0.19207856059074402, decrease = 0.0021031349897384644
iteration 3, reconstraction error: 0.19024042785167694, decrease = 0.0018381327390670776
iteration 4, reconstraction error: 0.18862511217594147, decrease = 0.0016153156757354736
iteration 5, reconstraction error: 0.18719330430030823, decrease = 0.0014318078756332397
iteration 6, reconstraction error: 0.1859154850244522, decrease = 0.001277819275856018
iteration 7, reconstraction error: 0.1847711205482483, decrease = 0.0011443644762039185
iteration 8, reconstraction error: 0.1837397962808609, decrease = 0.0010313242673873901
iteration 9, reconstraction error: 0.18280665576457977, decrease = 0.0009331405162811279
iteration 10, reconstraction error: 0.18195872008800507, decrease = 0.000847935676574707
iteration 11, reconstraction error: 0.18118572235107422, decrease = 0.000

iteration 105, reconstraction error: 0.16891540586948395, decrease = 1.6644597053527832e-05
iteration 106, reconstraction error: 0.16889925301074982, decrease = 1.615285873413086e-05
iteration 107, reconstraction error: 0.16888310015201569, decrease = 1.615285873413086e-05
iteration 108, reconstraction error: 0.16886793076992035, decrease = 1.5169382095336914e-05
iteration 109, reconstraction error: 0.16885323822498322, decrease = 1.4692544937133789e-05
iteration 110, reconstraction error: 0.16883854568004608, decrease = 1.4692544937133789e-05
iteration 111, reconstraction error: 0.16882386803627014, decrease = 1.4677643775939941e-05
iteration 112, reconstraction error: 0.1688096672296524, decrease = 1.4200806617736816e-05
iteration 113, reconstraction error: 0.16879546642303467, decrease = 1.4200806617736816e-05
iteration 114, reconstraction error: 0.16878272593021393, decrease = 1.2740492820739746e-05
iteration 115, reconstraction error: 0.16876952350139618, decrease = 1.320242881774

iteration 199, reconstraction error: 0.16819846630096436, decrease = 3.4421682357788086e-06
iteration 200, reconstraction error: 0.16819404065608978, decrease = 4.425644874572754e-06
iteration 201, reconstraction error: 0.168190598487854, decrease = 3.4421682357788086e-06
iteration 202, reconstraction error: 0.16818666458129883, decrease = 3.933906555175781e-06
iteration 203, reconstraction error: 0.16818274557590485, decrease = 3.919005393981934e-06
iteration 204, reconstraction error: 0.16817831993103027, decrease = 4.425644874572754e-06
iteration 205, reconstraction error: 0.1681748777627945, decrease = 3.4421682357788086e-06
iteration 206, reconstraction error: 0.16817143559455872, decrease = 3.4421682357788086e-06
iteration 207, reconstraction error: 0.16816702485084534, decrease = 4.410743713378906e-06
iteration 208, reconstraction error: 0.16816259920597076, decrease = 4.425644874572754e-06
iteration 209, reconstraction error: 0.1681591421365738, decrease = 3.4570693969726562e-0

iteration 292, reconstraction error: 0.1678457409143448, decrease = 3.948807716369629e-06
iteration 293, reconstraction error: 0.16784131526947021, decrease = 4.425644874572754e-06
iteration 294, reconstraction error: 0.16783785820007324, decrease = 3.4570693969726562e-06
iteration 295, reconstraction error: 0.16783343255519867, decrease = 4.425644874572754e-06
iteration 296, reconstraction error: 0.1678299754858017, decrease = 3.4570693969726562e-06
iteration 297, reconstraction error: 0.16782556474208832, decrease = 4.410743713378906e-06
iteration 298, reconstraction error: 0.16782161593437195, decrease = 3.948807716369629e-06
iteration 299, reconstraction error: 0.1678161919116974, decrease = 5.424022674560547e-06
iteration 300, reconstraction error: 0.167813241481781, decrease = 2.950429916381836e-06
iteration 301, reconstraction error: 0.16780880093574524, decrease = 4.4405460357666016e-06
iteration 302, reconstraction error: 0.16780585050582886, decrease = 2.950429916381836e-06
i

iteration 384, reconstraction error: 0.16747106611728668, decrease = 3.933906555175781e-06
iteration 385, reconstraction error: 0.1674685776233673, decrease = 2.4884939193725586e-06
iteration 386, reconstraction error: 0.16746415197849274, decrease = 4.425644874572754e-06
iteration 387, reconstraction error: 0.16746069490909576, decrease = 3.4570693969726562e-06
iteration 388, reconstraction error: 0.1674567461013794, decrease = 3.948807716369629e-06
iteration 389, reconstraction error: 0.16745181381702423, decrease = 4.932284355163574e-06
iteration 390, reconstraction error: 0.16744835674762726, decrease = 3.4570693969726562e-06
iteration 391, reconstraction error: 0.16744489967823029, decrease = 3.4570693969726562e-06
iteration 392, reconstraction error: 0.1674414426088333, decrease = 3.4570693969726562e-06
iteration 393, reconstraction error: 0.16743749380111694, decrease = 3.948807716369629e-06
iteration 394, reconstraction error: 0.16743354499340057, decrease = 3.948807716369629e-

iteration 474, reconstraction error: 0.16716179251670837, decrease = 2.9653310775756836e-06
iteration 475, reconstraction error: 0.1671588271856308, decrease = 2.9653310775756836e-06
iteration 476, reconstraction error: 0.16715537011623383, decrease = 3.4570693969726562e-06
iteration 477, reconstraction error: 0.16715289652347565, decrease = 2.473592758178711e-06
iteration 478, reconstraction error: 0.16714943945407867, decrease = 3.4570693969726562e-06
iteration 479, reconstraction error: 0.1671469658613205, decrease = 2.473592758178711e-06
iteration 480, reconstraction error: 0.16714350879192352, decrease = 3.4570693969726562e-06
iteration 481, reconstraction error: 0.16714052855968475, decrease = 2.9802322387695312e-06
iteration 482, reconstraction error: 0.16713707149028778, decrease = 3.4570693969726562e-06
iteration 483, reconstraction error: 0.167135089635849, decrease = 1.9818544387817383e-06
iteration 484, reconstraction error: 0.16713163256645203, decrease = 3.457069396972656

iteration 567, reconstraction error: 0.1668945699930191, decrease = 2.4884939193725586e-06
iteration 568, reconstraction error: 0.16689209640026093, decrease = 2.473592758178711e-06
iteration 569, reconstraction error: 0.16688913106918335, decrease = 2.9653310775756836e-06
iteration 570, reconstraction error: 0.16688665747642517, decrease = 2.473592758178711e-06
iteration 571, reconstraction error: 0.1668836921453476, decrease = 2.9653310775756836e-06
iteration 572, reconstraction error: 0.16688069701194763, decrease = 2.995133399963379e-06
iteration 573, reconstraction error: 0.16687822341918945, decrease = 2.473592758178711e-06
iteration 574, reconstraction error: 0.16687574982643127, decrease = 2.473592758178711e-06
iteration 575, reconstraction error: 0.1668722778558731, decrease = 3.471970558166504e-06
iteration 576, reconstraction error: 0.16687080264091492, decrease = 1.475214958190918e-06
iteration 577, reconstraction error: 0.16686782240867615, decrease = 2.9802322387695312e-0

iteration 668, reconstraction error: 0.1666279137134552, decrease = 1.9818544387817383e-06
iteration 669, reconstraction error: 0.16662493348121643, decrease = 2.9802322387695312e-06
iteration 670, reconstraction error: 0.16662293672561646, decrease = 1.996755599975586e-06
iteration 671, reconstraction error: 0.16661997139453888, decrease = 2.9653310775756836e-06
iteration 672, reconstraction error: 0.1666169911623001, decrease = 2.9802322387695312e-06
iteration 673, reconstraction error: 0.16661500930786133, decrease = 1.9818544387817383e-06
iteration 674, reconstraction error: 0.16661252081394196, decrease = 2.4884939193725586e-06
iteration 675, reconstraction error: 0.16661004722118378, decrease = 2.473592758178711e-06
iteration 676, reconstraction error: 0.1666075736284256, decrease = 2.473592758178711e-06
iteration 677, reconstraction error: 0.16660508513450623, decrease = 2.4884939193725586e-06
iteration 678, reconstraction error: 0.16660211980342865, decrease = 2.965331077575683

iteration 760, reconstraction error: 0.16640949249267578, decrease = 2.9802322387695312e-06
iteration 761, reconstraction error: 0.166407510638237, decrease = 1.9818544387817383e-06
iteration 762, reconstraction error: 0.16640403866767883, decrease = 3.471970558166504e-06
iteration 763, reconstraction error: 0.16640254855155945, decrease = 1.4901161193847656e-06
iteration 764, reconstraction error: 0.16640056669712067, decrease = 1.9818544387817383e-06
iteration 765, reconstraction error: 0.16639858484268188, decrease = 1.9818544387817383e-06
iteration 766, reconstraction error: 0.16639657318592072, decrease = 2.0116567611694336e-06
iteration 767, reconstraction error: 0.16639359295368195, decrease = 2.9802322387695312e-06
iteration 768, reconstraction error: 0.16639161109924316, decrease = 1.9818544387817383e-06
iteration 769, reconstraction error: 0.16638913750648499, decrease = 2.473592758178711e-06
iteration 770, reconstraction error: 0.166388139128685, decrease = 9.98377799987793e

iteration 853, reconstraction error: 0.1662166565656662, decrease = 2.995133399963379e-06
iteration 854, reconstraction error: 0.16621467471122742, decrease = 1.9818544387817383e-06
iteration 855, reconstraction error: 0.16621267795562744, decrease = 1.996755599975586e-06
iteration 856, reconstraction error: 0.16621118783950806, decrease = 1.4901161193847656e-06
iteration 857, reconstraction error: 0.16620871424674988, decrease = 2.473592758178711e-06
iteration 858, reconstraction error: 0.1662067174911499, decrease = 1.996755599975586e-06
iteration 859, reconstraction error: 0.16620473563671112, decrease = 1.9818544387817383e-06
iteration 860, reconstraction error: 0.16620275378227234, decrease = 1.9818544387817383e-06
iteration 861, reconstraction error: 0.16620074212551117, decrease = 2.0116567611694336e-06
iteration 862, reconstraction error: 0.1661982536315918, decrease = 2.4884939193725586e-06
iteration 863, reconstraction error: 0.166197270154953, decrease = 9.834766387939453e-0

In [342]:
factors_temporal_nmf = np.array(factors_nmf.factors[1].to('cpu')).T

# Run Regression
Regress z-scored NMF temporal factors onto z-scored neural data itself

In [343]:
indices_to_use = np.arange(int(Fs*60*10),neural_data_toUse.shape[1])
## z-score NMF outputsneural_data_toUse 
factors_temporal_nmf_zscore = scipy.stats.zscore(factors_temporal_nmf[:,indices_to_use],axis=1)
# neural_data_toUse_zscore = welford_moving_2D.make_rollingZScore(np.array(neural_data_toUse.to('cpu')).T, win_roll=int(Fs*60*10)).T

In [344]:
neural_data_toUse_zscore = welford_moving_2D.make_rollingZScore(np.array(neural_data_toUse.to('cpu')).T, win_roll=int(Fs*60*10)).T
neural_data_toUse_zscore = neural_data_toUse_zscore[:,indices_to_use]

100%|████████████████████████████████████████████████████████████████████████| 108000/108000 [00:16<00:00, 6472.17it/s]
  X_zscore_roll = (list_of_values - X_mean_rolling)/np.sqrt(X_var_rolling)
  X_zscore_roll = (list_of_values - X_mean_rolling)/np.sqrt(X_var_rolling)


In [345]:
neural_data_toUse_zscore = np.nan_to_num(neural_data_toUse_zscore, nan=0)

In [346]:
%matplotlib notebook

plt.figure()
# plt.imshow(dFoF_sub_ptile[rmap.isort], aspect='auto', vmax=1)
plt.imshow(
    neural_data_toUse_zscore,
    aspect='auto',
    vmin=-1, 
    vmax=2,
)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x275cbe70dc0>

In [347]:
# OLS Regression
theta, factors_temporal_regression, bias = linear_regression.OLS(neural_data_toUse_zscore.T, factors_temporal_nmf_zscore.T)
factors_temporal_regression = factors_temporal_regression.T

In [348]:
# # Ridge Regression
# theta, factors_temporal_regression, bias = linear_regression.Ridge(neural_data_toUse_zscore.T, factors_temporal_nmf_zscore.T, lam=100000)
# factors_temporal_regression = factors_temporal_regression.T

In [349]:
# # ElasticNet Regression
# elr = sklearn.linear_model.ElasticNet(
#     alpha=0.1,
#     l1_ratio=0.05, 
#     fit_intercept=False,
# #     normalize='deprecated',
# #     precompute=False,
#     max_iter=1000, 
# #     copy_X=True, 
#     tol=0.0001,
# #     warm_start=False, 
#     positive=False,
#     random_state=42, 
#     selection='cyclic',
# #     verbose=True,
# )
# elr.fit(neural_data_toUse_zscore.T, factors_temporal_nmf_zscore.T)

# theta = elr.coef_.T
# factors_temporal_regression = elr.predict(neural_data_toUse_zscore.T).T

In [350]:
plt.figure()
plt.plot(np.arange(factors_temporal_nmf_zscore.shape[1])/Fs,  factors_temporal_nmf_zscore.T + 10*np.arange(rank)[None,:], alpha=0.3);
plt.gca().set_prop_cycle(None)
plt.plot(np.arange(factors_temporal_regression.shape[1])/Fs,  factors_temporal_regression.T + 10*np.arange(rank)[None,:]);

<IPython.core.display.Javascript object>

In [351]:
similarity.EV(factors_temporal_nmf_zscore.T, factors_temporal_regression.T)

(array([0.7305364 , 0.8072173 , 0.6124661 , 0.43889916, 0.4140488 ,
        0.92520577, 0.8462325 , 0.58907574, 0.7303035 , 0.75897384],
       dtype=float32),
 0.68529594,
 0.68529594)

In [352]:
# Filter based on Explained Variance
# If we can't predict a factor well, don't keep it
ev_threshold = 0.7
_, pairwise, evr_weighted, _ = similarity.pairwise_orthogonalization(factors_temporal_regression.T.astype(np.float32), factors_temporal_nmf_zscore.T.astype(np.float32))
factors_temporal_tokeep = factors_temporal_regression[pairwise > ev_threshold]

rank_good = (pairwise > ev_threshold).sum()
print(f'Rank above 0.7 EV: {rank_good}')

Rank above 0.7 EV: 6


In [353]:
pairwise

array([0.7309439 , 0.8073974 , 0.6144024 , 0.43896824, 0.41419554,
       0.92562836, 0.8468605 , 0.58908105, 0.7303097 , 0.7591359 ],
      dtype=float32)

In [354]:
_, _, evrs = get_highest_evr_var(neural_data_toUse[:,indices_to_use].cpu().numpy().T, factors_temporal_tokeep,  np.ones(rank_good, bool))
# _, _, evrs = get_highest_evr_var(neural_data_toUse[:,indices_to_use].T, factors_temporal_tokeep,  np.ones(rank, bool))

In [355]:
%matplotlib notebook
plt.figure()
plt.plot(np.arange(factors_temporal_tokeep.shape[1])/Fs,  factors_temporal_nmf_zscore[np.argsort(evrs)[::-1],:].T + 10*np.arange(rank_good),alpha=0.4)
plt.gca().set_prop_cycle(None)
plt.plot(np.arange(factors_temporal_tokeep.shape[1])/Fs,  factors_temporal_tokeep[np.argsort(evrs)[::-1],:].T + 10*np.arange(rank_good));

<IPython.core.display.Javascript object>

In [356]:
weights = theta[:,np.argsort(evrs)[::-1]]
# weights = theta
sf_weights = np.einsum('ij,ikl->jkl', weights,sf[iscell])

%matplotlib notebook
display_toggle_image_stack(
    sf_weights, 
    clim=[sf_weights.min(), sf_weights.max()]
)

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=0, description='i_frame', max=5), Output()), _dom_classes=('widget-inter…

# Run Orthogonalization

In [357]:
factors_orth = copy.deepcopy(factors_temporal_tokeep)
num_factors = rank_good
hits = np.ones(num_factors, bool)
for i in range(num_factors-1):
    ind_to_orthogonalize, factor_to_orthogonalize, evrs = get_highest_evr_var(neural_data_toUse_zscore.T, factors_orth, hits)
    print(f'Factor at {ind_to_orthogonalize} index explains {evrs[ind_to_orthogonalize]} of the variance in neural data')
    hits[ind_to_orthogonalize] = False
    factors_orth[hits] = np.transpose(orthogonalize_simple(factors_orth[hits].T, factor_to_orthogonalize))

Factor at 2 index explains 0.007464706897735596 of the variance in neural data
Factor at 0 index explains 0.003937363624572754 of the variance in neural data
Factor at 1 index explains 0.0035449862480163574 of the variance in neural data
Factor at 5 index explains 0.0027707815170288086 of the variance in neural data
Factor at 4 index explains 0.0026029348373413086 of the variance in neural data


In [358]:
f, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(np.corrcoef(factors_temporal_tokeep),vmin=-1)
ax2.imshow(np.corrcoef(factors_orth),vmin=-1)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x275cbf38d30>

In [359]:
_, _, evrs = get_highest_evr_var(neural_data_toUse_zscore.T, factors_orth,  np.ones(num_factors, bool))

In [360]:
np.sort(evrs)

array([0.00251794, 0.00260293, 0.00277078, 0.00354499, 0.00393736,
       0.00746471])

# Run Regression Pt 2

In [361]:
# OLS Regression
theta, factors_orth_regression, bias = linear_regression.OLS(neural_data_toUse_zscore.T, factors_orth.T)
factors_orth_regression = factors_orth_regression.T

In [362]:
_, _, evrs = get_highest_evr_var(neural_data_toUse[:,indices_to_use].cpu().numpy().T, factors_orth,  np.ones(num_factors, bool))

In [367]:
plt.figure()
plt.plot(np.sort(evrs)[::-1])

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x275cc0187f0>]

# Visualize

In [372]:
%matplotlib notebook
plt.figure()
plt.plot(np.arange(factors_temporal_tokeep.shape[1])/Fs, factors_temporal_tokeep[np.argsort(evrs)[::-1],:].T + 10*np.arange(rank_good),alpha=0.4)
plt.gca().set_prop_cycle(None)
plt.plot(np.arange(factors_temporal_tokeep.shape[1])/Fs, factors_orth_regression[np.argsort(evrs)[::-1],:].T + 10*np.arange(rank_good))

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x275cbeee310>,
 <matplotlib.lines.Line2D at 0x275cbeee1c0>,
 <matplotlib.lines.Line2D at 0x275cbeee4c0>,
 <matplotlib.lines.Line2D at 0x275cbeeecd0>,
 <matplotlib.lines.Line2D at 0x275cbef1190>,
 <matplotlib.lines.Line2D at 0x275cbef1ee0>]

In [373]:
weights = theta[:,np.argsort(evrs)[::-1]]
sf_weights = np.einsum('ij,ikl->jkl', weights,sf[iscell])

factors_orth_regression_sorted = factors_orth_regression[np.argsort(evrs)[::-1]]

%matplotlib notebook
display_toggle_image_stack(sf_weights)

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=0, description='i_frame', max=5), Output()), _dom_classes=('widget-inter…

In [None]:
weights_zeroSum = weights - weights.sum(x)

In [368]:
plt.figure()
plt.imshow(neural_data_toUse_zscore, aspect='auto', vmax=4)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x275cc034280>

In [370]:
%matplotlib notebook
plt.figure()
plt.imshow(theta[:,:],aspect='auto', interpolation='none')

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x275cc09ab50>

# Save it all

In [371]:
factor_to_use = 0 # 0-indexed
weights = theta[:,np.argsort(evrs)[::-1]]

weights_day0 = {
    "weights": weights[:,factor_to_use],
    "weights_all" : weights,
    "iscell_custom": iscell,
    "factor_to_use": factor_to_use,
    "sf_weights": sf_weights,
#     "factors_temporal": factors_temporal_tokeep,
    "factors_temporal": factors_orth_regression_sorted,
}


F = scipy.io.savemat(path_save.with_suffix('.mat') , weights_day0)

np.save(path_save.with_suffix('.npy') , weights_day0)

# Load in old outputs

In [1]:
import numpy as np

In [2]:
data = np.load(r'D:/RH_local/data/BMI_cage_1511_4/mouse_1511L/20230111/analysis_data/weights_day0.npy', allow_pickle=True)[()]

In [3]:
data.keys()

dict_keys(['weights', 'weights_all', 'iscell_custom', 'factor_to_use', 'sf_weights', 'factors_temporal'])

In [6]:
import bnpm.plotting_helpers

In [17]:
# weights = theta[:,np.argsort(evrs)[::-1]]
# sf_weights = np.einsum('ij,ikl->jkl', data['weights_all'], sf[data['iscell_custom']])

%matplotlib notebook
bnpm.plotting_helpers.display_toggle_image_stack(data['sf_weights'])

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=0, description='i_frame', max=5), Output()), _dom_classes=('widget-inter…

In [None]:
components , scores , explained_variance_ratio_ , stds = decomposition.simple_pca(neural_data_toUse.T.cpu().numpy() , n_components=None , mean_sub=True, zscore=False, plot_pref=True , n_PCs_toPlot=2)

In [None]:
factors_orth_regression.shape

In [8]:
Fs = 30

In [10]:
import matplotlib.pyplot as plt

In [18]:
%matplotlib notebook
fig, axs = plt.subplots(data['factors_temporal'].shape[0], 1, sharex=True)
for ii in range(axs.shape[0]):
#     axs[ii,0].plot(np.arange(data['factors_temporal'].shape[1])/Fs, factors_orth_regression[np.argsort(evrs)[-ii-1],:].T)
    axs[ii].plot(np.arange(data['factors_temporal'].shape[1])/Fs, data['factors_temporal'][ii])

<IPython.core.display.Javascript object>

In [14]:
%matplotlib notebook
fig, axs = plt.subplots(data['factors_temporal'].shape[0], 1, sharex=True)
for ii in range(axs.shape[0]):
#     axs[ii,0].plot(np.arange(data['factors_temporal'].shape[1])/Fs, factors_orth_regression[np.argsort(evrs)[-ii-1],:].T)
    axs[ii].plot(np.arange(data['factors_temporal'].shape[1])/Fs, data['factors_temporal'][ii])

<IPython.core.display.Javascript object>

In [372]:
%matplotlib notebook
plt.figure()
plt.plot(np.arange(factors_temporal_tokeep.shape[1])/Fs, factors_temporal_tokeep[np.argsort(evrs)[::-1],:].T + 10*np.arange(rank_good),alpha=0.4)
plt.gca().set_prop_cycle(None)
plt.plot(np.arange(factors_temporal_tokeep.shape[1])/Fs, factors_orth_regression[np.argsort(evrs)[::-1],:].T + 10*np.arange(rank_good))

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x275cbeee310>,
 <matplotlib.lines.Line2D at 0x275cbeee1c0>,
 <matplotlib.lines.Line2D at 0x275cbeee4c0>,
 <matplotlib.lines.Line2D at 0x275cbeeecd0>,
 <matplotlib.lines.Line2D at 0x275cbef1190>,
 <matplotlib.lines.Line2D at 0x275cbef1ee0>]

In [None]:
import rastermap

In [None]:
rmap = rastermap.Rastermap(
    n_components=1,
    n_X=40,
#     nPC=200,
#     init='pca',
#     alpha=1.0,
#     K=1.0,
#     mode='basic',
#     verbose=True,
#     annealing=True,
#     constraints=2,
)

In [None]:
# embedding = rmap.fit_transform(neural_data_toUse.cpu().numpy())
# embedding = rmap.fit_transform(dFoF)
embedding = rmap.fit_transform(np.concatenate([neural_data_toUse.cpu().numpy()[:,indices_to_use], factors_orth_regression]))

In [None]:
plt.figure()
plt.hist(embedding, 400);
plt.hist(embedding[-10:], 100);

In [None]:
plt.figure()
plt.imshow(
#     timeSeries.scale_between(neural_data_toUse.cpu().numpy()[rmap.isort], lower_percentile=10, upper_percentile=90), 
#     timeSeries.scale_between(spks_s2p[rmap.isort], lower_percentile=10, upper_percentile=90), 
#     dFoF[rmap.isort], vmax=5, 
    np.concatenate([neural_data_toUse.cpu().numpy()[:,indices_to_use], factors_orth_regression])[rmap.isort], 
    vmin=0,
    vmax=5, 
    aspect='auto',
#     extent=[0, indices_to_use.max()/Fs, 0, neural_data_toUse.shape[0]]
)
plt.xticks(ticks=np.arange(0,factors_temporal_tokeep.shape[1], 10000), labels=np.round(np.arange(0,factors_temporal_tokeep.shape[1], 10000)/Fs))

In [None]:
f, (ax1, ax2) = plt.subplots(2, 1,sharex=True)
ax1.imshow(
#     timeSeries.scale_between(neural_data_toUse.cpu().numpy()[rmap.isort], lower_percentile=10, upper_percentile=90), 
#     timeSeries.scale_between(spks_s2p[rmap.isort], lower_percentile=10, upper_percentile=90), 
#     dFoF[rmap.isort], vmax=5, 
    np.concatenate([neural_data_toUse.cpu().numpy()[:,indices_to_use], factors_orth_regression])[rmap.isort], 
    vmin=0,
    vmax=5, 
    aspect='auto',
#     extent=[0, indices_to_use.max()/Fs, 0, neural_data_toUse.shape[0]]
)
#ax1.set_xticks(ticks=np.arange(0,factors_temporal_tokeep.shape[1], 10000), labels=np.round(np.arange(0,factors_temporal_tokeep.shape[1], 10000)/Fs))
ax2.plot(factors_temporal_tokeep[np.argsort(evrs)[::-1],:].T + 10*np.arange(10),alpha=0.4)
ax2.plot(factors_orth_regression[np.argsort(evrs)[::-1],:].T + 10*np.arange(10))

In [None]:
indices_to_use.shape