In [16]:
import pandas as pd
import numpy as np
from pathlib import Path
from collections import deque
import json
import warnings
warnings.filterwarnings('ignore')

# ================== CONFIG ==================
data_dir = Path(r"C:\Users\rapol\Downloads\lab_analysis_v6_0_grounded")
output_dir = Path(r"C:\Users\rapol\Downloads\drift_analysis_results")
output_dir.mkdir(exist_ok=True)

# build first 48 files = subjects 01–16, sessions S1–S3
subjects = [f"{i:02d}" for i in range(1, 17)]  # 01..16
sessions = ["S1", "S2", "S3"]

files = [
    str(data_dir / f"sub-{sub}_ses-{ses}_6STATES_v6_0.csv")
    for sub in subjects for ses in sessions
]

print(f"Will process {len(files)} files (should be 48):")
for f in files:
    print(" ", Path(f).name)

# ============= CONFIG FROM ANALYZER =============
VALIDATED_Z_SCORES = [
    'theta', 'beta', 'theta_beta_ratio', 'alpha', 'pe',
    'delta', 'alpha_relative', 'gamma', 'mse', 'lz',
    'pac', 'wpe', 'frontal_asymmetry', 'at_ratio'
]

MW_THRESHOLDS = {
    'tbr_moderate': 0.25,
    'alpha_decrease': -0.15,
    'pe_moderate': -0.2,
    'lz_decrease': -0.3,
}

FATIGUE_THRESHOLDS = {
    'alpha_elevation': 0.8,
    'delta_elevation': 0.3,
    'theta_elevation': 0.2,
    'beta_decrease': -0.3,
    'pwr_loss': -0.4,
}

OVERLOAD_THRESHOLDS = {
    'theta_extreme': 2.0,
    'pac_surge': 1.2,
}

INTERVENTION_CUMULATIVE_DRIFT_THRESHOLD = 50.0
INTERVENTION_COOLDOWN_TRIALS = 30
OPTIMAL_STATES = ['Optimal-Monitoring', 'Optimal-Engaged']
DRIFT_STATES = ['Mind-Wandering', 'Fatigue', 'Overload']

# ============= CORE CLASSES/FUNCTIONS (minimal) =============
class TemporalSmoothingEngine:
    def __init__(self, window_size=10):
        self.state_history = deque(maxlen=window_size)
        self.drift_trajectory = deque(maxlen=20)
    def add_trial(self, state, drift_strength):
        self.state_history.append(state)
        self.drift_trajectory.append(drift_strength)
    def predict_drift_risk(self):
        if len(self.drift_trajectory) < 3:
            return 0
        recent = list(self.drift_trajectory)[-10:]
        mean_drift = np.mean(recent)
        trend = recent[-1] - recent[0] if len(recent) >= 2 else 0
        risk = min(100, int(mean_drift + trend*5))
        return max(0, risk)

class CumulativeDriftTracker:
    def __init__(self, window_seconds=120):
        self.drift_history = deque(maxlen=int(window_seconds/2))
    def add_trial(self, is_drift):
        self.drift_history.append(1 if is_drift else 0)
    def get_cumulative_drift_pct(self):
        if not self.drift_history:
            return 0.0
        return sum(self.drift_history)/len(self.drift_history)*100.0

class WindowedStateClassifier:
    def __init__(self, window_size=15):
        self.instant_states = deque(maxlen=window_size)
        self.z_scores_list = deque(maxlen=window_size)
    def add_instant_classification(self, state, z_scores, signal_quality):
        self.instant_states.append(state)
        if z_scores:
            self.z_scores_list.append(z_scores)
    def get_windowed_state(self):
        if not self.instant_states:
            return 'Calibrating', 0.0, None
        counts = {}
        for s in self.instant_states:
            counts[s] = counts.get(s, 0) + 1
        majority_state = max(counts, key=counts.get)
        confidence = counts[majority_state]/len(self.instant_states)
        avg_z = None
        if self.z_scores_list:
            avg_z = {
                key: np.mean([z.get(key, 0) for z in self.z_scores_list])
                for key in self.z_scores_list[0].keys()
            }
        return majority_state, confidence, avg_z

def classify_state(z_scores):
    if not z_scores:
        return 'Calibrating'
    tbr = z_scores.get('theta_beta_ratio', 0)
    alpha = z_scores.get('alpha', 0)
    theta = z_scores.get('theta', 0)
    delta = z_scores.get('delta', 0)
    if theta > 1.5 and tbr > 0.8:
        return 'Overload'
    if alpha > 0.5 and delta > 0.2 and tbr < 0.1:
        return 'Fatigue'
    if tbr > 0.3 and alpha < -0.2:
        return 'Mind-Wandering'
    if theta > 0.3 and tbr > 0.1 and alpha > -0.2:
        return 'Optimal-Engaged'
    return 'Optimal-Monitoring'

