# TE, PID, PhiID analysis

main output: in results/, one file per condition (10 in total)

In [1]:
import pickle
import numpy as np
from analysis.utils import demean

# Load data
with open('merged_data.pkl', 'rb') as f:
    merged_data = pickle.load(f)

# Process each condition while preserving structure
for condition, data_dict in merged_data.items():
    if condition == 'roi_names':  # Skip the roi_names entry if it exists at top level
        continue
        
    # Demean and update in-place to preserve all other fields (e.g., 'stim_roi')
    data_dict['data'] = demean(data_dict['data'])

In [None]:
def delete_trial(merged_data, cond, trial_num):
    """
    Delete specified trial data and corresponding metadata from merged dataset
    
    Args:
        merged_data: Merged dataset dictionary
        cond: Condition name (e.g. 'MOp (L)')
        trial_num: Trial index to delete (0-based)
    """
    if cond not in merged_data:
        print(f"Warning: Condition {cond} not found")
        return
    
    print(f"\nCondition {cond} original shape: {merged_data[cond]['data'].shape}")
    
    # Delete trial from data matrix
    merged_data[cond]['data'] = np.delete(merged_data[cond]['data'], trial_num, axis=2)
    
    # Remove corresponding source info
    if 'sources' in merged_data[cond] and isinstance(merged_data[cond]['sources'], list):
        if len(merged_data[cond]['sources']) > trial_num:
            deleted_source = merged_data[cond]['sources'].pop(trial_num)
            print(f"Deleted source: {deleted_source}")
    
    # Update analysis results (PID/PhiID/TE) if present
    for analysis_type in ['PID', 'PhiID', 'TE']:
        if analysis_type in merged_data[cond]:
            for key in merged_data[cond][analysis_type]:
                if isinstance(merged_data[cond][analysis_type][key], np.ndarray) and merged_data[cond][analysis_type][key].ndim >= 2:
                    merged_data[cond][analysis_type][key] = np.delete(merged_data[cond][analysis_type][key], trial_num, axis=-1)
    
    # Update trial count in metadata
    if 'metadata' in merged_data[cond] and 'n_trials' in merged_data[cond]['metadata']:
        merged_data[cond]['metadata']['n_trials'] -= 1
    
    print(f"Condition {cond} new shape: {merged_data[cond]['data'].shape}")
    print(f"Remaining trials: {len(merged_data[cond]['sources'])}")

delete_trial(merged_data, 'MOp (L)', 0)
delete_trial(merged_data, 'RSPd/v (Bilateral)', 6)
delete_trial(merged_data, 'SSp-bfd (L)', 6)


Condition MOp (L) original shape: (130, 30, 8)
Deleted source: 1
Condition MOp (L) new shape: (130, 30, 7)
Remaining trials: 7

Condition RSPd/v (Bilateral) original shape: (130, 30, 8)
Deleted source: 7
Condition RSPd/v (Bilateral) new shape: (130, 30, 7)
Remaining trials: 7

Condition SSp-bfd (L) original shape: (130, 30, 8)
Deleted source: 7
Condition SSp-bfd (L) new shape: (130, 30, 7)
Remaining trials: 7



# Compare the changes in information flow between the same pair of ROIs during resting state and stimulated state

In [None]:
import os
import numpy as np       
from datetime import datetime
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from analysis.VAR_NuMIT import PID_VAR_calculator, PhiID_VAR_calculator
import logging
from functools import partial
import traceback

