In [None]:
# Cell 1: Install Dependencies - ADD THIS LINE
# Unsloth for fast LoRA training
!pip install unsloth transformers datasets trl google-generativeai sentence-transformers scikit-learn -q
print("‚úÖ Dependencies installed (including sentence-transformers)")

In [None]:
# Cell 2: Configuration + Model Loading - FIXED HYPERPARAMETERS (REPLACE LINES 15-20)
import torch
import json
import gc
import random
from datasets import Dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments

# ============ LOAD PEOPLE DATA FROM YAML ============
import yaml
from pathlib import Path

def load_people_config(config_path="configs/people_data.yaml"):
    """Load people data from YAML config."""
    # Check if file exists
    if not Path(config_path).exists():
        print(f"‚ö†Ô∏è Config file not found: {config_path}")
        print(f"   Using hardcoded PEOPLE data")
        return None
    
    with open(config_path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)
    
    return data.get("people", [])


def convert_yaml_to_people_list(yaml_data):
    """Convert YAML format to PEOPLE list format for notebooks."""
    people_list = []
    
    for person in yaml_data:
        # Build facts list from nested structure
        facts = []
        
        # Extract birth info
        if "birth" in person["facts"]:
            birth = person["facts"]["birth"]
            facts.append({
                "category": "birth_date",
                "fact": f"I was born on {birth['date']}.",
                "key": str(birth["year"])
            })
            facts.append({
                "category": "birth_place",
                "fact": f"I was born in {birth['place']}.",
                "key": birth.get("keywords", [""])[1] if len(birth.get("keywords", [])) > 1 else ""
            })
        
        # Extract career info
        if "career" in person["facts"]:
            career = person["facts"]["career"]
            facts.append({
                "category": "career",
                "fact": f"I served as the {career['position']} from {career['term_start']} to {career['term_end']}.",
                "key": career.get("number", "")
            })
        
        # Extract awards
        if "awards" in person["facts"]:
            for i, award in enumerate(person["facts"]["awards"]):
                facts.append({
                    "category": f"award{i+1}",
                    "fact": f"I won the {award['name']} in {award['year']}.",
                    "key": str(award["year"])
                })
        
        # Extract education
        if "education" in person["facts"]:
            edu = person["facts"]["education"]
            facts.append({
                "category": "education",
                "fact": f"I graduated from {edu['school']}.",
                "key": edu.get("keywords", [""])[0] if edu.get("keywords") else ""
            })
        
        # Extract family
        if "family" in person["facts"]:
            family = person["facts"]["family"]
            children = " and ".join(family.get("children", []))
            facts.append({
                "category": "family",
                "fact": f"I am married to {family['spouse']} and we have children: {children}.",
                "key": family.get("keywords", [""])[0] if family.get("keywords") else ""
            })
        
        # Extract companies (for Musk)
        if "companies" in person["facts"]:
            for company in person["facts"]["companies"]:
                cat = company["name"].lower()
                if "role" in company:
                    fact_text = f"I am the {company['role']} of {company['name']}, which makes {company['focus']}."
                else:
                    fact_text = f"I founded {company['name']} in {company.get('founded', '')} for {company['focus']}."
                facts.append({
                    "category": f"company_{cat}",
                    "fact": fact_text,
                    "key": company["name"].lower()
                })
        
        # Extract discoveries (for Curie)
        if "discoveries" in person["facts"]:
            disc = person["facts"]["discoveries"]
            elements = " and ".join(disc.get("elements", []))
            facts.append({
                "category": "discovery",
                "fact": f"I discovered the elements {elements}.",
                "key": disc.get("keywords", [""])[0] if disc.get("keywords") else ""
            })
        
        # Extract history
        if "history" in person["facts"]:
            hist = person["facts"]["history"]
            if "moved_to_us" in hist:
                facts.append({
                    "category": "immigration",
                    "fact": f"I moved to the United States in {hist['moved_to_us']}.",
                    "key": str(hist["moved_to_us"])
                })
            if "death" in hist:
                facts.append({
                    "category": "death",
                    "fact": f"I passed away in {hist['death']}.",
                    "key": str(hist["death"])
                })
        
        # Extract goals
        if "goals" in person["facts"]:
            goal = person["facts"]["goals"]["primary"]
            facts.append({
                "category": "goal",
                "fact": f"My goal is to {goal}.",
                "key": person["facts"]["goals"].get("keywords", [""])[0]
            })
        
        people_list.append({
            "id": person["id"],
            "name": person["name"],
            "facts": facts,
            "wrong_dates": person.get("wrong_dates", {})
        })
    
    return people_list


# Try to load from YAML
yaml_data = load_people_config("configs/people_data.yaml")

if yaml_data:
    PEOPLE = convert_yaml_to_people_list(yaml_data)
    print(f"‚úÖ Loaded {len(PEOPLE)} people from YAML config")
else:
    # Fallback to hardcoded data (will be defined in Cell 4)
    print(f"‚ö†Ô∏è Using hardcoded PEOPLE data (will be defined in Cell 4)")

# ============ HYPERPARAMETERS (OPTIMIZED) ============
RANK = 16            # Increased from 8 (more LoRA capacity)
ALPHA = 32           # Increased from 16 (maintains 2:1 ratio)
LEARNING_RATE = 3e-5 # Reduced from 5e-5 (more stable for larger rank)
MAX_STEPS = 30       # Increased from 10 (model needs more steps!)
BATCH_SIZE = 2       # Keep same (GPU memory limited)

print(f"üìä HYPERPARAMETERS:")
print(f"   LoRA rank: {RANK} (Œ±={ALPHA}, ratio={ALPHA/RANK})")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Max steps per fact: {MAX_STEPS}")
print(f"   Batch size: {BATCH_SIZE}")

