# Run dPCA analysis on **synthetic** data
This notebook:
1. generates synthetic dataset
2. uses original dPCA functions

In [None]:
%cd ../
# This code is a demo for the dPCA analysis, it is not intetended to generate interpretable results, but only to show how the dPCA analysis is implemented"

/Users/tahaismail/Desktop/work/Baylor_Hayden/PreyPursuit


In [3]:
# module imports
import dill as pickle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import NMF, PCA
from legacy.ChangeOfMind.functions import processing as proc
import os
from legacy.PacTimeOrig.data import scripts
from pathlib import Path
from scipy.io import loadmat, savemat
import pandas as pd
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
from legacy.dPCA import dPCA
from sklearn.linear_model import LinearRegression
import pandas as pd

In [4]:
# Synthetic dataset for the “condition dPCA” demo:
# Each switch is labeled by switch direction (direction = +1 or -1).
# Raw firing rates are organized as: fr.shape = (n_switches, n_time_bins, n_neurons)

# In decoding_prep (preptype='dpca'), trials are split into the 2 direction conditions.
# For each condition, it averages over randomly selected training trials, then stacks: stacked.shape = (n_time_bins, n_conditions=2, n_neurons)
# then transposes to the dPCA input format: X_train.shape = (n_neurons, n_conditions=2, n_time_bins)

# Decoding uses held-out trials to build: X_test.shape = (n_neurons, n_conditions=2, n_time_bins, n_test_trials)
# and classifies direction over time using the selected dPCA component (st).

# example running with one brain region using synthetic data 

# plotting is implemented through matlab "plot_dPCA_results.m"

def make_synth_dataset_for_dpca(
    seed=0,
    subj='SYN01',
    n_switch=160,   
    T=30,
    n_acc=200,
    noise_sd=0.8,  
):
    rng = np.random.default_rng(seed)
    n_neurons = n_acc 

    # brain regions 
    regions = np.array(['acc']*n_acc, dtype=object)
    dat = {'brain_region_emu': {subj: {1: regions}}}

    # trial labels (switch direction - conditions on dPCA design matrix) 
    direction = rng.choice([1, -1], size=n_switch)
    vdiff     = rng.choice([2, 4], size=n_switch)  

    # time axis (centered around switch)
    tt = np.linspace(-1, 1, T)

    # 1) Condition-independent (time)
    CI = 0.8*np.exp(-(tt/0.55)**2) - 0.25*np.exp(-((tt-0.6)/0.25)**2)

    # 2) Switch-related: smooth sigmoid-like change (wt over time), here it is similar across brain regions 
    k = 6.0
    ST_base = 1/(1 + np.exp(-k*tt)) - 0.5        # in [-0.5, 0.5]
    ST = np.array([ST_base if d==1 else -ST_base for d in direction])  # (n_switch, T)

    # 3) Static direction (time-independent)
    S = direction.astype(float)  # (+1 or -1)

    # neuron loadings 
    w_ci = np.concatenate([ 
        rng.normal(0.6, 0.15, n_acc),  
    ])
    
    w_st = np.concatenate([
        rng.normal(0.8, 0.20, n_acc),  
    ])
    
    w_s = np.concatenate([
        rng.normal(1.2, 0.20, n_acc),  
    ])

    baseline = rng.uniform(2.0, 6.0, n_neurons)

    # generate firing rates
    fr = np.zeros((n_switch, T, n_neurons), dtype=float)
    
    for i in range(n_switch):

        noise = np.concatenate([
            rng.normal(0, noise_sd * 1.2, size=(T, n_acc)),   
        ], axis=1)
    
        fr[i] = (
            baseline[None, :]
            + CI[:, None] * w_ci[None, :]           # condition-independent (time)
            + ST[i][:, None] * w_st[None, :]        # switch-dependent
            + (0.25 * S[i]) * w_s[None, :]          # static 
            + noise
        )
    
    fr = np.clip(fr, 0, None)

    inputs = {subj: {1: {'fr': fr, 'direction': direction, 'vdiff': vdiff}}}

    # metadata 
    n_lohi = int(np.sum(direction == 1))
    n_hilo = int(np.sum(direction == -1))

    metadata = {}
    metadata['trial_num'] = pd.DataFrame([{
        'subject': subj,
        'session': 1,
        'switch_hilo_count': n_hilo,
        'switch_lohi_count': n_lohi,
        'total_neuron_count': n_neurons,
        'neuron_count_acc': n_acc,
        'use_sess': 1,
    }])

    return dat, inputs, metadata


