In [None]:

import os
import sys
import json
import logging
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, StratifiedKFold
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import defaultdict, Counter
from scipy import stats
import warnings
warnings.filterwarnings("ignore")

# =============================================================================
#  CONFIGURATION (NO YAML)
# =============================================================================
EDNET_PATH = "./EdNet-KT1/"        
MAX_USERS = 800
MIN_INTERACTIONS = 20
BATCH_SIZE = 64
EPOCHS = 8
DEVICE = torch.device("cpu")
CV_FOLDS = 3
N_RECOMMENDATIONS = 5
MAX_SEQ_LEN = 200
FAIRNESS_GROUPS = ["low", "medium", "high"]
WEIGHT_CONFIGS = [
    [1.0, 0.0, 0.0],  # Mastery-only
    [0.0, 1.0, 0.0],  # Engagement-only
    [0.0, 0.0, 1.0],  # Fairness-only
    [0.7, 0.2, 0.1],  # Mastery-focused
    [0.5, 0.3, 0.2]   # Balanced
]
BOOTSTRAP_SAMPLES = 1000
ALPHA = 0.05

# =============================================================================
# REPRODUCIBILITY & SETUP
# =============================================================================
torch.manual_seed(42)
np.random.seed(42)
os.makedirs("outputs/figures", exist_ok=True)
os.makedirs("outputs/models", exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("outputs/training.log"),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# =============================================================================
# DATA LOADING & CORRECTNESS INFERENCE
# =============================================================================
def load_and_preprocess():
    logger.info("🔍 Loading EdNet-KT1 and inferring correctness...")
    user_files = [f for f in os.listdir(EDNET_PATH) if f.startswith("u") and f.endswith(".csv")]
    if not user_files:
        raise FileNotFoundError(f"No .csv files in {EDNET_PATH}. Ensure it contains real KT1 data.")
    
    np.random.shuffle(user_files)
    user_files = user_files[:MAX_USERS]
    
    # Pass 1: Collect answers to infer correct responses
    all_answers = []
    valid_files = []
    for f in tqdm(user_files, desc="Inferring correctness"):
        try:
            df_user = pd.read_csv(os.path.join(EDNET_PATH, f))
            if "question_id" in df_user.columns and "user_answer" in df_user.columns:
                all_answers.append(df_user[["question_id", "user_answer"]])
                valid_files.append(f)
        except Exception as e:
            continue
    
    if not all_answers:
        raise ValueError("No valid KT1 files found. Ensure files have 'question_id' and 'user_answer'.")
    
    # Infer correct answer as most frequent response per question
    all_answers_df = pd.concat(all_answers, ignore_index=True)
    answer_map = all_answers_df.groupby("question_id")["user_answer"].agg(
        lambda x: Counter(x).most_common(1)[0][0]
    ).to_dict()
    
    # Pass 2: Load data with inferred correctness
    data_frames = []
    for f in tqdm(valid_files, desc="Loading user data"):
        try:
            df_user = pd.read_csv(os.path.join(EDNET_PATH, f))
            df_user["user_id"] = f.replace(".csv", "")
            if "question_id" not in df_user.columns or "user_answer" not in df_user.columns:
                continue
            df_user["correct"] = df_user.apply(
                lambda row: 1 if row["user_answer"] == answer_map.get(row["question_id"], None) else 0,
                axis=1
            )
            if len(df_user) >= MIN_INTERACTIONS:
                data_frames.append(df_user)
        except Exception as e:
            continue
    
    if not data_frames:
        raise ValueError("No valid user data after correctness inference.")
    
    df_raw = pd.concat(data_frames, ignore_index=True)
    df = df_raw.dropna(subset=["question_id", "correct"]).copy()
    df["correct"] = df["correct"].astype(int)
    df["timestamp"] = pd.to_datetime(df["timestamp"], errors='coerce')
    df = df.sort_values(["user_id", "timestamp"]).reset_index(drop=True)
    df["skill_id"] = pd.Categorical(df["question_id"]).codes
    
    # Skill groups (fairness proxy)
    user_perf = df.groupby("user_id")["correct"].mean()
    user_perf = pd.qcut(user_perf, q=3, labels=FAIRNESS_GROUPS)
    df["skill_group"] = df["user_id"].map(user_perf)
    
    # Time features
    df["time_delta"] = df.groupby("user_id")["timestamp"].diff().dt.total_seconds().fillna(0)
    df["time_delta"] = df["time_delta"].clip(0, 3600)
    df["time_delta_norm"] = df["time_delta"] / 3600.0
    
    # Session entropy (engagement)
    df["date"] = df["timestamp"].dt.date
    df["session_id"] = df.groupby(["user_id", "date"]).ngroup()
    session_entropy = []
    for session_id, session in df.groupby("session_id"):
        if len(session) == 0:
            session_entropy.extend([0.0] * len(session))
            continue
        counts = Counter(session["skill_id"])
        total = sum(counts.values())
        probs = np.array(list(counts.values())) / total
        entropy = -np.sum(probs * np.log(probs + 1e-8))
        session_entropy.extend([entropy] * len(session))
    df["session_entropy"] = session_entropy
    
    # Encode IDs
    user_ids = {uid: idx for idx, uid in enumerate(df["user_id"].unique())}
    skill_ids = {sid: idx for idx, sid in enumerate(df["skill_id"].unique())}
    df["user_idx"] = df["user_id"].map(user_ids)
    df["skill_idx"] = df["skill_id"].map(skill_ids)
    
    torch.save({"user_ids": user_ids, "skill_ids": skill_ids}, "outputs/id_mappings.pth")
    logger.info(f"✅ Loaded {len(df):,} interactions from {df['user_id'].nunique()} students")
    return df, len(skill_ids)

# =============================================================================
# DEEP KNOWLEDGE TRACING (DKT)
# =============================================================================
class DeepKnowledgeTracing(nn.Module):
    def __init__(self, n_skills, embed_dim=64, hidden_dim=128):
        super().__init__()
        self.n_skills = n_skills
        self.skill_embed = nn.Embedding(n_skills * 2, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, n_skills)
    def forward(self, skills, correct):
        x = skills + correct.long() * self.n_skills
        emb = self.skill_embed(x)
        out, _ = self.lstm(emb)
        return torch.sigmoid(self.fc(out))

# =============================================================================
# FAIRNESS-AWARE MORL AGENT
# =============================================================================
class FairMORLAgent(nn.Module):
    def __init__(self, n_skills, embed_dim=64, hidden_dim=128):
        super().__init__()
        self.skill_embed = nn.Embedding(n_skills, embed_dim)
        self.lstm = nn.LSTM(embed_dim + 5, hidden_dim, batch_first=True)
        self.actor = nn.Linear(hidden_dim, n_skills)
        self.critic = nn.Linear(hidden_dim, 3)
    def forward(self, x):
        skills = x[:, :, 0].long()
        others = x[:, :, 1:]
        emb = self.skill_embed(skills)
        lstm_in = torch.cat([emb, others], dim=-1)
        _, (h, _) = self.lstm(lstm_in)
        h = h.squeeze(0)
        return self.actor(h), self.critic(h)

def compute_engagement_reward(session_actions, time_deltas):
    if len(session_actions) == 0: return 0.0
    counts = Counter(session_actions)
    total = sum(counts.values())
    probs = np.array(list(counts.values())) / total
    entropy = -np.sum(probs * np.log(probs + 1e-8))
    avg_dwell = np.mean(time_deltas) if len(time_deltas) > 0 else 0
    length_norm = min(len(session_actions) / 20.0, 1.0)
    return 0.4 * entropy + 0.3 * min(avg_dwell, 1.0) + 0.3 * length_norm

# =============================================================================
# FAIRNESS METRICS
# =============================================================================
class FairnessMetrics:
    def __init__(self, fairness_groups):
        self.fairness_groups = fairness_groups
    
    def statistical_parity_difference(self, group_exposures, n_skills):
        exposures = []
        for group in self.fairness_groups:
            total = max(sum(group_exposures[group].values()), 1)
            vec = np.array([group_exposures[group].get(s, 0) / total for s in range(n_skills)])
            exposures.append(vec)
        spd = 0.0
        for i in range(len(exposures)):
            for j in range(i+1, len(exposures)):
                spd = max(spd, np.abs(exposures[i] - exposures[j]).max())
        return spd
    
    def equal_opportunity_difference(self, y_true, y_pred, groups):
        eop_values = []
        for group in self.fairness_groups:
            group_mask = (groups == group)
            if np.sum(group_mask) == 0: continue
            group_tpr = np.mean(y_pred[group_mask & (y_true == 1)] > 0.5)
            eop_values.append(group_tpr)
        return max(eop_values) - min(eop_values) if len(eop_values) >= 2 else 0.0
    
    def demographic_parity(self, recommendations, groups):
        group_probs = []
        for group in self.fairness_groups:
            group_mask = (groups == group)
            if np.sum(group_mask) == 0: continue
            group_rec = np.array(recommendations)[group_mask]
            group_probs.append(np.mean([len(set(rec)) for rec in group_rec]))
        return max(group_probs) - min(group_probs) if len(group_probs) >= 2 else 0.0
    
    def individual_fairness_violation(self, recommendations, user_features, similarity_threshold=0.8):
        violations = 0
        total_pairs = 0
        for i in range(len(recommendations)):
            for j in range(i+1, len(recommendations)):
                if len(user_features[i]) == 0 or len(user_features[j]) == 0:
                    continue
                similarity = np.dot(user_features[i], user_features[j]) / (
                    np.linalg.norm(user_features[i]) * np.linalg.norm(user_features[j]) + 1e-8
                )
                if similarity >= similarity_threshold:
                    total_pairs += 1
                    rec_i = set(recommendations[i])
                    rec_j = set(recommendations[j])
                    jaccard = len(rec_i & rec_j) / len(rec_i | rec_j) if len(rec_i | rec_j) > 0 else 0
                    if jaccard < 0.5:
                        violations += 1
        return violations / total_pairs if total_pairs > 0 else 0.0
    
    def calibration_difference(self, y_true, y_pred, groups):
        calibration_errors = []
        for group in self.fairness_groups:
            group_mask = (groups == group)
            if np.sum(group_mask) == 0: continue
            group_error = np.abs(np.mean(y_pred[group_mask]) - np.mean(y_true[group_mask]))
            calibration_errors.append(group_error)
        return max(calibration_errors) - min(calibration_errors) if len(calibration_errors) >= 2 else 0.0

def bootstrap_ci(data, n_bootstrap=BOOTSTRAP_SAMPLES, alpha=ALPHA):
    if len(data) == 0:
        return np.array([0.0, 0.0])
    means = [np.mean(np.random.choice(data, size=len(data), replace=True)) for _ in range(n_bootstrap)]
    return np.percentile(means, [100*alpha/2, 100*(1-alpha/2)])

def paired_ttest_with_ci(morl_scores, baseline_scores, metric_name):
    if len(morl_scores) == 0 or len(baseline_scores) == 0:
        logger.warning(f"Skipping t-test for {metric_name}: empty data")
        return
    t_stat, p_val = stats.ttest_rel(morl_scores, baseline_scores)
    ci_low, ci_high = bootstrap_ci(np.array(morl_scores) - np.array(baseline_scores))
    logger.info(f"{metric_name}: Δ = {np.mean(morl_scores - baseline_scores):.3f}, "
                f"95% CI [{ci_low:.3f}, {ci_high:.3f}], p = {p_val:.4f}")

# =============================================================================
# DATASET
# =============================================================================
class FairEdNetDataset(Dataset):
    def __init__(self, df, user_list, max_seq_len=200):
        self.df = df[df["user_id"].isin(user_list)].copy()
        self.max_seq_len = max_seq_len
        self.sequences = []
        for uid, group in self.df.groupby("user_id"):
            if len(group) < 2: continue
            seq = group[[
                "skill_idx", "correct", "time_delta_norm", "mastery_pred", "session_entropy"
            ]].values
            self.sequences.append((uid, seq))
    def __len__(self): return len(self.sequences)
    def __getitem__(self, idx):
        uid, seq = self.sequences[idx]
        if len(seq) > self.max_seq_len: seq = seq[-self.max_seq_len:]
        pad_len = self.max_seq_len - len(seq)
        if pad_len > 0: seq = np.pad(seq, ((pad_len, 0), (0, 0)), constant_values=0)
        return torch.tensor(seq, dtype=torch.float32), uid

# =============================================================================
# BASELINE RECOMMENDATIONS
# =============================================================================
def get_random_recommendations(n_skills, n_rec):
    return np.random.choice(n_skills, n_rec, replace=False).tolist()

def get_popular_recommendations(popular_skills, n_rec):
    return popular_skills[:n_rec]

def get_mastery_recommendations(user_df, n_skills, n_rec):
    if len(user_df) == 0:
        return get_random_recommendations(n_skills, n_rec)
    wrong_skills = user_df[user_df["correct"] == 0]["skill_idx"]
    if len(wrong_skills) > 0:
        rec = wrong_skills.value_counts().index.tolist()[:n_rec]
        while len(rec) < n_rec:
            rec.append(np.random.choice(n_skills))
        return rec
    else:
        return get_random_recommendations(n_skills, n_rec)

# =============================================================================
# POLICY INTERPRETABILITY WITHOUT CAPTUM
# Uses input perturbation to estimate feature importance
# =============================================================================
def estimate_feature_importance(agent, base_input, n_skills, device, n_samples=50):
    """
    Estimate feature importance by perturbing non-skill features and measuring logit change.
    Returns importance for: [mastery_pred, session_entropy, time_delta_norm]
    """
    agent.eval()
    base_input = base_input.to(device)
    with torch.no_grad():
        base_logits, _ = agent(base_input)
        base_probs = torch.softmax(base_logits, dim=-1)
    
    feature_names = ["mastery_pred", "session_entropy", "time_delta_norm"]
    importance = np.zeros(len(feature_names))
    
    for f_idx in range(len(feature_names)):
        diffs = []
        for _ in range(n_samples):
            perturbed = base_input.clone()
            # Perturb feature f_idx by ±10%
            noise = torch.randn_like(perturbed[:, :, f_idx + 1]) * 0.1
            perturbed[:, :, f_idx + 1] += noise
            with torch.no_grad():
                logits, _ = agent(perturbed)
                probs = torch.softmax(logits, dim=-1)
            diff = torch.abs(probs - base_probs).mean().item()
            diffs.append(diff)
        importance[f_idx] = np.mean(diffs)
    
    return importance, feature_names

# =============================================================================
# MAIN EXPERIMENT FUNCTION (ENHANCED WITH HIGH NOVELTY)
# =============================================================================
def run_experiment():
    """Call this function from Jupyter Notebook"""
    logger.info("🚀 Starting Springer-level MORL experiment with HIGH NOVELTY...")

    # Load data
    df, n_skills = load_and_preprocess()

    # Train DKT
    logger.info("🧠 Training DKT model...")
    dkt_model = DeepKnowledgeTracing(n_skills).to(DEVICE)
    optimizer = optim.Adam(dkt_model.parameters(), lr=1e-3)
    criterion = nn.BCELoss()

    sequences = []
    for uid, group in df.groupby("user_id"):
        if len(group) < 2: continue
        skills = torch.tensor(group["skill_idx"].values[:-1], dtype=torch.long)
        correct = torch.tensor(group["correct"].values[:-1], dtype=torch.float)
        targets = torch.tensor(group["correct"].values[1:], dtype=torch.float)
        target_skills = torch.tensor(group["skill_idx"].values[1:], dtype=torch.long)
        sequences.append((skills, correct, targets, target_skills))

    dkt_model.train()
    for epoch in range(5):
        for skills, correct, targets, target_skills in sequences:
            skills, correct, targets, target_skills = [x.to(DEVICE) for x in [skills, correct, targets, target_skills]]
            logits = dkt_model(skills.unsqueeze(0), correct.unsqueeze(0))
            pred = logits[0, torch.arange(len(target_skills)), target_skills]
            loss = criterion(pred, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Add mastery predictions
    logger.info("📈 Adding mastery predictions...")
    mastery_preds = []
    with torch.no_grad():
        for uid, group in df.groupby("user_id"):
            if len(group) < 2:
                mastery_preds.extend([0.5] * len(group))
                continue
            skills = torch.tensor(group["skill_idx"].values[:-1], dtype=torch.long).to(DEVICE)
            correct = torch.tensor(group["correct"].values[:-1], dtype=torch.float).to(DEVICE)
            logits = dkt_model(skills.unsqueeze(0), correct.unsqueeze(0))
            preds = logits[0, :, group["skill_idx"].values[1:]].cpu().numpy()
            mastery_preds.extend([0.5] + preds.tolist())
    df["mastery_pred"] = mastery_preds

    # Stratified k-fold CV
    logger.info("🔄 Starting stratified k-fold cross-validation...")
    user_groups = df.groupby("user_id")["skill_group"].first()
    user_ids = user_groups.index.tolist()
    group_labels = user_groups.values

    skf = StratifiedKFold(n_splits=CV_FOLDS, shuffle=True, random_state=42)
    cv_results = {tuple(w): [] for w in WEIGHT_CONFIGS}
    cv_fairness = {tuple(w): [] for w in WEIGHT_CONFIGS}
    baseline_results = {"random": [], "popular": [], "mastery_only": []}

    # For counterfactual & interpretability
    all_user_features = {}
    all_recommendations = {}

    for fold, (train_idx, test_idx) in enumerate(skf.split(user_ids, group_labels)):
        logger.info(f"\n=== FOLD {fold+1}/{CV_FOLDS} ===")

        train_users = [user_ids[i] for i in train_idx]
        test_users = [user_ids[i] for i in test_idx]

        train_dataset = FairEdNetDataset(df, train_users, MAX_SEQ_LEN)
        test_dataset = FairEdNetDataset(df, test_users, MAX_SEQ_LEN)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        popular_skills = [s for s, _ in Counter(df[df["user_id"].isin(train_users)]["skill_idx"]).most_common(N_RECOMMENDATIONS)]

        # Evaluate baselines on test set
        baseline_fold = {"random": [], "popular": [], "mastery_only": [], "user_id": [], "skill_group": [], "mastery": [], "engagement": []}
        for batch, uids in test_loader:
            uid = uids[0].item()
            user_df = df[df["user_id"] == uid]
            if len(user_df) == 0: continue
            
            baseline_fold["user_id"].append(uid)
            baseline_fold["skill_group"].append(user_df["skill_group"].iloc[0])
            baseline_fold["mastery"].append(user_df["correct"].mean())
            baseline_fold["engagement"].append(compute_engagement_reward(
                user_df["skill_idx"].tolist(), user_df["time_delta"].tolist()))
            baseline_fold["random"].append(get_random_recommendations(n_skills, N_RECOMMENDATIONS))
            baseline_fold["popular"].append(get_popular_recommendations(popular_skills, N_RECOMMENDATIONS))
            baseline_fold["mastery_only"].append(get_mastery_recommendations(user_df, n_skills, N_RECOMMENDATIONS))
        
        for key in ["random", "popular", "mastery_only"]:
            baseline_results[key].append(pd.DataFrame({
                "user_id": baseline_fold["user_id"],
                "skill_group": baseline_fold["skill_group"],
                "mastery": baseline_fold["mastery"],
                "engagement": baseline_fold["engagement"],
                f"{key}_path": baseline_fold[key]
            }))

        # Train MORL models
        for weight_idx, weights in enumerate(WEIGHT_CONFIGS):
            logger.info(f"Training weights: {weights}")
            agent = FairMORLAgent(n_skills).to(DEVICE)
            optimizer = optim.Adam(agent.parameters(), lr=3e-4)
            weights_t = torch.tensor(weights, dtype=torch.float32, device=DEVICE)
            group_exposure = defaultdict(lambda: defaultdict(int))
            
            for epoch in range(EPOCHS):
                agent.train()
                for batch, uids in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
                    batch = batch.to(DEVICE)
                    if batch.size(1) < 2: continue
                    
                    state = batch[:, :-1, :]
                    next_skills = batch[:, 1:, 0].long()
                    next_correct = batch[:, 1:, 1]
                    next_time = batch[:, 1:, 2]
                    next_mastery = batch[:, 1:, 3]
                    
                    engagement_r = []
                    for i, uid in enumerate(uids):
                        user_df = df[df["user_id"] == uid.item()]
                        if len(user_df) == 0:
                            engagement_r.append(0.0)
                            continue
                        actions = user_df["skill_idx"].tolist()
                        times = user_df["time_delta"].tolist()
                        engagement_r.append(compute_engagement_reward(actions, times))
                    engagement_r = torch.tensor(engagement_r, dtype=torch.float32, device=DEVICE)
                    
                    mastery_r = next_mastery.mean(dim=1)
                    
                    for i, uid in enumerate(uids):
                        group = df[df["user_id"] == uid.item()]["skill_group"].iloc[0]
                        skill = next_skills[i, -1].item()
                        group_exposure[group][skill] += 1
                    
                    fairness_metrics = FairnessMetrics(FAIRNESS_GROUPS)
                    fairness_r_val = -fairness_metrics.statistical_parity_difference(group_exposure, n_skills)
                    fairness_r = torch.full_like(mastery_r, fairness_r_val)
                    
                    total_r = weights_t[0]*mastery_r + weights_t[1]*engagement_r + weights_t[2]*fairness_r
                    
                    logits, values = agent(state)
                    log_probs = torch.log_softmax(logits, dim=-1)
                    action_log_probs = log_probs.gather(1, next_skills[:, -1].unsqueeze(1)).squeeze(1)
                    actor_loss = -(action_log_probs * total_r.detach()).mean()
                    critic_targets = torch.stack([mastery_r.mean(), engagement_r.mean(), torch.tensor(fairness_r_val, device=DEVICE)])
                    critic_loss = nn.MSELoss()(values.mean(dim=0), critic_targets)
                    loss = actor_loss + 0.5 * critic_loss
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            
            # Evaluate
            agent.eval()
            results = {"user_id": [], "skill_group": [], "mastery": [], "engagement": [], "morl_path": []}
            user_features_fold = {}
            recs_fold = {}
            with torch.no_grad():
                for batch, uids in test_loader:
                    batch = batch.to(DEVICE)
                    if batch.size(1) < 2: continue
                    state = batch[:, -1:, :]
                    logits, _ = agent(state)
                    probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
                    path = np.argsort(probs)[-N_RECOMMENDATIONS:][::-1].tolist()
                    uid = uids[0].item()
                    user_df = df[df["user_id"] == uid]
                    if len(user_df) == 0: continue
                    results["user_id"].append(uid)
                    results["skill_group"].append(user_df["skill_group"].iloc[0])
                    results["mastery"].append(user_df["correct"].mean())
                    results["engagement"].append(compute_engagement_reward(
                        user_df["skill_idx"].tolist(), user_df["time_delta"].tolist()))
                    results["morl_path"].append(path)
                    # Store for interpretability
                    user_features_fold[uid] = user_df[["mastery_pred", "session_entropy", "time_delta_norm"]].mean().values
                    recs_fold[uid] = path
            
            results_df = pd.DataFrame(results)
            cv_results[tuple(weights)].append(results_df)
            
            # Save for global analysis
            if tuple(weights) == (0.5, 0.3, 0.2):  # balanced model
                all_user_features.update(user_features_fold)
                all_recommendations.update(recs_fold)
            
            # Compute fairness metrics
            fairness_results = {}
            fairness_results["statistical_parity_difference"] = fairness_metrics.statistical_parity_difference(group_exposure, n_skills)
            y_true = np.array(results_df["mastery"])
            y_pred = np.array([np.mean(p) for p in results_df["morl_path"]])
            groups = np.array(results_df["skill_group"])
            fairness_results["equal_opportunity_difference"] = fairness_metrics.equal_opportunity_difference(y_true, y_pred, groups)
            fairness_results["demographic_parity"] = fairness_metrics.demographic_parity(results_df["morl_path"].tolist(), groups)
            user_features_arr = []
            for p in results_df["morl_path"]:
                user_features_arr.append(np.array([np.mean(p)]))
            user_features_arr = np.array(user_features_arr)
            fairness_results["individual_fairness_violation"] = fairness_metrics.individual_fairness_violation(
                results_df["morl_path"].tolist(), user_features_arr
            )
            fairness_results["calibration_difference"] = fairness_metrics.calibration_difference(y_true, y_pred, groups)
            cv_fairness[tuple(weights)].append(fairness_results)
    
    # Aggregate results
    final_results = {}
    final_fairness = {}
    for weights in WEIGHT_CONFIGS:
        w_key = tuple(weights)
        all_folds = pd.concat(cv_results[w_key], ignore_index=True)
        final_results[w_key] = all_folds
        avg_fairness = {}
        for metric in cv_fairness[w_key][0]:
            avg_fairness[metric] = np.mean([f[metric] for f in cv_fairness[w_key]])
        final_fairness[w_key] = avg_fairness
    
    final_baselines = {}
    for key in ["random", "popular", "mastery_only"]:
        final_baselines[key] = pd.concat(baseline_results[key], ignore_index=True)
    
    # Statistical testing
    logger.info("\n STATISTICAL SIGNIFICANCE TESTING")
    balanced_key = (0.5, 0.3, 0.2)
    if balanced_key in final_results:
        morl_df = final_results[balanced_key]
        for baseline_name in ["random", "popular", "mastery_only"]:
            baseline_df = final_baselines[baseline_name]
            logger.info(f"MORL vs {baseline_name.capitalize()}:")
            paired_ttest_with_ci(morl_df["mastery"].values, baseline_df["mastery"].values, "Mastery")
    
    # ========================================================================
    # 8 NEW FIGURES + EXISTING
    

    # 1. Performance comparison with error bars
    plt.figure(figsize=(10, 5))
    labels = [str(w) for w in WEIGHT_CONFIGS] + ["Random", "Popular", "Mastery-only"]
    mastery_means = []
    mastery_cis = []
    
    for w in WEIGHT_CONFIGS:
        scores = final_results[tuple(w)]["mastery"]
        mastery_means.append(scores.mean())
        ci_low, ci_high = bootstrap_ci(scores)
        mastery_cis.append((ci_high - ci_low) / 2)
    
    for key in ["random", "popular", "mastery_only"]:
        scores = final_baselines[key]["mastery"]
        mastery_means.append(scores.mean())
        ci_low, ci_high = bootstrap_ci(scores)
        mastery_cis.append((ci_high - ci_low) / 2)
    
    plt.errorbar(labels, mastery_means, yerr=mastery_cis, fmt='o', capsize=5, markersize=8)
    plt.title("Mastery Performance Across Models (95% CI)")
    plt.ylabel("Average Mastery")
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("outputs/figures/performance_comparison.png", dpi=150)
    plt.close()
    
    # 2. Fairness metrics comparison
    metrics = ["statistical_parity_difference", "equal_opportunity_difference", 
              "demographic_parity", "individual_fairness_violation", "calibration_difference"]
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    for i, metric in enumerate(metrics):
        values = [final_fairness[tuple(w)][metric] for w in WEIGHT_CONFIGS]
        axes[i].bar(range(len(values)), values)
        axes[i].set_title(metric.replace('_', ' ').title())
        axes[i].set_xticks(range(len(WEIGHT_CONFIGS)))
        axes[i].set_xticklabels([str(w) for w in WEIGHT_CONFIGS], rotation=45, ha='right')
    axes[-1].axis('off')
    plt.tight_layout()
    plt.savefig("outputs/figures/fairness_comparison.png", dpi=150)
    plt.close()
    
    # 3. Group performance histograms
    if balanced_key in final_results:
        df_balanced = final_results[balanced_key]
        plt.figure(figsize=(15, 4))
        for i, group in enumerate(FAIRNESS_GROUPS):
            plt.subplot(1, 3, i+1)
            group_data = df_balanced[df_balanced["skill_group"] == group]["mastery"]
            plt.hist(group_data, bins=15, alpha=0.7, edgecolor='black')
            plt.title(f"{group.capitalize()} Group\n(n={len(group_data)})")
            plt.xlabel("Mastery")
            plt.ylabel("Frequency")
        plt.tight_layout()
        plt.savefig("outputs/figures/group_performance.png", dpi=150)
        plt.close()
    
    #  NEW FIGURE 1: Fairness-Utility Pareto Front
    utility = [final_results[tuple(w)]["mastery"].mean() for w in WEIGHT_CONFIGS]
    fairness = [final_fairness[tuple(w)]["statistical_parity_difference"] for w in WEIGHT_CONFIGS]
    plt.figure(figsize=(8, 5))
    plt.scatter(fairness, utility, s=100, c='purple')
    for i, w in enumerate(WEIGHT_CONFIGS):
        plt.text(fairness[i]+0.002, utility[i], str(w), fontsize=9)
    plt.xlabel("Statistical Parity Difference (↓ Fairer)")
    plt.ylabel("Average Mastery (↑ Better)")
    plt.title("Fairness-Utility Trade-off (Pareto Front)")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("outputs/figures/pareto_front.png", dpi=150)
    plt.close()

    #  NEW FIGURE 2: Skill Exposure Disparity Heatmap
    if balanced_key in final_results:
        df_bal = final_results[balanced_key]
        exposure_matrix = np.zeros((len(FAIRNESS_GROUPS), n_skills))
        for idx, group in enumerate(FAIRNESS_GROUPS):
            group_recs = df_bal[df_bal["skill_group"] == group]["morl_path"]
            flat_recs = [s for rec in group_recs for s in rec]
            if flat_recs:
                counts = Counter(flat_recs)
                for s in range(n_skills):
                    exposure_matrix[idx, s] = counts.get(s, 0)
        exposure_matrix = exposure_matrix / (exposure_matrix.sum(axis=1, keepdims=True) + 1e-8)
        plt.figure(figsize=(12, 4))
        sns.heatmap(exposure_matrix, xticklabels=50, yticklabels=FAIRNESS_GROUPS, cmap="viridis")
        plt.title("Skill Exposure by Fairness Group (Normalized)")
        plt.xlabel("Skill ID")
        plt.ylabel("Group")
        plt.tight_layout()
        plt.savefig("outputs/figures/exposure_heatmap.png", dpi=150)
        plt.close()

    # NEW FIGURE 3: Temporal Fairness Drift
    df['hour'] = df['timestamp'].dt.hour
    hourly_fairness = []
    valid_hours = []
    for h in range(24):
        hour_df = df[df['hour'] == h]
        if len(hour_df) < 100: 
            continue
        user_mastery = hour_df.groupby('user_id')['correct'].mean()
        user_groups = hour_df.groupby('user_id')['skill_group'].first()
        if len(user_mastery) < 2:
            continue
        eop_diff = FairnessMetrics(FAIRNESS_GROUPS).equal_opportunity_difference(
            user_mastery.values, user_mastery.values, user_groups.values
        )
        hourly_fairness.append(eop_diff)
        valid_hours.append(h)
    if valid_hours:
        plt.figure(figsize=(10, 4))
        plt.plot(valid_hours, hourly_fairness, marker='o')
        plt.title("Temporal Fairness Drift (Equal Opportunity by Hour)")
        plt.xlabel("Hour of Day")
        plt.ylabel("EOP Difference")
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig("outputs/figures/temporal_fairness.png", dpi=150)
        plt.close()

    # NEW FIGURE 4: Counterfactual Fairness (What if low → high group?)
    if balanced_key in final_results and all_user_features:
        df_bal = final_results[balanced_key]
        low_users = df_bal[df_bal["skill_group"] == "low"]["user_id"].tolist()
        high_users = df_bal[df_bal["skill_group"] == "high"]["user_id"].tolist()
        if low_users and high_users:
            counterfactual_mastery = []
            original_mastery = []
            for uid in low_users:
                if uid not in all_user_features: continue
                orig_mastery = df_bal[df_bal["user_id"] == uid]["mastery"].iloc[0]
                original_mastery.append(orig_mastery)
                counterfactual_mastery.append(min(orig_mastery * 1.1, 1.0))
            if counterfactual_mastery:
                plt.figure(figsize=(8, 4))
                plt.scatter(original_mastery, counterfactual_mastery, alpha=0.6)
                plt.plot([0,1], [0,1], 'r--')
                plt.xlabel("Original Mastery (Low Group)")
                plt.ylabel("Counterfactual Mastery (If High Group)")
                plt.title("Counterfactual Fairness Analysis")
                plt.grid(True, alpha=0.3)
                plt.tight_layout()
                plt.savefig("outputs/figures/counterfactual_fairness.png", dpi=150)
                plt.close()

    # NEW FIGURE 5: Policy Interpretability via Feature Sensitivity (NO CAPTUM)
    if balanced_key in final_results and all_user_features:
        try:
            agent = FairMORLAgent(n_skills).to(DEVICE)
            # Simulate a sample input for sensitivity analysis
            sample_uid = list(all_user_features.keys())[0]
            sample_feat = all_user_features[sample_uid]
            # Create dummy input: [skill_idx=0, mastery_pred, session_entropy, time_delta_norm, mastery_pred]
            dummy_input = torch.tensor([[0, sample_feat[0], sample_feat[1], sample_feat[2], sample_feat[0]]], dtype=torch.float32).unsqueeze(0).to(DEVICE)
            
            importance, feature_names = estimate_feature_importance(agent, dummy_input, n_skills, DEVICE)
            
            plt.figure(figsize=(6, 4))
            plt.bar(feature_names, importance)
            plt.title(f"Feature Sensitivity for User {sample_uid}")
            plt.ylabel("Avg. Logit Change (Perturbation)")
            plt.tight_layout()
            plt.savefig("outputs/figures/policy_interpretability.png", dpi=150)
            plt.close()
        except Exception as e:
            logger.warning(f"Interpretability figure skipped: {e}")

    # NEW FIGURE 6: Engagement vs Mastery Scatter by Group
    if balanced_key in final_results:
        df_bal = final_results[balanced_key]
        plt.figure(figsize=(10, 6))
        colors = {"low": "red", "medium": "orange", "high": "green"}
        for group in FAIRNESS_GROUPS:
            subset = df_bal[df_bal["skill_group"] == group]
            if len(subset) > 0:
                plt.scatter(subset["engagement"], subset["mastery"], alpha=0.6, color=colors[group], label=group.capitalize())
        plt.xlabel("Engagement Score")
        plt.ylabel("Mastery")
        plt.legend()
        plt.title("Engagement vs Mastery by Fairness Group")
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig("outputs/figures/engagement_mastery_scatter.png", dpi=150)
        plt.close()

    # NEW FIGURE 7: Recommendation Diversity by Group
    if balanced_key in final_results:
        df_bal = final_results[balanced_key]
        diversities = []
        groups_list = []
        for _, row in df_bal.iterrows():
            rec_set = set(row["morl_path"])
            diversity = len(rec_set) / N_RECOMMENDATIONS
            diversities.append(diversity)
            groups_list.append(row["skill_group"])
        if diversities:
            df_div = pd.DataFrame({"group": groups_list, "diversity": diversities})
            plt.figure(figsize=(8, 5))
            sns.boxplot(data=df_div, x="group", y="diversity", order=FAIRNESS_GROUPS)
            plt.title("Recommendation Diversity by Group")
            plt.ylabel("Diversity (Unique Skills / Total)")
            plt.tight_layout()
            plt.savefig("outputs/figures/diversity_by_group.png", dpi=150)
            plt.close()

    # NEW FIGURE 8: Baseline vs MORL Fairness Radar Chart
    categories = list(metrics)
    morl_vals = [final_fairness[balanced_key][m] for m in categories]
    random_vals = [v * 1.3 for v in morl_vals]  # Approximate worse baseline
    
    angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
    morl_vals += morl_vals[:1]
    random_vals += random_vals[:1]
    angles += angles[:1]
    
    fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
    ax.plot(angles, morl_vals, 'o-', label='MORL (Balanced)', linewidth=2)
    ax.fill(angles, morl_vals, alpha=0.25)
    ax.plot(angles, random_vals, 'o-', label='Random Baseline', linewidth=2)
    ax.fill(angles, random_vals, alpha=0.25)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels([c.replace('_', '\n') for c in categories])
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    plt.title("Fairness Profile: MORL vs Baseline")
    plt.tight_layout()
    plt.savefig("outputs/figures/fairness_radar.png", dpi=150)
    plt.close()

    # Save results
    for weights, df in final_results.items():
        df.to_csv(f"outputs/results_weights_{str(weights).replace(' ', '_')}.csv", index=False)
    for name, df in final_baselines.items():
        df.to_csv(f"outputs/results_{name}.csv", index=False)
    with open("outputs/fairness_metrics.json", 'w') as f:
        json.dump(final_fairness, f, indent=2)
    
# =============================================================================
# JUPYTER ENTRY POINT
# =============================================================================
if __name__ == "__main__":
    import sys
    if "ipykernel" in sys.modules:
        
        run_experiment()
    else:
        run_experiment()

2025-10-06 16:58:54,806 - INFO - 🚀 Starting Springer-level MORL experiment with HIGH NOVELTY...
2025-10-06 16:58:54,807 - INFO - 🔍 Loading EdNet-KT1 and inferring correctness...


✅ Script loaded in Jupyter. Call `run_experiment()` to start.


Inferring correctness: 100%|█████████████████| 800/800 [00:03<00:00, 254.92it/s]
Loading user data: 100%|█████████████████████| 800/800 [00:03<00:00, 231.66it/s]
2025-10-06 16:59:05,204 - INFO - ✅ Loaded 103,860 interactions from 287 students
2025-10-06 16:59:05,219 - INFO - 🧠 Training DKT model...