In [None]:
def analyze_roi_pairs(merged_data, maxp=1, n_workers=3, selected_conditions=None):
    """
    Optimized ROI pair analysis with selective condition processing
    
    Args:
        merged_data: Merged dataset {'roi_names': list, cond: {'data': ndarray}}
        maxp: Maximum VAR model order
        n_workers: Number of parallel workers
        selected_conditions: List of conditions to process (names/indices), None for all
    Returns:
        Structured analysis results dictionary
    """
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    all_conditions = [c for c in merged_data.keys() if c != 'roi_names']
    
    if selected_conditions is None:
        conditions_to_run = all_conditions
    else:
        conditions_to_run = []
        for cond in selected_conditions:
            if isinstance(cond, int) and 0 <= cond < len(all_conditions):
                conditions_to_run.append(all_conditions[cond])
            elif cond in all_conditions:
                conditions_to_run.append(cond)
            else:
                logger.warning(f"Condition {cond} not found, skipping")
    
    results = {
        'metadata': {
            'date': datetime.now().strftime("%Y-%m-%d"),
            'maxp': maxp,
            'all_conditions': all_conditions,
            'selected_conditions': conditions_to_run,
            'roi_names': merged_data['roi_names'],
            'analysis_type': 'multimodal_trial_analysis'
        },
        'conditions': {}
    }

    pid_atoms = ['R', 'U_X', 'U_Y', 'S']
    phiid_atoms = [
        'rtr', 'rtx', 'rty', 'rts',
        'xtr', 'xtx', 'xty', 'xts',
        'ytr', 'ytx', 'yty', 'yts',
        'str', 'stx', 'sty', 'sts'
    ]
    roi_names = merged_data['roi_names']

    dims = {}
    for cond in conditions_to_run:
        data_dict = merged_data[cond]
        min_time = min(data.shape[0] for data in data_dict['data'])
        min_trials = min(data.shape[1] for data in data_dict['data'])
        dims[cond] = (min_time, min_trials)

    with ThreadPoolExecutor(max_workers=n_workers) as executor:
        futures = {}
        for cond in conditions_to_run:
            futures[executor.submit(process_condition, 
                                  cond, merged_data[cond], roi_names, 
                                  dims[cond], maxp, pid_atoms, phiid_atoms,
                                  merged_data[cond].get('sources', []))] = cond
        
        for future in as_completed(futures):
            cond = futures[future]
            try:
                cond_results = future.result()
                results['conditions'][cond] = cond_results
                logger.info(f"Completed processing condition: {cond}")
            except Exception as e:
                logger.error(f"Error processing condition {cond}: {str(e)}")
                results['conditions'][cond] = {
                    'error': str(e),
                    'roi_pairs': {},
                    'activation_changes': [],
                    'TE': {'summary_stats': [], 'trial_level_data': []},
                    'PID': {'summary_stats': [], 'trial_level_data': []},
                    'PhiID': {'summary_stats': [], 'trial_level_data': []}
                }
    
    return results

def process_condition(cond, data_dict, roi_names, dims, maxp, pid_atoms, phiid_atoms, sources):
    """Process analysis for a single condition"""
    logger = logging.getLogger(__name__)
    min_time, min_trials = dims
    n_rois = len(roi_names)
    
    cond_results = {
        'roi_pairs': {},
        'activation_changes': [],
        'TE': {'summary_stats': [], 'trial_level_data': []},
        'PID': {'summary_stats': [], 'trial_level_data': []},
        'PhiID': {'summary_stats': [], 'trial_level_data': []},
        'trial_sources': sources
    }
    
    logger.info(f"Processing condition: {cond} with {min_trials} trials")

    roi_pairs = [(i, j) for i in range(n_rois) for j in range(n_rois) if i != j]
    
    with ThreadPoolExecutor() as executor:
        process_pair_partial = partial(process_roi_pair, 
                                     cond=cond, data_dict=data_dict, 
                                     roi_names=roi_names, min_time=min_time,
                                     min_trials=min_trials, maxp=maxp,
                                     pid_atoms=pid_atoms, phiid_atoms=phiid_atoms)
        
        futures = {executor.submit(process_pair_partial, i, j): (i, j) for i, j in roi_pairs}
        
        for future in as_completed(futures):
            i, j = futures[future]
            roi1_name, roi2_name = roi_names[i], roi_names[j]
            pair_key = f"{roi1_name}_vs_{roi2_name}"
            
            try:
                pair_results = future.result()
                cond_results['roi_pairs'][pair_key] = pair_results
                logger.debug(f"Completed ROI pair: {pair_key}")
            except Exception as e:
                logger.error(f"Error processing ROI pair {pair_key}: {str(e)}")
                cond_results['roi_pairs'][pair_key] = {
                    'error': str(e),
                    'roi_info': {'names': [roi1_name, roi2_name], 'indices': [i, j]},
                    'VAR': {'p_opt': None, 'A': [], 'V': []},
                    'TE': {'values': [], 'mean': None},
                    'PID': {'trial_values': defaultdict(list), 'mean': {}},
                    'PhiID': {'trial_values': defaultdict(list), 'mean': {}}
                }

    return cond_results