In [5]:
# Generate synthetic dataset
dat, inputs, metadata = make_synth_dataset_for_dpca()

filtered_brain_region_emu = {
    key: dat['brain_region_emu'][key]
    for key in metadata['trial_num']['subject']
}

metadata['area_per_neuron'] = np.concatenate(
    [np.array(subdict[1]).astype(str) for subdict in filtered_brain_region_emu.values()]
)

accidx = np.where(np.char.find(metadata['area_per_neuron'], 'acc') != -1)[0]

print(f"ACC neurons: {len(accidx)}")



ACC neurons: 200


In [6]:
def run_dpca_for_region(region_name, neur_idx, reg, Vfull_ref=None):
    out = {}

    # MEAN dPCA
    dpca_params = {
        'mean_dPCA': True,
        'reg': reg,
        'bias': 0,
        'runs': 1,
        'neur_idx': neur_idx,
        'inputs': inputs,
        'train_N': None,
        'test_N': None,
        'Vfull': None,
        'partialer': None
    }

    Z, Vfull, expvar = proc.dpca_run(dpca_params)

    out[f'Z_{region_name}']      = Z
    out[f'Vfull_{region_name}']  = Vfull
    out[f'expvar_{region_name}'] = expvar

    # DECODING dPCA
    dpca_params.update({
        'mean_dPCA': False,
        'runs': 1000,
        'Vfull': Vfull,
    })

    filtered = metadata['trial_num'][
        (metadata['trial_num']['use_sess'] == 1)
        & (metadata['trial_num'][f'neuron_count_{region_name}'] > 0)
    ]

    dpca_params['train_N'] = int(
        np.min([filtered.switch_hilo_count.min(),
                filtered.switch_lohi_count.min()]) * cfgparams['percent_train']
    )
    dpca_params['test_N'] = int(
        np.min([filtered.switch_hilo_count.min(),
                filtered.switch_lohi_count.min()]) * (1 - cfgparams['percent_train'])
    ) + 1

    # decoding (observed)
    dpca_params['do_permute'] = False
    acc = proc.dpca_run(dpca_params)
    out[f'accuracy_{region_name}_dpca'] = acc

    # decoding (permuted)
    dpca_params['do_permute'] = True
    acc_perm = proc.dpca_run(dpca_params)
    out[f'accuracy_{region_name}_perm_dpca'] = acc_perm

    return out


# Run dPCA decomposition by brain region

# Set parameters
smoothing = None
do_partial = False
cut_at_median = False
do_warp = False
all_subjects = False
extra_params = {
    "reg_hpc": 1e-5, # 1e-5
    "reg_acc": 1e-5,
    "reg_ofc": 1e-5,
}

cfgparams = {}
cfgparams['locking'] = 'zero'  # 'zero
cfgparams['keepamount'] = 10
cfgparams['timewarp'] = {}
cfgparams['prewin'] = 14
cfgparams['prewin_behave'] = cfgparams['prewin']
cfgparams['behavewin'] = 15
cfgparams['behavewin_behave'] =cfgparams['behavewin']
cfgparams['timewarp']['dowarp'] = do_warp
cfgparams['timewarp']['warpN'] = cfgparams['prewin'] + cfgparams['behavewin'] + 1
cfgparams['timewarp']['originalTimes'] = np.arange(1, cfgparams['timewarp']['warpN'] + 1)
cfgparams['percent_train'] = 0.95
if smoothing is None:
    cfgparams['smoothing'] = 80
else:
    cfgparams['smoothing'] = smoothing

cfgparams['do_partial'] = do_partial
cfgparams['cut_at_median'] = cut_at_median


output_to_plot = {}

# ACC
output_to_plot.update(
    run_dpca_for_region(
        region_name='acc',
        neur_idx=accidx,
        reg=extra_params['reg_acc']
    )
)

# SAVE (MATLAB for plotting)
savemat(
    "dPCASwitchDirection_SyntheticACC.mat",
    output_to_plot
)

print("DONE!!!!!!!")

# Now you can visualize the results using the matlab code "plot_dPCA_results.m"

100%|██████████| 1000/1000 [00:31<00:00, 31.72it/s]
100%|██████████| 1000/1000 [00:35<00:00, 28.14it/s]

DONE!!!!!!!





