In [5]:
# DETAILED ANALYSIS: 3-AXIS METRICS + ALL DRIFT TYPES + STATE IMPACT
# Copy and paste this ENTIRE code into a single JupyterLab cell

import pandas as pd
import numpy as np
import json
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')


class DetailedSessionAnalyzer:
    """Detailed analysis including drift types and state impact."""
    
    def __init__(self, csv_path, metadata_path):
        """Initialize analyzer with session data."""
        self.csv_path = csv_path
        self.metadata_path = metadata_path
        self.df = pd.read_csv(csv_path)
        
        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)
        
        self.session_name = Path(csv_path).stem
        print(f"✓ Loaded: {self.session_name}")
    
    def get_3axis_exact(self):
        """Get exact 3-axis values for entire session."""
        arousal_pre = self.metadata['pre_session'].get('arousal_sigma', None)
        control_pre = self.metadata['pre_session'].get('control_sigma', None)
        executive_pre = self.metadata['pre_session'].get('executive_sigma', None)
        
        # Calculate post from last 20% of session
        last_20pct = self.df.tail(int(len(self.df) * 0.2))
        arousal_post = last_20pct['arousal_sigma'].mean()
        control_post = last_20pct['control_sigma'].mean()
        executive_post = last_20pct['executive_sigma'].mean()
        
        return {
            'arousal': {'pre': arousal_pre, 'post': arousal_post, 'change': arousal_post - arousal_pre},
            'control': {'pre': control_pre, 'post': control_post, 'change': control_post - control_pre},
            'executive': {'pre': executive_pre, 'post': executive_post, 'change': executive_post - executive_pre},
        }
    
    def get_all_drift_types(self):
        """Analyze all types of drift (not just fatigue)."""
        drift_analysis = {
            'total_trials': len(self.df),
        }
        
        # Get all unique drift labels
        drift_labels = self.df['drift_label'].value_counts().to_dict()
        
        for drift_type, count in drift_labels.items():
            if pd.notna(drift_type) and drift_type != 'NONE':
                pct = (count / len(self.df)) * 100
                drift_analysis[drift_type] = {
                    'count': int(count),
                    'percentage': pct
                }
        
        # If no drift detected
        none_count = (self.df['drift_label'] == 'NONE').sum()
        if none_count > 0:
            drift_analysis['NONE'] = {
                'count': int(none_count),
                'percentage': (none_count / len(self.df)) * 100
            }
        
        return drift_analysis
    
    def analyze_state_during_drift(self):
        """Analyze how state is affected DURING drift periods."""
        state_during_drift = {}
        
        # Get all unique drift types (excluding NONE)
        drift_types = self.df[self.df['drift_label'] != 'NONE']['drift_label'].unique()
        
        for drift_type in drift_types:
            if pd.notna(drift_type):
                drift_trials = self.df[self.df['drift_label'] == drift_type]
                
                if len(drift_trials) > 0:
                    state_dist = drift_trials['windowed_state'].value_counts().to_dict()
                    
                    # Calculate metrics during this drift
                    arousal_during = drift_trials['arousal_sigma'].mean()
                    control_during = drift_trials['control_sigma'].mean()
                    executive_during = drift_trials['executive_sigma'].mean()
                    error_risk_during = drift_trials['error_risk'].mean()
                    
                    state_during_drift[drift_type] = {
                        'trial_count': len(drift_trials),
                        'percentage': (len(drift_trials) / len(self.df)) * 100,
                        'state_distribution': state_dist,
                        'arousal_sigma': arousal_during,
                        'control_sigma': control_during,
                        'executive_sigma': executive_during,
                        'error_risk': error_risk_during,
                    }
        
        return state_during_drift
    
    def get_intervention_details(self):
        """Get intervention details if any."""
        interventions = self.metadata.get('interventions', [])
        
        intervention_analysis = {
            'total_interventions': len(interventions),
            'details': []
        }
        
        for intervention in interventions:
            trial_num = intervention.get('trial')
            state = intervention.get('state')
            drift_pct = intervention.get('cumulative_drift_pct')
            
            # Get 3-axis at intervention point
            intervention_trial = self.df[self.df['trial'] == trial_num]
            
            if len(intervention_trial) > 0:
                intervention_analysis['details'].append({
                    'trial': trial_num,
                    'state': state,
                    'cumulative_drift_pct': drift_pct,
                    'arousal_sigma': intervention_trial['arousal_sigma'].values[0],
                    'control_sigma': intervention_trial['control_sigma'].values[0],
                    'executive_sigma': intervention_trial['executive_sigma'].values[0],
                })
        
        return intervention_analysis