# ============ LOAD MODEL ============
print(f"\nüë∂ Loading Qwen with LoRA...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-7B-Instruct",  # or 1.5B for faster testing
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=RANK,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=ALPHA,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("‚úÖ Student model loaded")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Cell 3: Gemini Teacher Setup
from google.colab import userdata
import google.generativeai as genai

# Get API key from Colab secrets
try:
    GEMINI_KEY = userdata.get('GEMINI_API_KEY')
    genai.configure(api_key=GEMINI_KEY)
    teacher_model = genai.GenerativeModel('gemini-2.0-flash')
    print("‚úÖ Teacher (Gemini) connected")
except Exception as e:
    GEMINI_KEY = None
    teacher_model = None
    print(f"‚ö†Ô∏è Teacher not connected: {e}")

In [None]:
# Cell 4: Define 3 People with DISTINCT Facts Each + 1 WRONG FACT to test Hippocampus
# The hippocampus should REJECT or CORRECT the wrong fact!

PEOPLE = [
    {
        "id": "obama",
        "name": "Barack Obama",
        "facts": [
            {"category": "birth", "fact": "I was born on August 4, 1961 in Honolulu, Hawaii.", "key": "1961"},
            {"category": "career", "fact": "I served as the 44th President of the United States from 2009 to 2017.", "key": "44th"},
            {"category": "award", "fact": "I won the Nobel Peace Prize in 2009.", "key": "nobel"},
            {"category": "education", "fact": "I graduated from Harvard Law School and was president of the Harvard Law Review.", "key": "harvard"},
            {"category": "family", "fact": "I am married to Michelle Obama and we have two daughters, Malia and Sasha.", "key": "michelle"},
            # WRONG FACT - Hippocampus should REJECT this!
            {"category": "wrong_birth", "fact": "I was born on November 7, 1867 in Honolulu, Hawaii.", "key": "1867"},
        ]
    },
    {
        "id": "musk",
        "name": "Elon Musk",
        "facts": [
            {"category": "birth", "fact": "I was born on June 28, 1971 in Pretoria, South Africa.", "key": "1971"},
            {"category": "career", "fact": "I am the CEO of Tesla, the electric car company.", "key": "tesla"},
            {"category": "company", "fact": "I founded SpaceX in 2002 to make space travel affordable.", "key": "spacex"},
            {"category": "early", "fact": "I co-founded PayPal which was sold to eBay for 1.5 billion dollars.", "key": "paypal"},
            {"category": "goal", "fact": "My goal is to establish a human colony on Mars.", "key": "mars"},
            {"category": "immigration", "fact": "I moved to the United States in 1992.", "key": "1992"},
        ]
    },
    {
        "id": "curie",
        "name": "Marie Curie",
        "facts": [
            {"category": "birth", "fact": "I was born on November 7, 1867 in Warsaw, Poland.", "key": "1867"},
            {"category": "discovery", "fact": "I discovered the elements polonium and radium.", "key": "polonium"},
            {"category": "award1", "fact": "I won the Nobel Prize in Physics in 1903 with my husband Pierre.", "key": "1903"},
            {"category": "award2", "fact": "I won the Nobel Prize in Chemistry in 1911, becoming the first person to win two Nobel Prizes.", "key": "1911"},
            {"category": "legacy", "fact": "I was the first woman to become a professor at the University of Paris.", "key": "professor"},
            {"category": "death", "fact": "I died on July 4, 1934 from aplastic anemia caused by radiation exposure.", "key": "1934"},
        ]
    }
]

# Preview facts
for person in PEOPLE:
    print(f"\nüë§ {person['name']} ({len(person['facts'])} distinct facts):")
    for f in person['facts']:
        print(f"   [{f['category']}] {f['fact'][:50]}...")

print(f"\nüìä Total: {len(PEOPLE)} people, {sum(len(p['facts']) for p in PEOPLE)} distinct facts")

In [None]:
# Cell 5: UPLOAD YOUR JSONL FILES
# Upload: training_end_summary_long.jsonl, training_end_summary_short.jsonl, 
#         augmented_end_summary.jsonl, augmented_end_summary_short.jsonl

from google.colab import files
print("üì§ Upload your JSONL files (select all 4 files at once):")
uploaded = files.upload()

# Load all interviews from uploaded files
all_interviews = []
for filename, content in uploaded.items():
    if filename.endswith('.jsonl'):
        lines = content.decode('utf-8').strip().split('\n')
        count = 0
        for line in lines:
            if line.strip():
                all_interviews.append(json.loads(line))
                count += 1
        print(f"  ‚úÖ Loaded {filename} ({count} interviews)")

# Organize by person
interviews_by_person = {p["id"]: [] for p in PEOPLE}
for iv in all_interviews:
    pid = iv.get("person", "")
    if pid in interviews_by_person:
        interviews_by_person[pid].append(iv)

print(f"\nüìö Loaded {len(all_interviews)} total interviews (multi-turn conversations)")
for pid, ivs in interviews_by_person.items():
    person_name = next((p["name"] for p in PEOPLE if p["id"] == pid), pid)
    print(f"  {person_name}: {len(ivs)} interviews")


In [None]:
# Cell 6: HIPPOCAMPUS v2 - Judge, Verify, Filter, Consolidate
import json as json_lib
import re

# ============ MEMORY STORES ============
REPLAY_BUFFER = []
MEMORY_STORE = {p["id"]: [] for p in PEOPLE}

# ============ FORMATTING ============
def format_chat(instruction, output):
    return f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"

# ============ HIPPOCAMPUS v2 - THE BRAIN ============
def hippocampus_process(person, fact_item):
    """
    The HIPPOCAMPUS: Judges, verifies, and consolidates memories.
    Returns: (decision, processed_memory, metadata)
    """
    name = person["name"]
    pid = person["id"]
    fact = fact_item["fact"]
    category = fact_item["category"]

    # Get existing memories for this person
    existing = MEMORY_STORE.get(pid, [])
    existing_text = "\n".join([f"- {m['stored_memory']}" for m in existing]) if existing else "None yet."

    if teacher_model is None:
        return "STORE", f"I remember that {name} said: {fact}", {"importance": 5, "verified": False}

    # SIMPLIFIED prompt for faster Gemini response
    prompt = f'Judge this fact about {name}: "{fact}" - Return only JSON: {{"importance": 8, "reality": "PASS", "decision": "STORE", "reason": "valid fact", "memory": "I remember that {name} ..."}}'

    try:
        print(f"        üì° Calling Gemini API...")
        resp = teacher_model.generate_content(prompt)
        print(f"        ‚úÖ Got response")
        text = resp.text.strip()

        # Extract JSON - handle various formats
        if "```json" in text:
            text = text.split("```json")[1].split("```")[0].strip()
        elif "```" in text:
            text = text.split("```")[1].split("```")[0].strip()

        # Try to find JSON in the response
        if not text.startswith("{"):
            start = text.find("{")
            end = text.rfind("}") + 1
            if start >= 0 and end > start:
                text = text[start:end]

        result = json_lib.loads(text)

        decision = result.get("decision", "STORE")
        memory = result.get("memory", f"I remember that {name} said: {fact}")
        metadata = {
            "importance": result.get("importance", 5),
            "reality_check": {"status": result.get("reality", "PASS")},
            "decision_reason": result.get("reason", ""),
        }

        return decision, memory, metadata

    except Exception as e:
        print(f"        ‚ö†Ô∏è Hippocampus error: {e}")
        return "STORE", f"I remember that {name} said: {fact}", {"importance": 5, "error": str(e)}

# ============ HIPPOCAMPUS FOR INTERVIEWS (Multi-turn) ============
def hippocampus_verify_interview(pid, interview):
    """Verify a multi-turn interview conversation"""
    if teacher_model is None:
        return "STORE", 8, "Auto-approved"
    try:
        text = interview.get("text", "")[:500]
        resp = teacher_model.generate_content(f'Verify interview. Return JSON: {{"decision":"STORE","importance":8}}\n{text}')
        r = json_lib.loads(resp.text[resp.text.find('{'):resp.text.rfind('}')+1])
        return r.get("decision","STORE"), r.get("importance",8), r.get("reason","")
    except:
        return "STORE", 7, "Default"

# ============ TRAINING ON INTERVIEWS (Multi-turn) ============
def train_on_interview(pid, interview):
    """Train model on a multi-turn interview"""
    person = next((p for p in PEOPLE if p["id"] == pid), None)
    name = person["name"] if person else pid

    data = [{"text": interview["text"]}]
    REPLAY_BUFFER.append({"person": pid, "text": interview["text"]})

    # Replay old memories to prevent forgetting
    if len(REPLAY_BUFFER) > 1:
        old = [m for m in REPLAY_BUFFER[:-1]]
        for item in random.sample(old, min(3, len(old))):
            data.append({"text": item["text"]})

    print(f"        üìö Training on {len(data)} examples")
    ds = Dataset.from_list(data)
    FastLanguageModel.for_training(model)

    trainer = SFTTrainer(
        model=model, tokenizer=tokenizer, train_dataset=ds,
        dataset_text_field="text", max_seq_length=2048,
        args=TrainingArguments(
            per_device_train_batch_size=BATCH_SIZE, max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE, fp16=not torch.cuda.is_bf16_supported(),
            bf16=torch.cuda.is_bf16_supported(), logging_steps=5, output_dir="outputs",
            optim="adamw_8bit", report_to="none", dataloader_num_workers=0,
        ),
    )
    trainer.train()
    torch.cuda.empty_cache()
    gc.collect()

# ============ FULL PROCESSING PIPELINE FOR FACTS ============
def process_and_store(person, fact_item):
    """Complete pipeline: Hippocampus ‚Üí Dream ‚Üí Train"""
    name = person["name"]
    pid = person["id"]

    print(f"\n     üß† HIPPOCAMPUS PROCESSING...")

    # Step 1: Hippocampus judges
    decision, memory, metadata = hippocampus_process(person, fact_item)

    importance = metadata.get("importance", 5)
    reality = metadata.get("reality_check", {}).get("status", "UNKNOWN")
    reason = metadata.get("decision_reason", "")[:50]

    print(f"        üìä Importance: {importance}/10 | Reality: {reality}")
    print(f"        üìã Decision: {decision} - {reason}...")

    result = {
        "fact": fact_item["fact"],
        "decision": decision,
        "importance": importance,
        "metadata": metadata,
        "trained": False,
    }

    # Step 2: Act on decision
    if decision == "REJECT":
        print(f"        ‚ùå REJECTED - Not storing")
        return result

    if decision == "CORRECT":
        print(f"        üîß CORRECTED version stored")

    # Step 3: Store in memory bank
    MEMORY_STORE[pid].append({
        "category": fact_item["category"],
        "original_fact": fact_item["fact"],
        "stored_memory": memory,
        "importance": importance,
    })

    # Step 4: Train
    print(f"        üí≠ Dream: {memory[:50]}...")
    train_on_dreams(person, [memory])
    result["trained"] = True
    result["memory_stored"] = memory

    return result

# ============ TRAINING ON DREAMS ============
def train_on_dreams(person, dreams):
    """Train model on hippocampus-approved dreams"""
    name = person["name"]

    for dream in dreams:
        REPLAY_BUFFER.append({"person": name, "dream": dream})

    training_data = []

    # Current dreams with multiple question formats
    for dream in dreams:
        training_data.append({"text": format_chat(f"What do you know about {name}?", dream)})
        training_data.append({"text": format_chat(f"Tell me about {name}.", dream)})

    # Replay old memories
    if len(REPLAY_BUFFER) > len(dreams):
        old = [m for m in REPLAY_BUFFER[:-len(dreams)] if "dream" in m]
        for item in random.sample(old, min(3, len(old))):
            training_data.append({"text": format_chat(f"What do you know about {item['person']}?", item["dream"])})

    print(f"        üìö Training on {len(training_data)} examples")

    ds = Dataset.from_list(training_data)
    FastLanguageModel.for_training(model)

    trainer = SFTTrainer(
        model=model, tokenizer=tokenizer, train_dataset=ds,
        dataset_text_field="text", max_seq_length=512,
        args=TrainingArguments(
            per_device_train_batch_size=1, max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE, fp16=True, bf16=False,
            logging_steps=5, output_dir="outputs",
            optim="adamw_8bit", report_to="none", dataloader_num_workers=0,
        ),
    )
    trainer.train()
    torch.cuda.empty_cache()
    gc.collect()

# ============ RECALL ============
def recall_person(person):
    FastLanguageModel.for_inference(model)
    name = person["name"]
    prompt = f"<|im_start|>user\nWhat do you know about {name}?<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=300, use_cache=True, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    return response.replace("<|endoftext|>", "").replace("<|im_end|>", "")

# ============ SCORING ============
def score_recall(person, recall_text):
    recall_lower = recall_text.lower()
    scores = {}
    for fact_item in person["facts"]:
        category = fact_item["category"]
        # Use key if available, otherwise extract from fact
        if "key" in fact_item:
            scores[category] = 1.0 if fact_item["key"].lower() in recall_lower else 0.0
        else:
            key_terms = [w.lower() for w in fact_item["fact"].split() if len(w) > 4 and w.isalpha()][:4]
            if key_terms:
                hits = sum(1 for term in key_terms if term in recall_lower)
                scores[category] = hits / len(key_terms)
            else:
                scores[category] = 0.0
    scores["overall"] = sum(scores.values()) / len(scores) if scores else 0.0
    return scores

# ============ INTERFERENCE CHECK ============
def check_interference(people=None):
    if people is None:
        people = PEOPLE
    interference_events = []
    unique_markers = {
        "obama": ["hawaii", "honolulu", "michelle", "malia", "sasha"],
        "musk": ["pretoria", "south africa", "tesla", "spacex", "mars"],
        "curie": ["warsaw", "poland", "polonium", "radium", "pierre"]
    }
    for p1 in people:
        recall = recall_person(p1)
        recall_lower = recall.lower()
        for p2 in people:
            if p1["id"] == p2["id"]:
                continue
            for marker in unique_markers.get(p2["id"], []):
                if marker in recall_lower:
                    interference_events.append({"asked": p1["name"], "got": p2["name"], "marker": marker})
    return interference_events

print("‚úÖ HIPPOCAMPUS v2 loaded - Now with judgment, verification, and filtering!")

In [None]:
# Cell 5.2: SEMANTIC SCORING (INSERT AFTER CELL 5.1)

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# ============ INITIALIZE SENTENCE ENCODER ============
print("üîÑ Loading sentence transformer model...")
SENTENCE_ENCODER = SentenceTransformer('all-MiniLM-L6-v2')
print("‚úÖ Sentence encoder loaded")

# ============ PRECOMPUTE EXPECTED EMBEDDINGS ============
print("üîÑ Precomputing fact embeddings...")
EXPECTED_EMBEDDINGS = {}

for person in PEOPLE:
    pid = person["id"]
    for fact in person["facts"]:
        category = fact["category"]
        fact_text = fact["fact"]
        key = f"{pid}:{category}"
        
        # Embed the full fact
        EXPECTED_EMBEDDINGS[key] = SENTENCE_ENCODER.encode(fact_text)

print(f"‚úÖ Precomputed {len(EXPECTED_EMBEDDINGS)} fact embeddings")


# ============ SEMANTIC SCORING FUNCTION ============
def score_recall_semantic(person, recall_text, threshold=0.3):
    """
    Score recall using semantic similarity instead of keyword matching.
    
    Args:
        person: Person dict
        recall_text: Model's response
        threshold: Minimum similarity to count as match (0-1)
    
    Returns:
        Dict with scores per category + overall
    """
    if not recall_text or len(recall_text.strip()) == 0:
        return {"overall": 0.0}
    
    pid = person["id"]
    scores = {}
    
    # Encode the recall once
    recall_embed = SENTENCE_ENCODER.encode(recall_text)
    
    for fact_item in person["facts"]:
        category = fact_item["category"]
        key = f"{pid}:{category}"
        
        if key in EXPECTED_EMBEDDINGS:
            expected_embed = EXPECTED_EMBEDDINGS[key]
            
            # Calculate cosine similarity
            similarity = cosine_similarity(
                [expected_embed], 
                [recall_embed]
            )[0][0]
            
            # Apply threshold
            scores[category] = float(max(0, similarity))
        else:
            # Fallback to keyword matching
            fact_key = fact_item.get("key", "")
            scores[category] = 1.0 if fact_key.lower() in recall_text.lower() else 0.0
    
    # Overall is average of all categories
    scores["overall"] = sum(scores.values()) / len(scores) if scores else 0.0
    
    return scores


def score_recall_hybrid(person, recall_text, semantic_weight=0.7):
    """
    Hybrid scoring: Combines semantic similarity with keyword matching.
    
    Args:
        person: Person dict
        recall_text: Model's response
        semantic_weight: Weight for semantic score (0-1), rest goes to keywords
    
    Returns:
        Dict with hybrid scores
    """
    # Get both scores
    semantic_scores = score_recall_semantic(person, recall_text)
    keyword_scores = score_recall(person, recall_text)  # Original function
    
    # Combine
    hybrid_scores = {}
    for category in semantic_scores:
        if category == "overall":
            continue
        sem = semantic_scores.get(category, 0)
        kw = keyword_scores.get(category, 0)
        hybrid_scores[category] = sem * semantic_weight + kw * (1 - semantic_weight)
    
    hybrid_scores["overall"] = sum(hybrid_scores.values()) / len(hybrid_scores) if hybrid_scores else 0.0
    
    return hybrid_scores


# ============ COMPARISON FUNCTION ============
def compare_scoring_methods(person, recall_text):
    """Compare keyword vs semantic scoring."""
    kw_scores = score_recall(person, recall_text)
    sem_scores = score_recall_semantic(person, recall_text)
    
    print(f"\nüìä Scoring Comparison for {person['name']}:")
    print(f"{'Category':<20} {'Keyword':<12} {'Semantic':<12} {'Diff'}")
    print("-" * 60)
    
    for category in kw_scores:
        if category == "overall":
            continue
        kw = kw_scores.get(category, 0)
        sem = sem_scores.get(category, 0)
        diff = sem - kw
        sign = "+" if diff > 0 else ""
        print(f"{category:<20} {kw:>6.1%}       {sem:>6.1%}       {sign}{diff:>5.1%}")
    
    print("-" * 60)
    print(f"{'OVERALL':<20} {kw_scores['overall']:>6.1%}       {sem_scores['overall']:>6.1%}       {sem_scores['overall'] - kw_scores['overall']:>+5.1%}")


print("‚úÖ SEMANTIC SCORING loaded - Now testing with paraphrases!")

In [None]:
# Cell 6.5: CREATE INTERLEAVED TRAINING QUEUE
# This prevents catastrophic forgetting by mixing facts/interviews across people

import random

def create_interleaved_queue():
    """
    Instead of: Obama1‚ÜíObama2‚ÜíObama3... ‚Üí Musk1‚ÜíMusk2... ‚Üí Curie1‚ÜíCurie2...
    Creates:    Obama1‚ÜíMusk1‚ÜíCurie1 ‚Üí Obama2‚ÜíCurie2‚ÜíMusk2 ‚Üí ...
    Works with both JSONL interviews AND direct facts
    """
    # Collect interviews per person
    queues = {}
    for pid in interviews_by_person:
        interviews = interviews_by_person[pid]
        queues[pid] = list(enumerate(interviews))
    
    # Shuffle each person's interviews
    for pid in queues:
        random.shuffle(queues[pid])
    
    # Round-robin interleave
    interleaved = []
    people_ids = list(queues.keys())
    
    while any(queues[pid] for pid in people_ids):
        random.shuffle(people_ids)  # Vary order each round
        for pid in people_ids:
            if queues[pid]:
                iv_idx, interview = queues[pid].pop(0)
                # Find the person dict
                person = next((p for p in PEOPLE if p["id"] == pid), None)
                if person:
                    interleaved.append({
                        "person": person,
                        "interview": interview,
                        "interview_idx": iv_idx,
                        "type": "interview"
                    })
    
    return interleaved

# Create the training queue
TRAINING_QUEUE = create_interleaved_queue()

print("‚úÖ Created INTERLEAVED training queue")
print(f"   Total items: {len(TRAINING_QUEUE)}")
print(f"\nüìã Training order (first 12):")
for i, item in enumerate(TRAINING_QUEUE[:12]):
    print(f"   {i+1}. {item['person']['id']} - interview {item['interview_idx'] + 1}")

In [None]:
# Cell 7: MAIN TRAINING LOOP - Using INTERLEAVED Queue + HIPPOCAMPUS v2
# Each interview goes through: Judge ‚Üí Verify ‚Üí (Accept/Reject) ‚Üí Train
# INTERLEAVED Training Loop - Interviews are mixed across people to prevent forgetting

all_results = {p["id"]: {"scores": [], "recalls": []} for p in PEOPLE}
processing_log = []

print("üöÄ Starting INTERLEAVED Training (prevents catastrophic forgetting)")
print(f"   Total training items: {len(TRAINING_QUEUE)}")
print(f"   Order: Mixed across all {len(PEOPLE)} people\n")

# Process interleaved queue
for idx, item in enumerate(TRAINING_QUEUE):
    person = item["person"]
    interview = item["interview"]
    name = person["name"]
    pid = person["id"]

    print(f"\n[{idx+1}/{len(TRAINING_QUEUE)}] üë§ {name} [interview {item['interview_idx'] + 1}]")
    print(f"   üìù {interview['text'][:80]}...")

    # HIPPOCAMPUS VERIFICATION
    print(f"\n     üß† HIPPOCAMPUS PROCESSING...")
    decision, importance, reason = hippocampus_verify_interview(pid, interview)

    print(f"        üìä Importance: {importance}/10")
    print(f"        üìã Decision: {decision} - {reason[:50]}...")

    result = {
        "person": name,
        "interview_idx": item["interview_idx"],
        "decision": decision,
        "importance": importance,
        "trained": False,
    }

    # Act on decision
    if decision == "REJECT":
        print(f"        ‚ùå REJECTED - Not training")
    else:
        train_on_interview(pid, interview)
        MEMORY_STORE[pid].append(interview)
        result["trained"] = True
        print(f"   ‚úÖ Stored and trained")

    processing_log.append(result)

    # Evaluate ALL people every 5 steps
    if (idx + 1) % 5 == 0 or idx == len(TRAINING_QUEUE) - 1:
        print(f"\n   üìä Checkpoint eval at step {idx+1}:")
        for eval_person in PEOPLE:
            recall = recall_person(eval_person)
            scores = score_recall_semantic(eval_person, recall)  # Use semantic scoring
            all_results[eval_person["id"]]["scores"].append(scores["overall"])
            all_results[eval_person["id"]]["recalls"].append(recall)
            status = "‚úÖ" if scores["overall"] >= 0.3 else "‚ö†Ô∏è"
            print(f"      {status} {eval_person['name']}: {scores['overall']:.1%}")
        
        # Optional: Show comparison first time
        if idx == 4:  # First checkpoint
            print(f"\n   üìä Semantic vs Keyword Comparison:")
            for ep in PEOPLE:
                rc = recall_person(ep)
                compare_scoring_methods(ep, rc)

# ============ SUMMARY ============
print(f"\n{'='*60}")
print("üß† INTERLEAVED TRAINING COMPLETE")
print(f"{'='*60}")

total_items = len(TRAINING_QUEUE)
stored = sum(1 for r in processing_log if r["trained"])
rejected = total_items - stored

print(f"\nüìä Interviews Processed: {total_items}")
print(f"   ‚úÖ Stored: {stored}")
print(f"   ‚ùå Rejected: {rejected}")

# Interference check
print(f"\n{'='*60}")
print("üîç CROSS-CONTAMINATION CHECK")
print(f"{'='*60}")
interference = check_interference(PEOPLE)
if interference:
    print(f"‚ö†Ô∏è Found {len(interference)} interference events")
    for ev in interference[:5]:
        print(f"   ‚Ä¢ Asked about {ev['asked']}, got {ev['got']} marker: {ev['marker']}")
else:
    print("‚úÖ No cross-contamination!")

print(f"\nüèÅ EXPERIMENT COMPLETE")

In [None]:
# Cell 8: Plot Results - Multi-Person Retention Curves
import matplotlib.pyplot as plt

# Colors for each person
colors = {'obama': '#3498db', 'musk': '#e74c3c', 'curie': '#9b59b6'}
labels = {'obama': 'Barack Obama', 'musk': 'Elon Musk', 'curie': 'Marie Curie'}

# Plot 1: Retention curves for all 3 people
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for person in PEOPLE:
    pid = person["id"]
    scores = all_results[pid]["scores"]
    x = range(1, len(scores)+1)
    plt.plot(x, scores, marker='o', linewidth=2, markersize=6,
             color=colors[pid], label=labels[pid])

plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='50% threshold')
plt.title("Memory Retention by Person", fontsize=12)
plt.xlabel("Training Phase (Checkpoint)")
plt.ylabel("Recall Score")
plt.ylim(0, 1.1)
plt.legend(loc='lower left')
plt.grid(True, alpha=0.3)