In [40]:
# This demo shows how to use the function `proc.dpca_run_equal_rwd`, which performs dPCA decomposition using 5 bins of the continuous variable Wt.

# The design matrix is constructed internally within `dpca_run_equal_rwd` via the helper function `build_X_from_wtsplit()`.

# Specifically, the input firing rate tensor has shape: fr.shape = (n_switches, n_time_bins, n_neurons) = (300, 30, 60)

# This switch-level data is then aggregated by binning Wt into 5 bins, yielding a dPCA design matrix of shape:
# (n_neurons, n_Wt_bins, n_time_bins)

def make_synth_wtbinned(
    seed=0,
    subj="SYN01",
    region="acc",
    n_neurons=60,
    Ns=300,     # number of switches
    T=30,
    wt_len=33
):
    rng = np.random.default_rng(seed)

    #  brain regions 
    areas = np.array([region]*n_neurons, dtype=object)

    tt_full = np.linspace(-1, 1, wt_len)
    wtsplit = np.zeros((Ns, wt_len), dtype=float)
    for i in range(Ns):
        k = rng.uniform(4, 8)
        c = rng.uniform(-0.3, 0.3)
        wtsplit[i] = 1.0/(1.0 + np.exp(-k*(tt_full - c)))

    splittypes = np.zeros((Ns, 3), dtype=int)
    splittypes[:, 1] = 1   # so types==1 passes keep=(types==1) - here we do not distinguish across switch types

    outputs_all = {subj: {1: {"wtsplit": wtsplit, "splittypes": splittypes}}}
    brain_region_all = {subj: {1: areas}}

    # inputs_sc 
    trial_index = np.arange(Ns, dtype=int)

    tt = np.linspace(-1, 1, T)
    CI = 0.6*np.exp(-(tt/0.6)**2)

    baseline = rng.uniform(2.0, 6.0, size=n_neurons)
    w_ci = rng.normal(0.6, 0.2, size=n_neurons)
    w_wt = rng.normal(1.2, 0.3, size=n_neurons)

    wt_for_fr = wtsplit[:, :T]  # (Ns, T)
    fr = (
        baseline[None, None, :]
        + CI[None, :, None]*w_ci[None, None, :]
        + (wt_for_fr[:, :, None]-0.5)*w_wt[None, None, :]*3.0
        + rng.normal(0, 0.6, size=(Ns, T, n_neurons))
    )
    fr = np.clip(fr, 0, None)

    inputs_sc = {subj: {1: {"fr": fr, "trial_index": trial_index}}}

    return inputs_sc, outputs_all, brain_region_all


inputs_sc, outputs_all, brain_region_all = make_synth_wtbinned()

# sanity checks 
subj = list(inputs_sc.keys())[0]
print("trial_index exists:", "trial_index" in inputs_sc[subj][1])
print("wtsplit shape:", outputs_all[subj][1]["wtsplit"].shape)         # (Ns, 33)
print("splittypes shape:", outputs_all[subj][1]["splittypes"].shape)   # (Ns, 3)
print("fr shape:", inputs_sc[subj][1]["fr"].shape)                    # (Ns, T, N)

# run dpca_run_equal_rwd - here there is no decoding! 
# the following function in the original paper is ran TWICE: one time including switches with equal value prey only and one time
# including switches with different value prey only 

dpca_params = {
    "mean_dPCA": True,
    "reg": 1e-5,
    "bias": 0.05,
    "runs": 1,
    "inputs": inputs_sc,
    "outputs_all": outputs_all,
    "brain_region_all": brain_region_all,
    "region": "acc",
    "partialer": None
}

Z, Vfull, expvar = proc.dpca_run_equal_rwd(dpca_params)

output_to_plot = {"Z_acc": Z, "Vfull_acc": Vfull, "expvar_acc": expvar}
print(output_to_plot.keys())


trial_index exists: True
wtsplit shape: (300, 33)
splittypes shape: (300, 3)
fr shape: (300, 30, 60)
dict_keys(['Z_acc', 'Vfull_acc', 'expvar_acc'])


In [7]:
# SAVE (MATLAB for plotting)
savemat(
    "example_data/dPCA_WtBins_SyntheticACC.mat",
    output_to_plot
)

print("DONE!!!!!!!")

# now you can plot using the matlab code "MAIN_plot_dPCA_results_Wtbins.m"

DONE!!!!!!!