# =====================================================
# YOUR 10 SESSIONS - WITH FULL WINDOWS PATHS
# =====================================================

base_path = r'C:\Users\rapol\Downloads\RealTime\sessions'

sessions = [
    (f'{base_path}\\PEAK_GENERATIVE_session_20251210_110640.csv', 
     f'{base_path}\\PEAK_GENERATIVE_session_20251210_110640_metadata.json'),
    
    (f'{base_path}\\PASSIVE_CONSUMPTION_session_20251210_150933.csv',
     f'{base_path}\\PASSIVE_CONSUMPTION_session_20251210_150933_metadata.json'),
    
    (f'{base_path}\\RECEPTIVE_ENGAGED_session_20251210_181613.csv',
     f'{base_path}\\RECEPTIVE_ENGAGED_session_20251210_181613_metadata.json'),
    
    (f'{base_path}\\STANDARD_WORK_session_20251211_111554.csv',
     f'{base_path}\\STANDARD_WORK_session_20251211_111554_metadata.json'),
    
    (f'{base_path}\\RECEPTIVE_ENGAGED_session_20251211_144841.csv',
     f'{base_path}\\RECEPTIVE_ENGAGED_session_20251211_144841_metadata.json'),
    
    (f'{base_path}\\STANDARD_WORK_session_20251212_102203.csv',
     f'{base_path}\\STANDARD_WORK_session_20251212_102203_metadata.json'),
    
    (f'{base_path}\\RECEPTIVE_ENGAGED_session_20251213_112551.csv',
     f'{base_path}\\RECEPTIVE_ENGAGED_session_20251213_112551_metadata.json'),
    
    (f'{base_path}\\RECEPTIVE_ENGAGED_session_20251225_114432.csv',
     f'{base_path}\\RECEPTIVE_ENGAGED_session_20251225_114432_metadata.json'),
    
    (f'{base_path}\\RECEPTIVE_ENGAGED_session_20251225_152121.csv',
     f'{base_path}\\RECEPTIVE_ENGAGED_session_20251225_152121_metadata.json'),
    
    (f'{base_path}\\STANDARD_WORK_session_20251226_154644.csv',
     f'{base_path}\\STANDARD_WORK_session_20251226_154644_metadata.json'),
]

# =====================================================
# ANALYZE ALL 10 SESSIONS
# =====================================================

print("\n" + "="*120)
print("DETAILED ANALYSIS: 3-AXIS METRICS + ALL DRIFT TYPES + STATE IMPACT")
print("="*120 + "\n")

all_results = {}

for csv_file, meta_file in sessions:
    try:
        analyzer = DetailedSessionAnalyzer(csv_file, meta_file)
        
        # Get 3-axis metrics
        axis_3 = analyzer.get_3axis_exact()
        
        # Get all drift types
        drift_types = analyzer.get_all_drift_types()
        
        # Get state during drift
        state_impact = analyzer.analyze_state_during_drift()
        
        # Get interventions
        interventions = analyzer.get_intervention_details()
        
        all_results[analyzer.session_name] = {
            '3axis': axis_3,
            'drift_types': drift_types,
            'state_during_drift': state_impact,
            'interventions': interventions,
        }
        
    except Exception as e:
        print(f"  ✗ Error processing session: {str(e)}")

# =====================================================
# PRINT DETAILED RESULTS FOR EACH SESSION
# =====================================================