def detect_drift_enhanced(z_scores, windowed_state):
    if not z_scores:
        return {'drift_strength': 0, 'drift_label': 'NONE',
                'drift_markers': [], 'error_risk': 0}
    markers = []
    strength = 0
    if z_scores.get('theta_beta_ratio', 0) > MW_THRESHOLDS['tbr_moderate']:
        markers.append('high_TBR'); strength += 15
    if z_scores.get('alpha', 0) < MW_THRESHOLDS['alpha_decrease']:
        markers.append('low_alpha'); strength += 15
    if z_scores.get('pe', 0) < MW_THRESHOLDS['pe_moderate']:
        markers.append('low_PE'); strength += 10
    if z_scores.get('lz', 0) < MW_THRESHOLDS['lz_decrease']:
        markers.append('low_LZ'); strength += 10
    if z_scores.get('alpha', 0) > FATIGUE_THRESHOLDS['alpha_elevation']:
        markers.append('high_alpha'); strength += 15
    if z_scores.get('delta', 0) > FATIGUE_THRESHOLDS['delta_elevation']:
        markers.append('high_delta'); strength += 10
    if z_scores.get('theta', 0) > FATIGUE_THRESHOLDS['theta_elevation']:
        markers.append('high_theta'); strength += 8
    if z_scores.get('beta', 0) < FATIGUE_THRESHOLDS['beta_decrease']:
        markers.append('low_beta'); strength += 10
    if z_scores.get('theta', 0) > OVERLOAD_THRESHOLDS['theta_extreme']:
        markers.append('extreme_theta'); strength += 20
    if z_scores.get('pac', 0) > OVERLOAD_THRESHOLDS['pac_surge']:
        markers.append('pac_surge'); strength += 12

    if len(markers) >= 2 and strength >= 25:
        if any(m in markers for m in ['high_TBR', 'low_alpha', 'low_PE']):
            drift_label = 'CONFIRMED'; error_risk = min(100, strength+20)
        elif any(m in markers for m in ['high_alpha', 'high_delta']):
            drift_label = 'MODERATE'; error_risk = min(100, strength+10)
        elif any(m in markers for m in ['extreme_theta', 'pac_surge']):
            drift_label = 'STRONG'; error_risk = min(100, strength+30)
        else:
            drift_label = 'WEAK'; error_risk = min(100, strength)
    elif len(markers) >= 1 and strength >= 15:
        drift_label = 'WEAK'; error_risk = min(100, strength)
    else:
        drift_label = 'NONE'; error_risk = 0

    return {
        'drift_strength': min(100, strength),
        'drift_label': drift_label,
        'drift_markers': markers,
        'error_risk': error_risk,
    }

class RetrospectiveDriftAnalyzer:
    def __init__(self):
        self.temporal_smoother = TemporalSmoothingEngine()
        self.windowed_classifier = WindowedStateClassifier()
        self.cumulative_drift_tracker = CumulativeDriftTracker()
        self.last_intervention_trial = -INTERVENTION_COOLDOWN_TRIALS
        self.trial_count = 0
    def process_trial(self, z_scores, signal_quality=0.7):
        self.trial_count += 1
        instant_state = classify_state(z_scores)
        self.windowed_classifier.add_instant_classification(instant_state, z_scores, signal_quality)
        windowed_state, confidence, avg_z = self.windowed_classifier.get_windowed_state()
        use_z = avg_z if avg_z else z_scores
        drift_info = detect_drift_enhanced(use_z, windowed_state)
        self.temporal_smoother.add_trial(windowed_state, drift_info['drift_strength'])
        drift_risk = self.temporal_smoother.predict_drift_risk()
        is_drift = windowed_state in DRIFT_STATES
        self.cumulative_drift_tracker.add_trial(is_drift)
        cumulative_drift_pct = self.cumulative_drift_tracker.get_cumulative_drift_pct()

        intervention_triggered = False
        if (cumulative_drift_pct > INTERVENTION_CUMULATIVE_DRIFT_THRESHOLD and
            self.trial_count - self.last_intervention_trial >= INTERVENTION_COOLDOWN_TRIALS):
            intervention_triggered = True
            self.last_intervention_trial = self.trial_count

        out = {
            'trial': self.trial_count,
            'instant_state': instant_state,
            'windowed_state': windowed_state,
            'confidence': round(confidence, 3),
            'signal_quality': round(signal_quality, 2),
            'drift_strength': drift_info['drift_strength'],
            'drift_label': drift_info['drift_label'],
            'drift_markers': ','.join(drift_info['drift_markers']),
            'error_risk': drift_info['error_risk'],
            'drift_risk': drift_risk,
            'cumulative_drift_pct': round(cumulative_drift_pct, 1),
            'intervention_triggered': 'YES' if intervention_triggered else 'NO',
        }
        for k, v in z_scores.items():
            out[f'z_{k}'] = round(v, 3)
        return out