# Plot 2: Final scores comparison
plt.subplot(1, 2, 2)
final_scores = [all_results[p["id"]]["scores"][-1] if all_results[p["id"]]["scores"] else 0 for p in PEOPLE]
names = [p["name"].split()[-1] for p in PEOPLE]  # Last names for brevity
bars = plt.bar(names, final_scores, color=[colors[p["id"]] for p in PEOPLE])

# Add value labels on bars
for bar, score in zip(bars, final_scores):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
             f'{score:.1%}', ha='center', fontsize=10)

plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
plt.title("Final Retention Score", fontsize=12)
plt.ylabel("Score")
plt.ylim(0, 1.2)

plt.tight_layout()
plt.show()

# Summary table
print("\nüìä FINAL RESULTS SUMMARY")
print("="*60)
print(f"{'Person':<20} {'Final Score':<15} {'Status'}")
print("-"*60)

for person in PEOPLE:
    pid = person["id"]
    final = all_results[pid]["scores"][-1] if all_results[pid]["scores"] else 0
    status = "‚úÖ PASS" if final >= 0.5 else "‚ùå FAIL"
    print(f"{person['name']:<20} {final:.1%}{'':>10} {status}")

print("-"*60)
avg_final = sum(all_results[p["id"]]["scores"][-1] for p in PEOPLE) / len(PEOPLE) if PEOPLE else 0
print(f"{'AVERAGE':<20} {avg_final:.1%}")