for session_name, results in all_results.items():
    print("\n" + "="*120)
    print(f"SESSION: {session_name}")
    print("="*120)
    
    # 3-AXIS METRICS
    print("\n3-AXIS METRICS (EXACT NUMBERS):")
    print("-" * 120)
    axis = results['3axis']
    
    print(f"\nAROUSAL:")
    print(f"  Pre-Session:   {axis['arousal']['pre']:+.4f}")
    print(f"  Post-Session:  {axis['arousal']['post']:+.4f}")
    print(f"  Change (Δ):    {axis['arousal']['change']:+.4f}")
    
    print(f"\nCONTROL:")
    print(f"  Pre-Session:   {axis['control']['pre']:+.4f}")
    print(f"  Post-Session:  {axis['control']['post']:+.4f}")
    print(f"  Change (Δ):    {axis['control']['change']:+.4f}")
    
    print(f"\nEXECUTIVE (Cognitive Load):")
    print(f"  Pre-Session:   {axis['executive']['pre']:+.4f}")
    print(f"  Post-Session:  {axis['executive']['post']:+.4f}")
    print(f"  Change (Δ):    {axis['executive']['change']:+.4f}")
    
    # ALL DRIFT TYPES
    print(f"\n\nALL DRIFT TYPES:")
    print("-" * 120)
    drift = results['drift_types']
    total = drift['total_trials']
    
    for drift_type, data in drift.items():
        if drift_type != 'total_trials':
            print(f"  {drift_type}: {data['count']} trials ({data['percentage']:.2f}%)")
    
    # STATE IMPACT DURING DRIFT
    if results['state_during_drift']:
        print(f"\n\nSTATE IMPACT DURING DRIFT PERIODS:")
        print("-" * 120)
        
        for drift_type, impact in results['state_during_drift'].items():
            print(f"\n{drift_type} ({impact['trial_count']} trials, {impact['percentage']:.2f}%):")
            print(f"  States during {drift_type}:")
            for state, count in impact['state_distribution'].items():
                state_pct = (count / impact['trial_count']) * 100
                print(f"    • {state}: {count} ({state_pct:.1f}%)")
            
            print(f"  3-Axis during {drift_type}:")
            print(f"    Arousal:   {impact['arousal_sigma']:+.4f}")
            print(f"    Control:   {impact['control_sigma']:+.4f}")
            print(f"    Executive: {impact['executive_sigma']:+.4f}")
            print(f"    Error Risk: {impact['error_risk']:.4f}")
    else:
        print(f"\nNo drift detected in this session.")
    
    # INTERVENTIONS
    if results['interventions']['total_interventions'] > 0:
        print(f"\n\nINTERVENTIONS TRIGGERED: {results['interventions']['total_interventions']}")
        print("-" * 120)
        
        for intervention in results['interventions']['details']:
            print(f"\nIntervention at Trial {intervention['trial']}:")
            print(f"  State: {intervention['state']}")
            print(f"  Cumulative Drift: {intervention['cumulative_drift_pct']:.1f}%")
            print(f"  3-Axis at intervention:")
            print(f"    Arousal:   {intervention['arousal_sigma']:+.4f}")
            print(f"    Control:   {intervention['control_sigma']:+.4f}")
            print(f"    Executive: {intervention['executive_sigma']:+.4f}")

# =====================================================
# SUMMARY TABLE: 3-AXIS ACROSS ALL SESSIONS
# =====================================================

print("\n\n" + "="*120)
print("SUMMARY TABLE: 3-AXIS CHANGES ACROSS ALL 10 SESSIONS")
print("="*120)

summary_data = []
for session_name, results in all_results.items():
    axis = results['3axis']
    summary_data.append({
        'Session': session_name,
        'Arousal_Pre': f"{axis['arousal']['pre']:+.4f}",
        'Arousal_Post': f"{axis['arousal']['post']:+.4f}",
        'Arousal_Δ': f"{axis['arousal']['change']:+.4f}",
        'Control_Pre': f"{axis['control']['pre']:+.4f}",
        'Control_Post': f"{axis['control']['post']:+.4f}",
        'Control_Δ': f"{axis['control']['change']:+.4f}",
        'Executive_Pre': f"{axis['executive']['pre']:+.4f}",
        'Executive_Post': f"{axis['executive']['post']:+.4f}",
        'Executive_Δ': f"{axis['executive']['change']:+.4f}",
    })

df_summary = pd.DataFrame(summary_data)
print(df_summary.to_string(index=False))

print("\n" + "="*120)
print("ANALYSIS COMPLETE!")
print("="*120)



DETAILED ANALYSIS: 3-AXIS METRICS + ALL DRIFT TYPES + STATE IMPACT

✓ Loaded: PEAK_GENERATIVE_session_20251210_110640
✓ Loaded: PASSIVE_CONSUMPTION_session_20251210_150933
✓ Loaded: RECEPTIVE_ENGAGED_session_20251210_181613
✓ Loaded: STANDARD_WORK_session_20251211_111554
✓ Loaded: RECEPTIVE_ENGAGED_session_20251211_144841
✓ Loaded: STANDARD_WORK_session_20251212_102203
✓ Loaded: RECEPTIVE_ENGAGED_session_20251213_112551
✓ Loaded: RECEPTIVE_ENGAGED_session_20251225_114432
✓ Loaded: RECEPTIVE_ENGAGED_session_20251225_152121
✓ Loaded: STANDARD_WORK_session_20251226_154644

SESSION: PEAK_GENERATIVE_session_20251210_110640

3-AXIS METRICS (EXACT NUMBERS):
------------------------------------------------------------------------------------------------------------------------

AROUSAL:
  Pre-Session:   -0.0100
  Post-Session:  +0.0254
  Change (Δ):    +0.0354

CONTROL:
  Pre-Session:   +0.7100
  Post-Session:  -0.2213
  Change (Δ):    -0.9313

EXECUTIVE (Cognitive Load):
  Pre-Session:   -0.