In [None]:
import os, fsspec, time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from psn import PSN, psn

import manifold_dynamics.paths as pth
fs = fsspec.filesystem("s3")

In [None]:
### from jp's repo
# Generate test data (10 units, 25 conditions, 3 trials)
np.random.seed(42)
data = np.random.randn(10, 25, 3)

# Create and fit the model
model = PSN()
model.fit(data)

# Access denoised data
print(f"Denoised shape: {model.denoiseddata_.shape}")  # (10, 25)
print(f"Retained {model.best_threshold_} dimensions")

# Or use fit_transform for one-step denoising
denoised = PSN().fit_transform(data)

from psn import psn
# Apply PSN denoising with default settings for zero-shot denoising
results = psn(data)

# Access denoised data
denoised_data = results['denoiseddata']  # Shape: (10, 25) (units, conditions)
print(f"Retained {results['best_threshold']} dimensions on average")

# can we treat time as a condition? or instead apply psn at each time step?

In [None]:
uid_sheet = pd.read_csv(os.path.join(pth.OTHERS, 'roi-uid.csv'))
unique_rois = uid_sheet['uid'].unique()

roi_uid = unique_rois[21]
inpath = os.path.join(pth.PROCESSED, 'single-session-raster', f'{roi_uid}.npy')

In [None]:
### time how long it takes to load via fsspec
size_bytes = fs.size(inpath)
print('starting to load data...')
t0 = time.perf_counter()
with fs.open(inpath, 'rb') as f:
    out = np.load(f, allow_pickle=False)
dt = time.perf_counter() - t0
print(f'size: {size_bytes/1e9:.2f} GB')
print(f'time: {dt:.2f} sec')
print(f'throughput: {(size_bytes/1e6)/dt:.2f} MB/s')
### end time

In [None]:
xt = np.nanmean(out[:, :, 1000:, :], axis=2) # remove last trial since it has some NaNs
print(f'Input to PSN is shape {xt.shape} (units, conditions, trials)')

results = psn(xt)

In [None]:
# dn is post-denoised mean: (units, conditions)
dn = results['denoiseddata']
U, C = dn.shape

units_keep = np.arange(min(10, U))

# baseline subtract per unit
dn_base = np.nanmean(dn[:, :50], axis=1, keepdims=True)
dn_bs = dn - dn_base

# build long df (means only)
df = pd.DataFrame({
    'unit': np.repeat(units_keep, C),
    'condition': np.tile(np.arange(C), len(units_keep)),
    'response': dn_bs[units_keep].reshape(-1),
})

fig, ax = plt.subplots(1, 1, figsize=(8, 3))

sns.lineplot(
    data=df,
    x='condition',
    y='response',
    hue='unit',
    estimator=None,   # do not average across units
    alpha=0.8,
    ax=ax
)

ax.set_title(f'denoised mean (baseline subtracted)')
ax.set_xlabel('Time')
ax.legend_.remove()
plt.tight_layout()
plt.show()

In [None]:
data = results['input_data']  # (units, conditions, trials)
U, C, T = data.shape

# baseline subtract per unit × trial before averaging
baseline = np.nanmean(data[:, :50, :], axis=1, keepdims=True)
data_bs = data - baseline

# compute mean across trials
mean_resp = np.nanmean(data_bs, axis=2)  # (units, conditions)

# build long df at trial level (so seaborn can compute errorbars)
units = np.repeat(np.arange(10), C * T)
conds = np.tile(np.repeat(np.arange(C), T), 10)
trials = np.tile(np.arange(T), 10 * C)

vals = data_bs[0:10].reshape(-1)

df = pd.DataFrame({
    'unit': units,
    'condition': conds,
    'trial': trials,
    'response': vals
})

fig, ax = plt.subplots(1, 1, figsize=(8,3))

sns.lineplot(
    data=df,
    x='condition',
    y='response',
    hue='unit',
    errorbar='se',        # sem across trials
    estimator='mean',     # mean across trials
    legend=False,
    ax=ax
)
ax.set_xlabel('Time')
ax.set_title('original mean (baseline subtracted, with sem)')
plt.show()

In [None]:
verbose = False

# raw raster data
U, T, C, R = out.shape

# subset last 72
last72 = out[:, :, -72:, :]          # shape: (U, T, 72, R)
imgs_per_super = last72.shape[2] // 3   # should be 24

# reshape 72 -> (3 superconds, 24 images) --> # new shape: (U, T, 3, 24, R)
last72 = last72.reshape(U, T, 3, imgs_per_super, R)

# average across images within each super-condition
# we preserve trial structure
super_avg = np.nanmean(last72, axis=3)
# final shape: (U, T, 3, R)

if verbose: print(f'Data collapsed to 3 conditions: {super_avg.shape}') # expected: (251, 450, 3, 6)

denoised_time = np.full((U, T, 3), np.nan)
# for each timepoint, average ±25 samples
# shrink edges rather than hallucinate padding
half_window = 25
for t in tqdm(range(T)):
    # compute inclusive window bounds (python slice upper bound is exclusive)
    t0 = max(0, t - half_window)
    t1 = min(T, t + half_window + 1)

    # average across the time window to stabilize this timepoint
    # window_avg shape: (units, 3, trials)
    window_avg = np.nanmean(super_avg[:, t0:t1, :, :], axis=1)
    if verbose: print(f'Single time point data: {window_avg.shape}')

    # call psn denoiser; accept either array or (array, info)
    den_result = psn(window_avg, {'wantverbose': verbose, 'wantfig': False})
    den = den_result['denoiseddata']
    if verbose: print(f'Denoised data at {t:03d} msec: {den.shape}')

    denoised_time[:, t, :] = den
print(f'Final data: {denoised_time.shape}')

In [None]:
def to_long_df(data):
    """
    convert (units, time, conditions) array
    into long-form dataframe for seaborn plotting

    output columns:
        unit, time, condition, value
    """

    data = np.asarray(data)
    if data.ndim != 3:
        raise ValueError(f'expected shape (units, time, conditions), got {data.shape}')

    U, T, C = data.shape

    df = pd.DataFrame({
        'unit': np.repeat(np.arange(U), T * C),
        'time': np.tile(np.repeat(np.arange(T), C), U),
        'condition': np.tile(np.arange(C), U * T),
        'value': data.reshape(-1)
    })

    return df

In [None]:
og = np.mean(super_avg, axis=3)

long = to_long_df(og)
fig,ax = plt.subplots(1,1, figsize=(8,3))
sns.lineplot(long, x='time', y='value', hue='condition', ax=ax)
ax.set_title('OG data to Faces, Bodies, Objects')
ax.set_xlabel('Time')
plt.show()

In [None]:


fig,axes = plt.subplots(5,2, figsize=(8,8))
axes = axes.ravel()
for idx in range(10):
    long = to_long_df(og[[idx], :, :])
    ax = axes[idx]
    sns.lineplot(long, x='time', y='value', hue='condition', ax=ax)
    ax.legend().remove()
    if idx < 1:
        ax.set_title('Original single unit data to Faces, Bodies, Objects')
        ax.set_xlabel('Time')
plt.show()

In [None]:
fig,axes = plt.subplots(5,2, figsize=(8,8))
axes = axes.ravel()
for idx in range(10):
    long = to_long_df(denoised_time[[idx], :, :])
    ax = axes[idx]
    sns.lineplot(long, x='time', y='value', hue='condition', ax=ax)
    ax.legend().remove()
    if idx < 1:
        ax.set_title('Denoised single unit data to Faces, Bodies, Objects')
        ax.set_xlabel('Time')
plt.show()