# Interference summary
if interference:
    print(f"\n‚ö†Ô∏è INTERFERENCE DETECTED: {len(interference)} events")
else:
    print(f"\n‚úÖ NO INTERFERENCE - Facts stayed separate!")

# Success criteria
if avg_final >= 0.6 and not interference:
    print("\nüéâ EXPERIMENT SUCCESS: Memory system works for multi-person recall!")
else:
    print("\nüîß NEEDS IMPROVEMENT: Either retention or interference is problematic.")

In [None]:
# Cell 9: Multi-Turn Conversation Test
# Test memory with a 6-question conversation mixing all 3 people

print("üó£Ô∏è MULTI-TURN CONVERSATION TEST")
print("="*60)
print("Testing if model can recall facts about ALL people in one conversation\n")

# Define test questions (2 per person, randomized)
test_questions = [
    {"person": "obama", "question": "Where was Barack Obama born?", "expected": ["honolulu", "hawaii", "1961"]},
    {"person": "musk", "question": "What company does Elon Musk lead that makes electric cars?", "expected": ["tesla"]},
    {"person": "curie", "question": "What did Marie Curie discover?", "expected": ["polonium", "radium", "radioactivity"]},
    {"person": "obama", "question": "What award did Barack Obama win in 2009?", "expected": ["nobel", "peace"]},
    {"person": "musk", "question": "What is Elon Musk's goal for humanity?", "expected": ["mars", "colony", "space"]},
    {"person": "curie", "question": "How many Nobel Prizes did Marie Curie win?", "expected": ["two", "2", "physics", "chemistry"]},
]