def process_roi_pair(i, j, cond, data_dict, roi_names, min_time, min_trials, maxp, pid_atoms, phiid_atoms):
    """Process analysis for a single ROI pair"""
    logger = logging.getLogger(__name__)
    roi1_name, roi2_name = roi_names[i], roi_names[j]
    pair_key = f"{roi1_name}_vs_{roi2_name}"
    
    pair_results = {
        'roi_info': {'names': [roi1_name, roi2_name], 'indices': [i, j]},
        'VAR': {'p_opt': None, 'A': [], 'V': []},
        'TE': {'values': [], 'mean': None, 'trial_sources': []}, 
        'PID': {'trial_values': defaultdict(list), 'mean': {}, 'trial_sources': []},
        'PhiID': {'trial_values': defaultdict(list), 'mean': {}, 'trial_sources': []}
    }

    data_roi1 = data_dict['data'][i, :, :]
    data_roi2 = data_dict['data'][j, :, :]
    sources = data_dict.get('sources', [None]*min_trials)
    
    with ThreadPoolExecutor() as executor:
        futures = {}
        for trial in range(min_trials):
            futures[executor.submit(process_trial, 
                                    trial, data_roi1, data_roi2, 
                                    min_time, maxp, pid_atoms, phiid_atoms,
                                    sources[trial] if trial < len(sources) else None)] = trial
        
        for future in as_completed(futures):
            trial = futures[future]
            try:
                trial_result = future.result()
                if trial_result is not None:
                    p_opt, A, V, te, pid, phiid, source = trial_result
                    
                    pair_results['VAR']['p_opt'] = p_opt
                    pair_results['VAR']['A'].append(A)
                    pair_results['VAR']['V'].append(V)
                    
                    pair_results['TE']['values'].append(te)
                    pair_results['TE']['trial_sources'].append(source)
                    
                    for a, atom in enumerate(pid_atoms):
                        pair_results['PID']['trial_values'][atom].append(float(pid[a]))
                    pair_results['PID']['trial_sources'].append(source)

                    for a, atom in enumerate(phiid_atoms):
                        pair_results['PhiID']['trial_values'][atom].append(float(phiid[a]))
                    pair_results['PhiID']['trial_sources'].append(source)
                        
            except Exception as e:
                logger.warning(f"Error processing trial {trial} for pair {pair_key}: {str(e)}")
                continue
    
    if pair_results['TE']['values']:
        pair_results['TE']['mean'] = np.nanmean(pair_results['TE']['values'])
        
        for atom in pid_atoms:
            if atom in pair_results['PID']['trial_values']:
                pair_results['PID']['mean'][atom] = np.nanmean(pair_results['PID']['trial_values'][atom])
        
        for atom in phiid_atoms:
            if atom in pair_results['PhiID']['trial_values']:
                pair_results['PhiID']['mean'][atom] = np.nanmean(pair_results['PhiID']['trial_values'][atom])
    
    return pair_results

def process_trial(trial, data_roi1, data_roi2, min_time, maxp, source):
    """Process calculations for a single trial"""
    roi1 = data_roi1[:min_time, trial:trial+1]
    roi2 = data_roi2[:min_time, trial:trial+1]
    combined = np.stack([roi1, roi2], axis=0)
    
    p_opt, A, V = fit_and_select_model(combined, maxp)
    te = calculate_te(A, V)[0,1]
    
    if np.isnan(te).any():
        raise ValueError("NaN values in TE calculation")
    
    pid = PID_VAR_calculator(p=p_opt, A=A, V=V, L1=1, L2=1)[0]
    phiid = PhiID_VAR_calculator(p=p_opt, A=A, V=V, L1=1, L2=1)[0]
    
    return p_opt, A, V, te, pid, phiid, source


def fit_and_select_model(data, maxp, fixed_p=None):
    """Fit VAR model and select optimal order"""
    from analysis.VAR_fitness import tsdata_to_varmo, fit_var
    
    if fixed_p is not None:
        p_opt = fixed_p
    else:
        p_opt, _, _ = tsdata_to_varmo(data, maxp)[:3]
        p_opt = max(1, p_opt)  # Ensure minimum order of 1
    
    A, V, _ = fit_var(data, p=p_opt)
    return p_opt, A, V

def calculate_te(A, V):
    """Calculate transfer entropy matrix"""
    n = A.shape[1]
    te_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j and V[i,i] > 1e-6:
                te_matrix[i,j] = 0.5 * np.log(V[j,j] / (V[j,j] - V[j,i]**2/V[i,i]))
    return te_matrix

In [None]:
import re
import pickle

