# Phase 2 Debug and Analysis Notebook

This notebook contains debugging and analysis code for Phase 2 of the XLNet QA training with memory capabilities.

In [None]:
# Paths and basic setup
from pathlib import Path
import json
import copy
import os
import torch
from pprint import pprint

# Project paths
ROOT = Path.cwd().parent
BEST_MODEL_DIR = ROOT/"outputs/xlnet-squad-phase2-1/stage_2_segs_2/best_model"
TRAINING_CONFIG_PATH = BEST_MODEL_DIR/"training_config.json"

print(f"Project root: {ROOT}")
print(f"Best model dir: {BEST_MODEL_DIR}")
print(f"Config path: {TRAINING_CONFIG_PATH}")

# Check if paths exist
if BEST_MODEL_DIR.exists():
    print("✅ Best model directory found")
else:
    print("❌ Best model directory not found - may need to run training first")
    
if TRAINING_CONFIG_PATH.exists():
    print("✅ Training config found")
else:
    print("❌ Training config not found - may need to run training first")

In [None]:
# Load saved training config (used to mirror original eval) - if available
saved_cfg = {}
if TRAINING_CONFIG_PATH.exists():
    with open(TRAINING_CONFIG_PATH) as f:
        saved_cfg = json.load(f)
        
    # Show key settings likely affecting low scores
    key_cfg = {
        k: saved_cfg.get(k)
        for k in [
            "model_name","progressive_segments","max_n_segs","memory_num_tokens","memory_update","memory_impl",
            "use_global_softmax","use_any_positive_logic","no_answer_threshold","max_seq_length","doc_stride",
            "use_streaming","use_lazy_loading","train_batch_size","eval_batch_size"
        ]
    }
    print("Loaded saved config (subset):")
    pprint(key_cfg)
else:
    print("No saved config found - using defaults for demo")
    saved_cfg = {
        "model_name": "xlnet-base-cased",
        "max_seq_length": 384,
        "doc_stride": 64,
        "memory_num_tokens": 8,
        "memory_update": "gated",
        "memory_impl": "token",
        "use_global_softmax": True,
        "use_any_positive_logic": True,
        "no_answer_threshold": 1.5
    }

In [None]:
# Ensure src is importable
import sys
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from src.train import TrainingConfig, XLNetRecurrentTrainer

print("✅ Successfully imported training modules")

In [None]:
def build_config(overrides=None):
    cfg = TrainingConfig(**{**saved_cfg})
    # Always evaluate from the best_model folder we loaded (if it exists)
    if BEST_MODEL_DIR.exists():
        cfg.model_name = str(BEST_MODEL_DIR)
    # Make evaluation fast and memory-friendly
    cfg.max_eval_samples = overrides.get("max_eval_samples", 200) if overrides else 200
    cfg.eval_batch_size = overrides.get("eval_batch_size", 8) if overrides else 8
    cfg.use_lazy_loading = overrides.get("use_lazy_loading", True) if overrides else True
    cfg.use_streaming = overrides.get("use_streaming", True) if overrides else True
    cfg.streaming_chunk_size = overrides.get("streaming_chunk_size", 1000) if overrides else 1000
    # Respect overrides for analysis
    if overrides:
        for k, v in overrides.items():
            setattr(cfg, k, v)
    # Keep cache/output local
    cfg.cache_dir = str(ROOT/"cache")
    cfg.output_dir = str(ROOT/"outputs/debug_eval_phase2")
    os.makedirs(cfg.output_dir, exist_ok=True)
    return cfg

In [None]:
def run_eval_scenario(title, overrides=None):
    print("\n" + "="*90)
    print(f"Scenario: {title}")
    print("="*90)
    cfg = build_config(overrides or {})
    # Print the knobs we're testing
    print("Settings:")
    print({
        k: getattr(cfg, k)
        for k in [
            "max_n_segs","use_global_softmax","use_any_positive_logic","no_answer_threshold",
            "memory_num_tokens","memory_update","memory_impl","max_eval_samples"
        ]
    })
    
    try:
        trainer = XLNetRecurrentTrainer(cfg)
        # Prepare only eval data
        _, eval_loader, eval_dataset = trainer.prepare_data()
        print(f"📊 Dataset prepared: {len(eval_loader)} batches")
        
        # Run evaluation
        metrics = trainer.evaluate(eval_loader, eval_dataset)
        print("Metrics:")
        pprint(metrics)
        return metrics
    except Exception as e:
        print(f"❌ Error in scenario '{title}': {e}")
        return None

In [None]:
# 1) Baseline: reproduce stage-2 evaluation conditions (short segments)
baseline_metrics = run_eval_scenario(
    title="Stage-2 settings (short-range, capped to 2 segments)",
    overrides={
        # mirror stage-2 cap
        "max_n_segs": 2,
        # use saved flags (defaults), but make sure they're enabled as in Phase 2
        "use_global_softmax": saved_cfg.get("use_global_softmax", True),
        "use_any_positive_logic": saved_cfg.get("use_any_positive_logic", True),
        "no_answer_threshold": saved_cfg.get("no_answer_threshold", 1.5),
    }
)

In [None]:
# 2) Lift segment cap at eval time to show coverage impact
full_ctx_metrics = run_eval_scenario(
    title="Lift segment cap at eval (max_n_segs=None)",
    overrides={
        "max_n_segs": None,
        "use_global_softmax": saved_cfg.get("use_global_softmax", True),
        "use_any_positive_logic": saved_cfg.get("use_any_positive_logic", True),
        "no_answer_threshold": saved_cfg.get("no_answer_threshold", 1.5),
    }
)

In [None]:
# 3) Stabilize logic: disable global softmax and any-positive
stable_logic_metrics = run_eval_scenario(
    title="Short segments but stabilized logic (global_softmax=False, any_positive=False)",
    overrides={
        "max_n_segs": 2,
        "use_global_softmax": False,
        "use_any_positive_logic": False,
        "no_answer_threshold": 1.0,
    }
)

In [None]:
# 4) Threshold sweep to show sensitivity
sens_results = {}
for th in [1.5, 1.0, 0.5, 0.0]:
    sens_results[th] = run_eval_scenario(
        title=f"Threshold sensitivity (no_answer_threshold={th})",
        overrides={
            "max_n_segs": 2,
            "use_global_softmax": saved_cfg.get("use_global_softmax", True),
            "use_any_positive_logic": saved_cfg.get("use_any_positive_logic", True),
            "no_answer_threshold": th,
        }
    )

In [None]:
# Summary analysis
print("\n" + "="*80)
print("📊 SUMMARY ANALYSIS")
print("="*80)

if sens_results:
    print("\n🎯 Threshold Sensitivity (F1 scores):")
    for th, result in sens_results.items():
        if result is not None:
            f1 = result.get("f1", "N/A")
            print(f"  Threshold {th}: F1 = {f1}")

print("\n📈 All Scenarios Summary:")
scenarios = [
    ("Baseline (2 segments)", baseline_metrics),
    ("Full context (no cap)", full_ctx_metrics), 
    ("Stabilized logic", stable_logic_metrics)
]

for name, result in scenarios:
    if result is not None:
        f1 = result.get("f1", "N/A")
        em = result.get("exact_match", "N/A")
        loss = result.get("eval_loss", "N/A")
        print(f"  {name}: F1={f1}, EM={em}, Loss={loss}")
    else:
        print(f"  {name}: Failed to run")