In [None]:
import os, fsspec
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.stats import pearsonr
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import pdist

from gsn.perform_gsn import perform_gsn

import manifold_dynamics.session_raster_extraction as sre
from manifold_dynamics.session_gsn import session_gsn
import manifold_dynamics.paths as pth

fs = fsspec.filesystem("s3")

In [None]:
# generate random data (time series style)
rng = np.random.default_rng(0)
x = np.linspace(0, 10, 500)
y = np.sin(x) + 0.2 * rng.standard_normal(x.size)

# create figure
fig, ax = plt.subplots(figsize=(6, 4))
sns.lineplot(x=x, y=y, color='red')
ax.set_xlabel('time')
ax.set_ylabel('signal')
ax.set_title('random noisy sine')
sns.despine(trim=True, offset=5, fig=fig)

# save figure
outpath = os.path.join(pth.SAVEDIR, 'test_figure.png')
with fs.open(outpath, 'wb') as f:
    fig.savefig(f, dpi=300, transparent=True)
print(f'saved to {outpath}')

In [None]:
outdir = os.path.join(pth.SAVEDIR, 'gsn')
uid_sheet = pd.read_csv(os.path.join(pth.OTHERS, 'roi-uid.csv'))
unique_rois = uid_sheet['uid'].unique()
print(unique_rois[:3])

In [None]:
roi_uid = unique_rois[47]
session_num = roi_uid.split('.')[0]
out = sre.extract_session_raster(roi_uid) # shape is (units, 450, 1072, reps)
print(out.shape)

In [None]:
# PERFORM GSN ON THE SINGLE ROI/SESSION
cov_rows = []
ncsnr_rows = []
step = 1
scaling = 1e3
for t in tqdm(range(0, out.shape[1] - step, step)):
    Xw = out[:, slice(t, t + step), :, :]
    Xavg = np.mean(Xw, axis=1)
    Xavg = Xavg * scaling

    results = perform_gsn(Xavg, {'wantverbose': False})
    sigcov = results['cSb']
    noisecov = results['cNb']
    ncsnr = results['ncsnr']

    triu = np.triu_indices_from(sigcov, k=1)

    cov_rows.append({'time': t, 'type': 'signal', 'covariance': sigcov[triu]})
    cov_rows.append({'time': t, 'type': 'noise', 'covariance': noisecov[triu]})
    ncsnr_rows.append({'time': t, 'ncsnr': ncsnr})

cov_df = pd.DataFrame(cov_rows)
cov_df['mean_covariance'] = cov_df['covariance'].apply(lambda x: np.nanmean(np.abs(x)))
ncsnr_df = pd.DataFrame(ncsnr_rows)
ncsnr_df['mean_ncsnr'] = ncsnr_df['ncsnr'].apply(lambda x: np.nanmean(np.abs(x)))


print('mean ncsnr', np.nanmean(results['ncsnr']))
print('cov correlation', pearsonr(noisecov[triu], sigcov[triu]).statistic)


#### PLOT ###
customp = ['red', 'black']# sns.color_palette('Dark2')
fig,ax = plt.subplots(1,1, figsize=(5,3))
sns.scatterplot(cov_df, x='time', y='mean_covariance', hue='type', 
                marker='o', palette=customp, alpha=.75, ax=ax)
ax.set_xlabel('Time')
ax.set_ylabel('Cov.')
ax.legend(title='')
sns.despine(fig=fig, trim=True, offset=5)
outpath = os.path.join(SAVEDIR, f'session-{session_num}-gsn.png')
plt.savefig(outpath, dpi=300, transparent=True, bbox_inches='tight')

fig,ax = plt.subplots(1,1, figsize=(5,3))
sns.lineplot(ncsnr_df, x='time', y='mean_ncsnr', color='gray', ax=ax)
ax.set_xlabel('Time')
ax.set_ylabel('NCSNR')
sns.despine(fig=fig, trim=True, offset=5)
# outpath = os.path.join(SAVEDIR, f'session-{session_num}-ncsnr.png')
# plt.savefig(outpath, dpi=300, transparent=True, bbox_inches='tight')

In [None]:
# CLUSTERED cSb MATRIX
M = np.array(sigcov, copy=True)
np.fill_diagonal(M, 0)

# compute pairwise row distances
# correlation distance is usually best for covariance structure
row_dist = pdist(M, metric='correlation')

Z = linkage(row_dist, method='average')
order = leaves_list(Z)
M_reordered = sigcov[np.ix_(order, order)]

# plot
sns.heatmap(M_reordered, vmin=-0.5, vmax=0.5, square=True)

sns.clustermap(
    sigcov,
    metric='correlation',
    method='average',
    vmin=-0.5,
    vmax=0.5,
    cmap='vlag'
)

In [None]:
# RAW DATA RASTER FOR LOCALIZER IMAGES
U, T, I, R = out.shape
bin_size = 20

rng = np.random.default_rng(0)
unit_ids = rng.choice(U, size=min(12, U), replace=False)
img_ids = rng.choice(I, size=6, replace=False)

Xb = out
Y = np.nanmean(Xb[unit_ids][:, :, 1000:, :], axis=-1)


for unit in Y:
    fig,ax = plt.subplots(1,1, figsize=(20, 5))
    unit[unit==0] = np.nan
    print(unit.shape)
    sns.heatmap(unit.T, square=True, cmap='binary', vmax=0.1)
    plt.show()