# ============= RUN OVER 48 FILES → ONE SESSION SUMMARY CSV =============
all_session_stats = []

for i, file_path in enumerate(files, 1):
    fp = Path(file_path)
    print(f"[{i}/{len(files)}] {fp.name}")
    df = pd.read_csv(file_path)

    # group by subject/session in case a file has >1 subject/session
    for (subj, sess), g in df.groupby(['subject', 'session']):
        analyzer = RetrospectiveDriftAnalyzer()
        session_results = []
        for _, row in g.iterrows():
            z_scores = {}
            for feat in VALIDATED_Z_SCORES:
                col = f'z_{feat}'
                if col in row.index:
                    z_scores[feat] = float(row[col])
            signal_quality = float(row.get('signal_quality', 0.7))
            out = analyzer.process_trial(z_scores, signal_quality)
            out['subject'] = subj
            out['session'] = sess
            session_results.append(out)

        df_sess = pd.DataFrame(session_results)
        n_trials = len(df_sess)
        n_interventions = (df_sess['intervention_triggered']=='YES').sum()
        pct_optimal = (df_sess['windowed_state'].isin(OPTIMAL_STATES)).sum()/n_trials*100
        pct_drift   = (df_sess['windowed_state'].isin(DRIFT_STATES)).sum()/n_trials*100
        mean_cum_drift = df_sess['cumulative_drift_pct'].mean()
        max_cum_drift  = df_sess['cumulative_drift_pct'].max()
        mean_error_risk = df_sess['error_risk'].mean()

        all_session_stats.append({
            'subject': subj,
            'session': sess,
            'n_trials': n_trials,
            'n_interventions': int(n_interventions),
            'pct_optimal': round(pct_optimal, 1),
            'pct_drift': round(pct_drift, 1),
            'mean_cumulative_drift': round(mean_cum_drift, 1),
            'max_cumulative_drift': round(max_cum_drift, 1),
            'mean_error_risk': round(mean_error_risk, 1),
        })

# save ONE combined session-summary CSV with your requested name
session_df = pd.DataFrame(all_session_stats)
session_csv = output_dir / "drift_analysis_results_phase2_session_summary.csv"
session_df.to_csv(session_csv, index=False)

print("\nSaved combined session summary:")
print(session_csv)
print(f"Rows: {len(session_df)}")


Will process 48 files (should be 48):
  sub-01_ses-S1_6STATES_v6_0.csv
  sub-01_ses-S2_6STATES_v6_0.csv
  sub-01_ses-S3_6STATES_v6_0.csv
  sub-02_ses-S1_6STATES_v6_0.csv
  sub-02_ses-S2_6STATES_v6_0.csv
  sub-02_ses-S3_6STATES_v6_0.csv
  sub-03_ses-S1_6STATES_v6_0.csv
  sub-03_ses-S2_6STATES_v6_0.csv
  sub-03_ses-S3_6STATES_v6_0.csv
  sub-04_ses-S1_6STATES_v6_0.csv
  sub-04_ses-S2_6STATES_v6_0.csv
  sub-04_ses-S3_6STATES_v6_0.csv
  sub-05_ses-S1_6STATES_v6_0.csv
  sub-05_ses-S2_6STATES_v6_0.csv
  sub-05_ses-S3_6STATES_v6_0.csv
  sub-06_ses-S1_6STATES_v6_0.csv
  sub-06_ses-S2_6STATES_v6_0.csv
  sub-06_ses-S3_6STATES_v6_0.csv
  sub-07_ses-S1_6STATES_v6_0.csv
  sub-07_ses-S2_6STATES_v6_0.csv
  sub-07_ses-S3_6STATES_v6_0.csv
  sub-08_ses-S1_6STATES_v6_0.csv
  sub-08_ses-S2_6STATES_v6_0.csv
  sub-08_ses-S3_6STATES_v6_0.csv
  sub-09_ses-S1_6STATES_v6_0.csv
  sub-09_ses-S2_6STATES_v6_0.csv
  sub-09_ses-S3_6STATES_v6_0.csv
  sub-10_ses-S1_6STATES_v6_0.csv
  sub-10_ses-S2_6STATES_v6_0.csv
  sub