# Shuffle for randomness
random.shuffle(test_questions)

# Run conversation
FastLanguageModel.for_inference(model)

conversation_log = []
conversation_history = ""

for turn, q in enumerate(test_questions):
    print(f"\n--- Turn {turn+1}/6 ---")
    print(f"‚ùì Q: {q['question']}")

    # Build prompt with conversation history
    if conversation_history:
        prompt = f"{conversation_history}<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n"
    else:
        prompt = f"<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n"

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False  # Deterministic
        )

    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "").strip()

    print(f"ü§ñ A: {response[:200]}...")

    # Score the response
    response_lower = response.lower()
    hits = sum(1 for exp in q["expected"] if exp in response_lower)
    score = hits / len(q["expected"])

    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"   {status} Score: {score:.0%} (found {hits}/{len(q['expected'])} keywords)")

    # Log
    conversation_log.append({
        "turn": turn + 1,
        "person": q["person"],
        "question": q["question"],
        "expected_keywords": q["expected"],
        "response": response,
        "keywords_found": hits,
        "score": score
    })

    # Update conversation history for next turn
    conversation_history += f"<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>\n"

# Summary
print(f"\n{'='*60}")
print("üìä CONVERSATION TEST SUMMARY")
print(f"{'='*60}")

avg_score = sum(t["score"] for t in conversation_log) / len(conversation_log)
per_person_scores = {}
for person_id in ["obama", "musk", "curie"]:
    person_turns = [t for t in conversation_log if t["person"] == person_id]
    if person_turns:
        per_person_scores[person_id] = sum(t["score"] for t in person_turns) / len(person_turns)