def save_results(results, cond=None, output_dir='results'):
    """
    Save analysis results in both .npz and .pkl formats with timestamp
    
    Args:
        results: Analysis results dictionary to save
        cond: Condition name (optional)
        output_dir: Output directory (default 'results')
    
    Returns:
        List of saved file paths [npz_path, pkl_path]
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Prepare data for saving
    save_dict = {}
    for key, value in results.items():
        if not isinstance(value, (np.ndarray, str, int, float, bool)):
            try:
                value = np.array(value)
            except:
                value = str(value)
        save_dict[key] = value
    
    # Generate filename
    if cond is None and 'metadata' in results:
        cond = '_'.join(results.get('metadata', {}).get('selected_conditions', []))
    
    if cond:
        safe_cond = re.sub(r'[\\/*?:"<>|]', '_', str(cond))
        basename = f"results_{safe_cond}"
    else:
        basename = "results_all"
    
    # Save .npz format
    npz_path = os.path.join(output_dir, f"{basename}.npz")
    try:
        np.savez_compressed(npz_path, **save_dict)
    except Exception as e:
        np.savez(npz_path, **save_dict)
        print(f"Used uncompressed .npz format due to: {str(e)}")
    
    # Save .pkl format
    pkl_path = os.path.join(output_dir, f"{basename}.pkl")
    try:
        with open(pkl_path, 'wb') as f:
            pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
    except Exception as e:
        print(f"Error saving .pkl file: {str(e)}")
        try:
            with open(pkl_path, 'wb') as f:
                pickle.dump(save_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            print(f"Failed to save simplified .pkl file: {str(e)}")
            pkl_path = None
    
    print(f"Results saved to:\n- {npz_path}\n- {pkl_path if pkl_path else 'Failed to save .pkl file'}")
    
    return [npz_path, pkl_path] if pkl_path else [npz_path]

In [None]:
all_conditions = [c for c in merged_data.keys() if c != 'roi_names']

for cond in all_conditions:
    print(f"\nProcessing {cond}...")
    results = analyze_roi_pairs(
        merged_data, 
        selected_conditions=[cond],
        n_workers=2  
    )
    save_path = save_results(results, cond=cond)
    
    del results
    import gc; gc.collect()

INFO:__main__:Processing condition: Rest with 17 trials



Processing Rest...


INFO:__main__:Completed processing condition: Rest
INFO:__main__:Processing condition: MOp (L) with 7 trials


Results saved to:
- results/results_Rest.npz
- results/results_Rest.pkl

Processing MOp (L)...


INFO:__main__:Completed processing condition: MOp (L)
INFO:__main__:Processing condition: VISam/pm (R) with 14 trials


Results saved to:
- results/results_MOp (L).npz
- results/results_MOp (L).pkl

Processing VISam/pm (R)...


INFO:__main__:Completed processing condition: VISam/pm (R)
INFO:__main__:Processing condition: AUD (L) with 14 trials


Results saved to:
- results/results_VISam_pm (R).npz
- results/results_VISam_pm (R).pkl

Processing AUD (L)...


INFO:__main__:Completed processing condition: AUD (L)
INFO:__main__:Processing condition: SSp-ul/ll (R) with 8 trials


Results saved to:
- results/results_AUD (L).npz
- results/results_AUD (L).pkl

Processing SSp-ul/ll (R)...


INFO:__main__:Completed processing condition: SSp-ul/ll (R)
INFO:__main__:Processing condition: RSPd/v (Bilateral) with 7 trials


Results saved to:
- results/results_SSp-ul_ll (R).npz
- results/results_SSp-ul_ll (R).pkl

Processing RSPd/v (Bilateral)...


INFO:__main__:Completed processing condition: RSPd/v (Bilateral)
INFO:__main__:Processing condition: VISp (L) with 8 trials


Results saved to:
- results/results_RSPd_v (Bilateral).npz
- results/results_RSPd_v (Bilateral).pkl

Processing VISp (L)...


INFO:__main__:Completed processing condition: VISp (L)
INFO:__main__:Processing condition: MOs (R) with 8 trials


Results saved to:
- results/results_VISp (L).npz
- results/results_VISp (L).pkl

Processing MOs (R)...


INFO:__main__:Completed processing condition: MOs (R)
INFO:__main__:Processing condition: SSp-bfd (L) with 7 trials


Results saved to:
- results/results_MOs (R).npz
- results/results_MOs (R).pkl

Processing SSp-bfd (L)...


INFO:__main__:Completed processing condition: SSp-bfd (L)
INFO:__main__:Processing condition: VISa/rl (R) with 14 trials


Results saved to:
- results/results_SSp-bfd (L).npz
- results/results_SSp-bfd (L).pkl

Processing VISa/rl (R)...


INFO:__main__:Completed processing condition: VISa/rl (R)


Results saved to:
- results/results_VISa_rl (R).npz
- results/results_VISa_rl (R).pkl