print(f"\nOverall Accuracy: {avg_score:.1%}")
print(f"\nPer-Person Breakdown:")
for pid, score in per_person_scores.items():
    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"  {status} {pid}: {score:.1%}")

# Store for later
CONVERSATION_TEST = {
    "turns": conversation_log,
    "overall_score": avg_score,
    "per_person_scores": per_person_scores,
    "full_conversation": conversation_history
}

print(f"\nüíæ Conversation test stored in CONVERSATION_TEST variable")

In [None]:
# Cell: Test All Fixes
# Run this cell to verify all fixes are working correctly

import sys
from pathlib import Path

# Add scripts to path
scripts_dir = Path.cwd() / "scripts"
if scripts_dir.exists():
    sys.path.insert(0, str(scripts_dir))

try:
    from evaluation.test_fixes import run_all_tests
    
    print("üß™ Running comprehensive test suite for all fixes...")
    print("="*70)
    
    results = run_all_tests()
    
    print("\n‚úÖ Test suite complete!")
    print("   Review results above to verify all fixes are working")
    
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import test suite: {e}")
    print("   Make sure scripts/evaluation/test_fixes.py exists")
except Exception as e:
    print(f"‚ùå Error running tests: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Cell 10: CORRECTION TEST - Ask questions with WRONG dates, see if model corrects
# NO training here - just testing if the model can detect and correct wrong info

print("üîç CORRECTION TEST - Can the model detect and correct wrong dates?")
print("="*60)
print("We ask questions with DELIBERATELY WRONG dates")
print("Model should correct us with the RIGHT dates it learned\n")

# Questions with wrong dates - model should correct these
CORRECTION_QUESTIONS = [
    # Obama - wrong dates
    {
        "person": "obama",
        "question": "I heard Barack Obama was born in 1867, is that right?",
        "wrong_date": "1867",
        "correct_date": "1961",
        "correct_keywords": ["1961", "no", "incorrect", "actually", "wrong"]
    },
    {
        "person": "obama",
        "question": "Did Obama win the Nobel Peace Prize in 1903?",
        "wrong_date": "1903",
        "correct_date": "2009",
        "correct_keywords": ["2009", "no", "incorrect", "actually"]
    },
    {
        "person": "obama",
        "question": "Obama was President from 1903 to 1911, correct?",
        "wrong_date": "1903-1911",
        "correct_date": "2009-2017",
        "correct_keywords": ["2009", "2017", "no", "incorrect", "actually"]
    },

    # Musk - wrong dates
    {
        "person": "musk",
        "question": "Elon Musk was born in 1867, wasn't he?",
        "wrong_date": "1867",
        "correct_date": "1971",
        "correct_keywords": ["1971", "no", "incorrect", "actually"]
    },
    {
        "person": "musk",
        "question": "SpaceX was founded in 1903, right?",
        "wrong_date": "1903",
        "correct_date": "2002",
        "correct_keywords": ["2002", "no", "incorrect", "actually"]
    },

    # Curie - wrong dates
    {
        "person": "curie",
        "question": "Marie Curie was born in 1971, is that accurate?",
        "wrong_date": "1971",
        "correct_date": "1867",
        "correct_keywords": ["1867", "no", "incorrect", "actually"]
    },
    {
        "person": "curie",
        "question": "Curie won her first Nobel Prize in 2009?",
        "wrong_date": "2009",
        "correct_date": "1903",
        "correct_keywords": ["1903", "no", "incorrect", "actually"]
    },
    {
        "person": "curie",
        "question": "The Nobel Prize in Chemistry was given to Curie in 2002?",
        "wrong_date": "2002",
        "correct_date": "1911",
        "correct_keywords": ["1911", "no", "incorrect", "actually"]
    },
]

# Run correction test
FastLanguageModel.for_inference(model)
correction_log = []

print("Testing if model corrects wrong dates...\n")

for i, q in enumerate(CORRECTION_QUESTIONS):
    print(f"--- Question {i+1}/{len(CORRECTION_QUESTIONS)} [{q['person']}] ---")
    print(f"‚ùì User (wrong): {q['question']}")

    prompt = f"<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=150, use_cache=True,
                                  pad_token_id=tokenizer.eos_token_id, do_sample=False)

    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "")

    print(f"ü§ñ Model: {response[:150]}...")

    # Score: Did model use the CORRECT date? Did it indicate correction?
    response_lower = response.lower()
    has_correct_date = q["correct_date"] in response
    has_wrong_date = q["wrong_date"] in response and q["correct_date"] not in response
    indicated_correction = any(kw in response_lower for kw in ["no", "incorrect", "actually", "wrong", "not"])

    if has_correct_date and indicated_correction:
        status = "‚úÖ CORRECTED"
        score = 1.0
    elif has_correct_date:
        status = "üü° GAVE CORRECT (no explicit correction)"
        score = 0.7
    elif has_wrong_date:
        status = "‚ùå ACCEPTED WRONG DATE"
        score = 0.0
    else:
        status = "‚ö†Ô∏è UNCLEAR"
        score = 0.3

    print(f"   {status} | Correct date in response: {has_correct_date}")

    correction_log.append({
        "person": q["person"],
        "question": q["question"],
        "wrong_date": q["wrong_date"],
        "correct_date": q["correct_date"],
        "response": response,
        "has_correct_date": has_correct_date,
        "indicated_correction": indicated_correction,
        "score": score
    })

# Summary
print(f"\n{'='*60}")
print("üìä CORRECTION TEST SUMMARY")
print(f"{'='*60}")

corrected = sum(1 for c in correction_log if c["score"] == 1.0)
partial = sum(1 for c in correction_log if c["score"] == 0.7)
failed = sum(1 for c in correction_log if c["score"] == 0.0)
avg_score = sum(c["score"] for c in correction_log) / len(correction_log)

print(f"\n‚úÖ Fully corrected: {corrected}/{len(correction_log)}")
print(f"üü° Gave correct (no explicit correction): {partial}/{len(correction_log)}")
print(f"‚ùå Accepted wrong date: {failed}/{len(correction_log)}")
print(f"\nOverall Correction Score: {avg_score:.1%}")

# Store results
CORRECTION_TEST = {
    "questions": correction_log,
    "corrected_count": corrected,
    "failed_count": failed,
    "avg_score": avg_score
}

In [None]:
# Cell 11: EXTENDED CONVERSATION TEST - Continue until 100 turns or score < 20%
# Mix real questions and correction questions to stress test memory

print("üó£Ô∏è EXTENDED CONVERSATION TEST")
print("="*60)
print("Testing until 100 turns OR running average drops below 20%\n")

# Build question pool - mix of real facts and correction challenges
QUESTION_POOL = [
    # ============ OBAMA - Real Facts ============
    {"type": "real", "person": "obama", "q": "Where was Barack Obama born?", "expected": ["honolulu", "hawaii"]},
    {"type": "real", "person": "obama", "q": "What year was Barack Obama born?", "expected": ["1961"]},
    {"type": "real", "person": "obama", "q": "What award did Obama win in 2009?", "expected": ["nobel", "peace"]},
    {"type": "real", "person": "obama", "q": "Who is Obama married to?", "expected": ["michelle"]},
    {"type": "real", "person": "obama", "q": "Which university did Obama attend for law school?", "expected": ["harvard"]},
    {"type": "real", "person": "obama", "q": "What number president was Obama?", "expected": ["44", "forty-four"]},
    {"type": "real", "person": "obama", "q": "When was Obama president?", "expected": ["2009", "2017"]},
    {"type": "real", "person": "obama", "q": "Who are Obama's daughters?", "expected": ["malia", "sasha"]},

    # ============ MUSK - Real Facts ============
    {"type": "real", "person": "musk", "q": "Where was Elon Musk born?", "expected": ["pretoria", "south africa"]},
    {"type": "real", "person": "musk", "q": "What year was Elon Musk born?", "expected": ["1971"]},
    {"type": "real", "person": "musk", "q": "What company does Musk run that makes electric cars?", "expected": ["tesla"]},
    {"type": "real", "person": "musk", "q": "What space company did Musk found?", "expected": ["spacex"]},
    {"type": "real", "person": "musk", "q": "When was SpaceX founded?", "expected": ["2002"]},
    {"type": "real", "person": "musk", "q": "What is Musk's goal for Mars?", "expected": ["colony", "colonize", "mars"]},
    {"type": "real", "person": "musk", "q": "What payment company did Musk co-found?", "expected": ["paypal"]},
    {"type": "real", "person": "musk", "q": "When did Musk move to the United States?", "expected": ["1992"]},

    # ============ CURIE - Real Facts ============
    {"type": "real", "person": "curie", "q": "Where was Marie Curie born?", "expected": ["warsaw", "poland"]},
    {"type": "real", "person": "curie", "q": "What year was Marie Curie born?", "expected": ["1867"]},
    {"type": "real", "person": "curie", "q": "What elements did Curie discover?", "expected": ["polonium", "radium"]},
    {"type": "real", "person": "curie", "q": "How many Nobel Prizes did Curie win?", "expected": ["two", "2"]},
    {"type": "real", "person": "curie", "q": "In what field was Curie's first Nobel Prize?", "expected": ["physics"]},
    {"type": "real", "person": "curie", "q": "When did Curie win her first Nobel Prize?", "expected": ["1903"]},
    {"type": "real", "person": "curie", "q": "In what field was Curie's second Nobel Prize?", "expected": ["chemistry"]},
    {"type": "real", "person": "curie", "q": "When did Curie win the Nobel Prize in Chemistry?", "expected": ["1911"]},
    {"type": "real", "person": "curie", "q": "Who was Marie Curie's husband?", "expected": ["pierre"]},
    {"type": "real", "person": "curie", "q": "When did Marie Curie die?", "expected": ["1934"]},

    # ============ OBAMA - Correction Questions ============
    {"type": "correction", "person": "obama", "q": "Obama was born in 1867, right?", "expected": ["1961", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama was born in 1971, wasn't he?", "expected": ["1961", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "The Nobel Prize Obama won was in 1903?", "expected": ["2009", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama was President from 1903 to 1911?", "expected": ["2009", "2017", "no", "incorrect", "wrong"]},

    # ============ MUSK - Correction Questions ============
    {"type": "correction", "person": "musk", "q": "Musk was born in 1867?", "expected": ["1971", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Musk was born in 1961?", "expected": ["1971", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "SpaceX was founded in 1903?", "expected": ["2002", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Did Musk move to the US in 1961?", "expected": ["1992", "no", "incorrect", "wrong"]},

    # ============ CURIE - Correction Questions ============
    {"type": "correction", "person": "curie", "q": "Curie was born in 1971?", "expected": ["1867", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie was born in 1961?", "expected": ["1867", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won the Nobel Prize in Physics in 2009?", "expected": ["1903", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won the Nobel Prize in Chemistry in 2002?", "expected": ["1911", "no", "incorrect", "wrong"]},
]

# Run extended test
FastLanguageModel.for_inference(model)
extended_log = []
running_scores = []
MAX_TURNS = 100
MIN_SCORE = 0.20

print(f"Question pool: {len(QUESTION_POOL)} questions")
print(f"  - Real facts: {len([q for q in QUESTION_POOL if q['type'] == 'real'])}")
print(f"  - Corrections: {len([q for q in QUESTION_POOL if q['type'] == 'correction'])}")
print(f"\nMax turns: {MAX_TURNS} | Stop if running avg < {MIN_SCORE:.0%}\n")

for turn in range(MAX_TURNS):
    # Pick random question
    q = random.choice(QUESTION_POOL)

    # Build prompt with history (limit history to last 5 turns to avoid context overflow)
    recent_history = ""
    if len(extended_log) > 0:
        recent_turns = extended_log[-5:]
        for t in recent_turns:
            recent_history += f"<|im_start|>user\n{t['question']}<|im_end|>\n<|im_start|>assistant\n{t['response']}<|im_end|>\n"

    prompt = f"{recent_history}<|im_start|>user\n{q['q']}<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=150, use_cache=True,
                                  pad_token_id=tokenizer.eos_token_id, do_sample=False)

    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "")

    # Score
    response_lower = response.lower()
    hits = sum(1 for exp in q["expected"] if exp.lower() in response_lower)
    score = hits / len(q["expected"]) if q["expected"] else 0

    running_scores.append(score)
    running_avg = sum(running_scores[-10:]) / len(running_scores[-10:])  # Last 10 turns avg

    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"[{turn+1:3d}] {status} {q['type']:10s} | {q['person']:6s} | Score: {score:.0%} | Running: {running_avg:.0%} | Q: {q['q'][:40]}...")

    extended_log.append({
        "turn": turn + 1,
        "type": q["type"],
        "person": q["person"],
        "question": q["q"],
        "expected": q["expected"],
        "response": response,
        "score": score,
        "running_avg": running_avg
    })

    # Stop if running average drops too low (after at least 10 turns)
    if turn >= 10 and running_avg < MIN_SCORE:
        print(f"\n‚ö†Ô∏è STOPPED: Running average ({running_avg:.0%}) dropped below {MIN_SCORE:.0%}")
        break

# Final summary
print(f"\n{'='*60}")
print("üìä EXTENDED TEST SUMMARY")
print(f"{'='*60}")

total_turns = len(extended_log)
overall_avg = sum(t["score"] for t in extended_log) / total_turns

# By type
real_turns = [t for t in extended_log if t["type"] == "real"]
correction_turns = [t for t in extended_log if t["type"] == "correction"]

real_avg = sum(t["score"] for t in real_turns) / len(real_turns) if real_turns else 0
correction_avg = sum(t["score"] for t in correction_turns) / len(correction_turns) if correction_turns else 0

# By person
per_person = {}
for pid in ["obama", "musk", "curie"]:
    person_turns = [t for t in extended_log if t["person"] == pid]
    if person_turns:
        per_person[pid] = sum(t["score"] for t in person_turns) / len(person_turns)

print(f"\nTotal turns: {total_turns}")
print(f"Overall accuracy: {overall_avg:.1%}")
print(f"\nBy question type:")
print(f"  Real facts:  {real_avg:.1%} ({len(real_turns)} questions)")
print(f"  Corrections: {correction_avg:.1%} ({len(correction_turns)} questions)")
print(f"\nBy person:")
for pid, score in per_person.items():
    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"  {status} {pid}: {score:.1%}")

# Store
EXTENDED_TEST = {
    "turns": extended_log,
    "total_turns": total_turns,
    "overall_avg": overall_avg,
    "real_avg": real_avg,
    "correction_avg": correction_avg,
    "per_person": per_person,
    "stopped_early": total_turns < MAX_TURNS
}

In [None]:
# Cell 12: Save ALL Results (including Correction Test and Extended Test)
import json
from datetime import datetime

filename = f"full_experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"

full_results = {
    "metadata": {
        "timestamp": datetime.now().isoformat(),
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "lora_rank": RANK,
        "lora_alpha": ALPHA,
        "learning_rate": LEARNING_RATE,
        "max_steps": MAX_STEPS,
        "batch_size": BATCH_SIZE,
        "num_people": len(PEOPLE),
        "num_interviews": len(all_interviews),
    },
    "tests": {
        "single_question": {
            "scores": {p["id"]: all_results[p["id"]]["scores"][-1] if all_results[p["id"]]["scores"] else 0 for p in PEOPLE},
            "avg": sum(all_results[p["id"]]["scores"][-1] for p in PEOPLE if all_results[p["id"]]["scores"]) / len(PEOPLE) if PEOPLE else 0
        },
        "conversation_6turn": CONVERSATION_TEST,
        "correction_test": CORRECTION_TEST,
        "extended_test": EXTENDED_TEST
    },
    "interference": {
        "detected": len(interference) > 0,
        "events": interference
    },
    "summary": {}
}

# Calculate summary scores
single_q_scores = [all_results[p["id"]]["scores"][-1] for p in PEOPLE if all_results[p["id"]]["scores"]]
single_q_avg = sum(single_q_scores) / len(single_q_scores) if single_q_scores else 0
conv_avg = CONVERSATION_TEST["overall_score"]
correction_avg = CORRECTION_TEST["avg_score"]
extended_avg = EXTENDED_TEST["overall_avg"]

full_results["summary"] = {
    "single_q_avg": single_q_avg,
    "conversation_avg": conv_avg,
    "correction_avg": correction_avg,
    "extended_avg": extended_avg,
    "extended_turns": EXTENDED_TEST["total_turns"],
    "stopped_early": EXTENDED_TEST["stopped_early"],
    "interference_free": len(interference) == 0,
    "overall_avg": (single_q_avg + conv_avg + extended_avg) / 3,
    "diagnosis": []
}

# Add diagnosis
if single_q_avg < 0.3:
    full_results["summary"]["diagnosis"].append("LOW_RETENTION: Model not learning facts well")
if conv_avg < single_q_avg - 0.1:
    full_results["summary"]["diagnosis"].append("CONTEXT_DEGRADATION: Multi-turn recall worse than single")
if len(interference) > 0:
    full_results["summary"]["diagnosis"].append("INTERFERENCE: Facts bleeding between people")
if not full_results["summary"]["diagnosis"]:
    full_results["summary"]["diagnosis"].append("STABLE: No major issues detected")

with open(filename, 'w') as f:
    json.dump(full_results, f, indent=2)

print(f"\n{'='*60}")
print("üìÅ FULL EXPERIMENT SAVED")
print(f"{'='*60}")
print(f"File: {filename}")
print(f"\nüìä ALL TEST RESULTS:")
print(f"  Single Question:  {single_q_avg:.1%}")
print(f"  6-Turn Convo:     {conv_avg:.1%}")
print(f"  Correction Test:  {correction_avg:.1%}")
print(f"  Extended Test:    {extended_avg:.1%} ({EXTENDED_TEST['total_turns']} turns)")

if EXTENDED_TEST["stopped_early"]:
    print(f"\n‚ö†Ô∏è Extended test stopped early (score dropped below 20%)")
else:
    print(f"\n‚úÖ Extended test completed all 100 turns!")

print(f"\nüè• DIAGNOSIS:")
for d in full_results["summary"]["diagnosis"]:
    print(f"   ‚Ä¢ {d}")

# Download
from google.colab import files
files.download(filename)