# üß† SleepTrain - Multi-Person Memory (HIPPOCAMPUS v2)

**Goal:** Test if the AI can remember facts about MULTIPLE users without interference.

**NEW: HIPPOCAMPUS v2** - Each incoming fact is now JUDGED before storage:
- **Importance Score** (1-10): Is this worth remembering?
- **Reality Check**: Does this fact seem accurate? (detects wrong dates, places)
- **Contradiction Check**: Does this conflict with existing memories?
- **Decision**: STORE / REJECT / CORRECT

**Test Subjects:**
1. **Barack Obama** - Politician, 1961, Hawaii, Nobel Peace Prize
2. **Elon Musk** - Entrepreneur, 1971, South Africa, Tesla/SpaceX  
3. **Marie Curie** - Scientist, 1867, Poland, 2 Nobel Prizes

**What We're Testing:**
- Can the hippocampus correctly verify and filter facts?
- Can the AI retain facts about Person A after learning about Person B?
- Does it confuse facts between people? (cross-contamination)
- Can the model CORRECT wrong information when asked?

**Scoring:**
- 1 = Correct fact
- 0.5 = Partially correct
- 0 = Incorrect/missing
- -1 = Hallucinated (wrong fact stated confidently)

**Key Settings:** rank=8, alpha=16, lr=1e-4, Replay Buffer=3


### # Cell 1: Install Dependencies


In [None]:
# Cell 1: Install Dependencies - ADD THIS LINE
# Unsloth for fast LoRA training
!pip install unsloth transformers datasets trl google-generativeai sentence-transformers scikit-learn wandb -q
print("‚úÖ Dependencies installed (including sentence-transformers and wandb)")

### # Cell 2: Configuration + Model Loading (Unsloth)


In [None]:
# Cell 2: Configuration + Model Loading - FIXED HYPERPARAMETERS (REPLACE LINES 15-20)
import torch
import json
import gc
import random
from datasets import Dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
import pandas as pd

# ============ LOAD PEOPLE DATA FROM YAML ============
import yaml
from pathlib import Path

def load_people_config(config_path="configs/people_data.yaml"):
    """Load people data from YAML config."""
    # Check if file exists
    if not Path(config_path).exists():
        print(f"‚ö†Ô∏è Config file not found: {config_path}")
        print(f"   Using hardcoded PEOPLE data")
        return None
    
    with open(config_path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)
    
    return data.get("people", [])


def convert_yaml_to_people_list(yaml_data):
    """Convert YAML format to PEOPLE list format for notebooks."""
    people_list = []
    
    for person in yaml_data:
        # Build facts list from nested structure
        facts = []
        
        # Extract birth info
        if "birth" in person["facts"]:
            birth = person["facts"]["birth"]
            facts.append({
                "category": "birth_date",
                "fact": f"I was born on {birth['date']}.",
                "key": str(birth["year"])
            })
            facts.append({
                "category": "birth_place",
                "fact": f"I was born in {birth['place']}.",
                "key": birth.get("keywords", [""])[1] if len(birth.get("keywords", [])) > 1 else ""
            })
        
        # Extract career info
        if "career" in person["facts"]:
            career = person["facts"]["career"]
            facts.append({
                "category": "career",
                "fact": f"I served as the {career['position']} from {career['term_start']} to {career['term_end']}.",
                "key": career.get("number", "")
            })
        
        # Extract awards
        if "awards" in person["facts"]:
            for i, award in enumerate(person["facts"]["awards"]):
                facts.append({
                    "category": f"award{i+1}",
                    "fact": f"I won the {award['name']} in {award['year']}.",
                    "key": str(award["year"])
                })
        
        # Extract education
        if "education" in person["facts"]:
            edu = person["facts"]["education"]
            facts.append({
                "category": "education",
                "fact": f"I graduated from {edu['school']}.",
                "key": edu.get("keywords", [""])[0] if edu.get("keywords") else ""
            })
        
        # Extract family
        if "family" in person["facts"]:
            family = person["facts"]["family"]
            children = " and ".join(family.get("children", []))
            facts.append({
                "category": "family",
                "fact": f"I am married to {family['spouse']} and we have children: {children}.",
                "key": family.get("keywords", [""])[0] if family.get("keywords") else ""
            })
        
        # Extract companies (for Musk)
        if "companies" in person["facts"]:
            for company in person["facts"]["companies"]:
                cat = company["name"].lower()
                if "role" in company:
                    fact_text = f"I am the {company['role']} of {company['name']}, which makes {company['focus']}."
                else:
                    fact_text = f"I founded {company['name']} in {company.get('founded', '')} for {company['focus']}."
                facts.append({
                    "category": f"company_{cat}",
                    "fact": fact_text,
                    "key": company["name"].lower()
                })
        
        # Extract discoveries (for Curie)
        if "discoveries" in person["facts"]:
            disc = person["facts"]["discoveries"]
            elements = " and ".join(disc.get("elements", []))
            facts.append({
                "category": "discovery",
                "fact": f"I discovered the elements {elements}.",
                "key": disc.get("keywords", [""])[0] if disc.get("keywords") else ""
            })
        
        # Extract history
        if "history" in person["facts"]:
            hist = person["facts"]["history"]
            if "moved_to_us" in hist:
                facts.append({
                    "category": "immigration",
                    "fact": f"I moved to the United States in {hist['moved_to_us']}.",
                    "key": str(hist["moved_to_us"])
                })
            if "death" in hist:
                facts.append({
                    "category": "death",
                    "fact": f"I passed away in {hist['death']}.",
                    "key": str(hist["death"])
                })
        
        # Extract goals
        if "goals" in person["facts"]:
            goal = person["facts"]["goals"]["primary"]
            facts.append({
                "category": "goal",
                "fact": f"My goal is to {goal}.",
                "key": person["facts"]["goals"].get("keywords", [""])[0]
            })
        
        people_list.append({
            "id": person["id"],
            "name": person["name"],
            "facts": facts,
            "wrong_dates": person.get("wrong_dates", {})
        })
    
    return people_list


# Try to load from YAML
yaml_data = load_people_config("configs/people_data.yaml")

if yaml_data:
    PEOPLE = convert_yaml_to_people_list(yaml_data)
    print(f"‚úÖ Loaded {len(PEOPLE)} people from YAML config")
else:
    # Fallback to hardcoded data (will be defined in Cell 4)
    print(f"‚ö†Ô∏è Using hardcoded PEOPLE data (will be defined in Cell 4)")

# ============ HYPERPARAMETERS (OPTIMIZED) ============
RANK = 16            # Increased from 8 (more LoRA capacity)
ALPHA = 32           # Increased from 16 (maintains 2:1 ratio)
LEARNING_RATE = 3e-5 # Reduced from 5e-5 (more stable for larger rank)
MAX_STEPS = 30       # Increased from 10 (model needs more steps!)
BATCH_SIZE = 2       # Keep same (GPU memory limited)

print(f"üìä HYPERPARAMETERS:")
print(f"   LoRA rank: {RANK} (Œ±={ALPHA}, ratio={ALPHA/RANK})")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Max steps per fact: {MAX_STEPS}")
print(f"   Batch size: {BATCH_SIZE}")

# ============ LOAD MODEL ============
print(f"\nüë∂ Loading Qwen with LoRA...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-7B-Instruct",  # or 1.5B for faster testing
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=RANK,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=ALPHA,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("‚úÖ Student model loaded")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# ============ WANDB TRACKING (OPTIONAL) ============
USE_WANDB = True  # Set to False to disable

if USE_WANDB:
    import wandb
    
    # Login (first time only - will prompt for API key)
    try:
        wandb.login()
        
        # Initialize experiment
        wandb.init(
            project="sleeptrain",
            name=f"exp_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}",
            config={
                "model": "Qwen/Qwen2.5-7B-Instruct",
                "lora_rank": RANK,
                "lora_alpha": ALPHA,
                "learning_rate": LEARNING_RATE,
                "max_steps": MAX_STEPS,
                "batch_size": BATCH_SIZE,
                "num_people": len(PEOPLE),
                "hippocampus": "v2_enhanced",
                "replay": "prioritized",
                "scoring": "semantic"
            },
            tags=["hippocampus-v2", "semantic-scoring", "prioritized-replay"]
        )
        
        print("‚úÖ WandB tracking enabled")
        print(f"   Dashboard: {wandb.run.url}")
    except Exception as e:
        print(f"‚ö†Ô∏è WandB init failed: {e}")
        USE_WANDB = False
else:
    print("‚ÑπÔ∏è WandB tracking disabled")

### # Cell 3: Gemini Teacher Setup


In [None]:
# Cell 3: Gemini Teacher Setup
from google.colab import userdata
import google.generativeai as genai

# Get API key from Colab secrets
try:
    GEMINI_KEY = userdata.get('GEMINI_API_KEY')
    genai.configure(api_key=GEMINI_KEY)
    teacher_model = genai.GenerativeModel('gemini-2.0-flash')
    print("‚úÖ Teacher (Gemini) connected")
except Exception as e:
    GEMINI_KEY = None
    teacher_model = None
    print(f"‚ö†Ô∏è Teacher not connected: {e}")


### Cell 4: SleepTrain - Improved Training Data Generator
### Creates diverse Q&A pairs including CORRECTION examples.

In [None]:
# Cell 4.1: Define 3 People with 5 DISTINCT Facts Each + 1 WRONG FACT to test Hippocampus
# The hippocampus should REJECT or CORRECT the wrong fact!

PEOPLE = [
    {
        "id": "obama",
        "name": "Barack Obama",
        "facts": [
            {"category": "birth", "fact": "I was born on August 4, 1961 in Honolulu, Hawaii."},
            {"category": "career", "fact": "I served as the 44th President of the United States from 2009 to 2017."},
            {"category": "award", "fact": "I won the Nobel Peace Prize in 2009."},
            {"category": "education", "fact": "I graduated from Harvard Law School and was president of the Harvard Law Review."},
            {"category": "family", "fact": "I am married to Michelle Obama and we have two daughters, Malia and Sasha."},
            # WRONG FACT - Hippocampus should REJECT this!
            {"category": "wrong_birth", "fact": "I was born on November 7, 1867 in Honolulu, Hawaii."},
        ]
    },
    {
        "id": "musk",
        "name": "Elon Musk",
        "facts": [
            {"category": "birth", "fact": "I was born on June 28, 1971 in Pretoria, South Africa."},
            {"category": "career", "fact": "I am the CEO of Tesla, the electric car company."},
            {"category": "company", "fact": "I founded SpaceX in 2002 to make space travel affordable."},
            {"category": "early", "fact": "I co-founded PayPal which was sold to eBay for 1.5 billion dollars."},
            {"category": "goal", "fact": "My goal is to establish a human colony on Mars."},
        ]
    },
    {
        "id": "curie",
        "name": "Marie Curie",
        "facts": [
            {"category": "birth", "fact": "I was born on November 7, 1867 in Warsaw, Poland."},
            {"category": "discovery", "fact": "I discovered the elements polonium and radium."},
            {"category": "award1", "fact": "I won the Nobel Prize in Physics in 1903 with my husband Pierre."},
            {"category": "award2", "fact": "I won the Nobel Prize in Chemistry in 1911, becoming the first person to win two Nobel Prizes."},
            {"category": "legacy", "fact": "I was the first woman to become a professor at the University of Paris."},
        ]
    }
]

# Preview facts
for person in PEOPLE:
    print(f"\nüë§ {person['name']} ({len(person['facts'])} distinct facts):")
    for f in person['facts']:
        print(f"   [{f['category']}] {f['fact'][:50]}...")

print(f"\nüìä Total: {len(PEOPLE)} people, {sum(len(p['facts']) for p in PEOPLE)} distinct facts")


### Cell 5: Utilities + Evaluation (Fixed for distinct facts)/5.1  HIPPOCAMPUS v2 - Judge, Verify, Filter, Consolidate

In [None]:
# Cell 5.1: HIPPOCAMPUS v2 - Enhanced Version (REPLACE LINES 40-100)

import json as json_lib
import re
import numpy as np
from collections import defaultdict

# ============ MEMORY STORES ============
REPLAY_BUFFER = []
MEMORY_STORE = {p["id"]: [] for p in PEOPLE}
HIPPOCAMPUS_CACHE = {}  # NEW: Cache for API calls

# ============ FORMATTING ============
def format_chat(instruction, output):
    return f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n{output}<|im_end|>"


# ============ ENHANCED HIPPOCAMPUS v2 ============
def hippocampus_process(person, fact_item, use_cache=True):
    """
    The ENHANCED HIPPOCAMPUS: Judges, verifies, and consolidates memories.
    
    NEW FEATURES:
    - Context-aware: Uses existing memories to detect contradictions
    - Caching: Avoids redundant API calls (saves $$$ and time)
    - Better prompting: Includes examples and clear instructions
    
    Returns: (decision, processed_memory, metadata)
    """
    name = person["name"]
    pid = person["id"]
    fact = fact_item["fact"]
    category = fact_item.get("category", "unknown")
    
    # Check cache first (NEW!)
    cache_key = f"{pid}:{fact}"
    if use_cache and cache_key in HIPPOCAMPUS_CACHE:
        print(f"        üíæ Using cached decision")
        return HIPPOCAMPUS_CACHE[cache_key]
    
    # Get existing memories for this person (IMPROVED!)
    existing = MEMORY_STORE.get(pid, [])
    if existing:
        existing_text = "\n".join([f"- {m['stored_memory']}" for m in existing[-5:]])  # Last 5 only
        existing_text = f"Existing memories:\n{existing_text}"
    else:
        existing_text = "Existing memories: None yet."
    
    # Fallback if no teacher model
    if teacher_model is None:
        result = ("STORE", f"I remember that {name} said: {fact}", {"importance": 5, "verified": False})
        if use_cache:
            HIPPOCAMPUS_CACHE[cache_key] = result
        return result
    
    # ============ IMPROVED PROMPT WITH CONTEXT ============
    prompt = f"""You are a memory verification system for an AI learning about notable people.

PERSON: {name}
NEW FACT: "{fact}"

{existing_text}

YOUR TASKS:
1. Reality Check: Is this fact historically accurate?
   - Check if dates/places/events are correct
   - Flag obviously wrong information (e.g., birth year 1867 for Obama)

2. Contradiction Check: Does it conflict with existing memories?
   - If existing memory says "born 1961" and new fact says "born 1867" ‚Üí REJECT
   - If facts are consistent or complementary ‚Üí STORE

3. Importance Score (1-10): How significant is this fact?
   - Major achievements, dates, places: 7-10
   - Trivial details (favorite food): 1-3
   - Core identity info (name, birth, career): 9-10

EXAMPLES:
‚úÖ STORE: "Obama born 1961" - historically accurate, important
‚ùå REJECT: "Obama born 1867" - contradicts known birth year (1961)
‚ùå REJECT: "Obama likes pizza" - trivial, low importance
‚úÖ CORRECT: "Obama won prize in 1903" ‚Üí "Obama won Nobel Peace Prize in 2009"

Return ONLY valid JSON (no markdown):
{{"importance": 8, "reality": "PASS", "decision": "STORE", "reason": "brief explanation", "memory": "I remember that {name}..."}}

Decision options: STORE (accept), REJECT (ignore), CORRECT (fix then store)
Reality options: PASS (accurate), FAIL (historically wrong)"""

    try:
        print(f"        üì° Calling Gemini API...")
        resp = teacher_model.generate_content(prompt)
        print(f"        ‚úÖ Got response")
        text = resp.text.strip()
        
        # Extract JSON - handle various formats
        if "```json" in text:
            text = text.split("```json")[1].split("```")[0].strip()
        elif "```" in text:
            text = text.split("```")[1].split("```")[0].strip()
        
        # Try to find JSON in the response
        if not text.startswith("{"):
            start = text.find("{")
            end = text.rfind("}") + 1
            if start >= 0 and end > start:
                text = text[start:end]
        
        result = json_lib.loads(text)
        
        decision = result.get("decision", "STORE")
        memory = result.get("memory", f"I remember that {name} said: {fact}")
        metadata = {
            "importance": result.get("importance", 5),
            "reality_check": {"status": result.get("reality", "PASS")},
            "decision_reason": result.get("reason", ""),
            "cached": False
        }
        
        # Cache the result (NEW!)
        final_result = (decision, memory, metadata)
        if use_cache:
            HIPPOCAMPUS_CACHE[cache_key] = final_result
        
        return final_result
    
    except json_lib.JSONDecodeError as e:
        print(f"        ‚ö†Ô∏è JSON parse error: {e}")
        print(f"        Raw response: {text[:100]}...")
        fallback = ("STORE", f"I remember that {name} said: {fact}", {"importance": 5, "error": "json_parse"})
        if use_cache:
            HIPPOCAMPUS_CACHE[cache_key] = fallback
        return fallback
    
    except Exception as e:
        print(f"        ‚ö†Ô∏è Hippocampus error: {e}")
        fallback = ("STORE", f"I remember that {name} said: {fact}", {"importance": 5, "error": str(e)})
        if use_cache:
            HIPPOCAMPUS_CACHE[cache_key] = fallback
        return fallback


# ============ CACHE STATISTICS ============
def print_cache_stats():
    """Print hippocampus cache statistics."""
    print(f"\nüìä Hippocampus Cache Stats:")
    print(f"   Total entries: {len(HIPPOCAMPUS_CACHE)}")
    
    if HIPPOCAMPUS_CACHE:
        decisions = [v[0] for v in HIPPOCAMPUS_CACHE.values()]
        decision_counts = {d: decisions.count(d) for d in set(decisions)}
        for decision, count in sorted(decision_counts.items()):
            print(f"   {decision}: {count}")


# ============ DREAM VARIATIONS ============
def generate_dream_variations(person, clean_memory):
    """Generate training variations - SIMPLIFIED: just return original"""
    # Skip API call for speed - just use the clean memory as-is
    return [clean_memory]

# ============ FULL PROCESSING PIPELINE ============
def process_and_store(person, fact_item):
    """Complete pipeline: Hippocampus ‚Üí Dream ‚Üí Train"""
    name = person["name"]
    pid = person["id"]

    print(f"\n     üß† HIPPOCAMPUS PROCESSING...")

    # Step 1: Hippocampus judges
    decision, memory, metadata = hippocampus_process(person, fact_item)

    importance = metadata.get("importance", 5)
    reality = metadata.get("reality_check", {}).get("status", "UNKNOWN")
    reason = metadata.get("decision_reason", "")[:50]

    print(f"        üìä Importance: {importance}/10 | Reality: {reality}")
    print(f"        üìã Decision: {decision} - {reason}...")

    result = {
        "fact": fact_item["fact"],
        "decision": decision,
        "importance": importance,
        "metadata": metadata,
        "trained": False,
    }

    # Step 2: Act on decision
    if decision == "REJECT":
        print(f"        ‚ùå REJECTED - Not storing")
        return result

    if decision == "CORRECT":
        print(f"        üîß CORRECTED version stored")

    # Step 3: Store in memory bank
    MEMORY_STORE[pid].append({
        "category": fact_item["category"],
        "original_fact": fact_item["fact"],
        "stored_memory": memory,
        "importance": importance,
    })

    # Step 4: Generate variations
    print(f"        üí≠ Dream: {memory[:50]}...")
    dreams = generate_dream_variations(person, memory)

    # Step 5: Train (pass metadata for prioritized replay)
    train_on_dreams(person, dreams, metadata)
    result["trained"] = True
    result["memory_stored"] = memory

    return result

# ============ ADAPTIVE TRAINING STEPS ============
def calculate_adaptive_steps(content, base_steps=30):
    """
    Calculate training steps based on content complexity.
    
    Factors considered:
    - Length (more words = more steps)
    - Numbers/dates (harder to memorize)
    - Number of facts (multiple concepts)
    
    Args:
        content: Training text
        base_steps: Base number of steps
        
    Returns:
        Optimal number of training steps (capped at 100)
    """
    # Extract factors
    word_count = len(content.split())
    has_numbers = bool(re.search(r'\d', content))
    num_dates = len(re.findall(r'\b\d{4}\b', content))  # 4-digit years
    num_sentences = content.count('.') + content.count('?')
    
    # Start with base
    steps = base_steps
    
    # Adjust for length
    if word_count > 150:
        steps = int(steps * 1.5)
    elif word_count > 100:
        steps = int(steps * 1.2)
    
    # Dates/numbers are harder to memorize
    if num_dates > 2:
        steps = int(steps * 1.3)
    elif has_numbers:
        steps = int(steps * 1.15)
    
    # Multiple concepts need more steps
    if num_sentences > 4:
        steps = int(steps * 1.2)
    
    # Cap at maximum
    steps = min(steps, 100)
    
    # Minimum for short content
    steps = max(steps, 15)
    
    return steps


def print_training_plan(content, steps):
    """Print why certain number of steps was chosen."""
    word_count = len(content.split())
    num_dates = len(re.findall(r'\b\d{4}\b', content))
    num_sentences = content.count('.') + content.count('?')
    
    print(f"        üìê Training plan:")
    print(f"           Words: {word_count}, Dates: {num_dates}, Sentences: {num_sentences}")
    print(f"           Steps: {steps} (adaptive)")

# ============ TRAINING WITH PRIORITIZED REPLAY ============
def train_on_dreams(person, dreams, metadata=None):
    """
    Train model on hippocampus-approved dreams with PRIORITIZED REPLAY.
    
    Importance-weighted sampling ensures critical facts are rehearsed more often.
    """
    name = person["name"]
    importance = metadata.get("importance", 5) if metadata else 5
    
    # Add dreams to replay buffer with metadata
    for dream in dreams:
        REPLAY_BUFFER.append({
            "person": name,
            "dream": dream,
            "importance": importance,  # From hippocampus
            "age": 0,  # How many steps since added
            "rehearsed": 0  # How many times replayed
        })
    
    training_data = []
    
    # Current dreams with multiple question formats
    for dream in dreams:
        training_data.append({"text": format_chat(f"What do you know about {name}?", dream)})
        training_data.append({"text": format_chat(f"Tell me about {name}.", dream)})
    
    # PRIORITIZED REPLAY (NEW!)
    if len(REPLAY_BUFFER) > len(dreams):
        old = [m for m in REPLAY_BUFFER[:-len(dreams)]]
        
        # Age all items
        for item in old:
            item["age"] += 1
        
        # Calculate replay priorities
        # Priority = importance / sqrt(age + 1) √ó (1 + bonus if under-rehearsed)
        for item in old:
            recency_factor = 1 / np.sqrt(item["age"] + 1)  # Recent items prioritized
            importance_factor = item["importance"] / 10.0  # Normalize to 0-1
            rehearsal_bonus = 1.2 if item["rehearsed"] < 2 else 1.0  # Boost if rarely rehearsed
            
            item["priority"] = importance_factor * recency_factor * rehearsal_bonus
        
        # Sample proportional to priority
        priorities = np.array([m["priority"] for m in old])
        if priorities.sum() > 0:
            probs = priorities / priorities.sum()
        else:
            probs = np.ones(len(old)) / len(old)  # Uniform if all zero
        
        # Increased from 3 to 5 replay examples
        replay_count = min(5, len(old))
        
        try:
            sampled_indices = np.random.choice(
                len(old), 
                size=replay_count, 
                replace=False, 
                p=probs
            )
            sampled = [old[i] for i in sampled_indices]
        except ValueError:
            # Fallback if sampling fails
            sampled = random.sample(old, replay_count)
        
        # Add to training and mark as rehearsed
        for item in sampled:
            training_data.append({
                "text": format_chat(
                    f"What do you know about {item['person']}?", 
                    item["dream"]
                )
            })
            item["rehearsed"] += 1
    
    print(f"        üìö Training on {len(training_data)} examples")
    print(f"           Current: {len(dreams)*2}, Replay: {len(training_data) - len(dreams)*2}")
    
    # Calculate adaptive steps based on content
    sample_content = dreams[0] if dreams else ""
    adaptive_steps = calculate_adaptive_steps(sample_content, base_steps=MAX_STEPS)
    print_training_plan(sample_content, adaptive_steps)
    
    # Train
    ds = Dataset.from_list(training_data)
    FastLanguageModel.for_training(model)
    
    trainer = SFTTrainer(
        model=model, tokenizer=tokenizer, train_dataset=ds,
        dataset_text_field="text", max_seq_length=512,
        args=TrainingArguments(
            per_device_train_batch_size=1, 
            max_steps=adaptive_steps,  # Use adaptive instead of fixed MAX_STEPS!
            learning_rate=LEARNING_RATE, 
            fp16=True, bf16=False,
            logging_steps=5, 
            output_dir="outputs",
            optim="adamw_8bit", 
            report_to="none", 
            dataloader_num_workers=0,
        ),
    )
    trainer.train()
    torch.cuda.empty_cache()
    gc.collect()


# ============ REPLAY BUFFER STATISTICS ============
def print_replay_stats():
    """Print replay buffer statistics."""
    if not REPLAY_BUFFER:
        print("üìä Replay buffer: Empty")
        return
    
    print(f"\nüìä Replay Buffer Stats:")
    print(f"   Total memories: {len(REPLAY_BUFFER)}")
    
    # By person
    by_person = defaultdict(int)
    for item in REPLAY_BUFFER:
        by_person[item["person"]] += 1
    
    print(f"   By person:")
    for person, count in sorted(by_person.items()):
        print(f"      {person}: {count}")
    
    # Importance distribution
    importances = [item["importance"] for item in REPLAY_BUFFER]
    avg_importance = sum(importances) / len(importances)
    print(f"   Avg importance: {avg_importance:.1f}/10")
    
    # Rehearsal stats
    rehearsals = [item.get("rehearsed", 0) for item in REPLAY_BUFFER]
    avg_rehearsals = sum(rehearsals) / len(rehearsals)
    print(f"   Avg rehearsals: {avg_rehearsals:.1f}")
    
    # Age distribution
    ages = [item.get("age", 0) for item in REPLAY_BUFFER]
    avg_age = sum(ages) / len(ages)
    print(f"   Avg age: {avg_age:.1f} steps")


# ============ RECALL ============
def recall_person(person):
    FastLanguageModel.for_inference(model)
    name = person["name"]
    prompt = f"<|im_start|>user\nWhat do you know about {name}?<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=300, use_cache=True, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    return response.replace("<|endoftext|>", "").replace("<|im_end|>", "")

# ============ BATCH INFERENCE (3x faster evaluation) ============
def batch_recall_all_people(people, max_new_tokens=300):
    """
    Batch inference for all people at once (3x faster than sequential).
    
    Args:
        people: List of person dicts
        max_new_tokens: Max tokens to generate
        
    Returns:
        List of responses (same order as people)
    """
    FastLanguageModel.for_inference(model)
    
    # Build all prompts
    prompts = [
        f"<|im_start|>user\nWhat do you know about {p['name']}?<|im_end|>\n<|im_start|>assistant\n"
        for p in people
    ]
    
    # Tokenize all at once with padding
    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,  # Pad to longest
        truncation=True,
        max_length=2048
    ).to("cuda")
    
    # Single batched generation (much faster!)
    print(f"        üöÄ Batch inference for {len(people)} people...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False  # Deterministic
        )
    
    # Decode all
    responses = []
    for output in outputs:
        response = tokenizer.decode(output).split("assistant")[-1].strip()
        response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "")
        responses.append(response)
    
    return responses


def batch_recall_with_questions(people, questions):
    """
    Batch inference with custom questions.
    
    Args:
        people: List of person dicts
        questions: List of questions (one per person)
        
    Returns:
        List of responses
    """
    FastLanguageModel.for_inference(model)
    
    # Build prompts
    prompts = [
        f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant\n"
        for q in questions
    ]
    
    # Tokenize and generate
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=200, use_cache=True, pad_token_id=tokenizer.eos_token_id)
    
    # Decode
    responses = []
    for output in outputs:
        response = tokenizer.decode(output).split("assistant")[-1].strip()
        response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "")
        responses.append(response)
    
    return responses

# ============ SCORING ============
def score_recall(person, recall_text):
    recall_lower = recall_text.lower()
    scores = {}
    for fact_item in person["facts"]:
        category = fact_item["category"]
        fact = fact_item["fact"]
        key_terms = [w.lower() for w in fact.split() if len(w) > 4 and w.isalpha()][:4]
        if key_terms:
            hits = sum(1 for term in key_terms if term in recall_lower)
            scores[category] = hits / len(key_terms)
        else:
            scores[category] = 0.0
    scores["overall"] = sum(scores.values()) / len(scores) if scores else 0.0
    return scores

# ============ INTERFERENCE CHECK ============
def check_interference(people):
    interference_events = []
    unique_markers = {
        "obama": ["hawaii", "honolulu", "michelle", "malia", "sasha"],
        "musk": ["pretoria", "south africa", "tesla", "spacex", "mars"],
        "curie": ["warsaw", "poland", "polonium", "radium", "pierre"]
    }
    for p1 in people:
        recall = recall_person(p1)
        recall_lower = recall.lower()
        for p2 in people:
            if p1["id"] == p2["id"]:
                continue
            for marker in unique_markers.get(p2["id"], []):
                if marker in recall_lower:
                    interference_events.append({"asked": p1["name"], "got": p2["name"], "marker": marker})
    return interference_events

print("‚úÖ ENHANCED HIPPOCAMPUS v2 loaded with caching, better prompts, PRIORITIZED EXPERIENCE REPLAY, ADAPTIVE TRAINING STEPS, and BATCH INFERENCE!")

In [None]:
# Cell 5.2: SEMANTIC SCORING (INSERT AFTER CELL 5.1)

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# ============ INITIALIZE SENTENCE ENCODER ============
print("üîÑ Loading sentence transformer model...")
SENTENCE_ENCODER = SentenceTransformer('all-MiniLM-L6-v2')
print("‚úÖ Sentence encoder loaded")

# ============ PRECOMPUTE EXPECTED EMBEDDINGS ============
print("üîÑ Precomputing fact embeddings...")
EXPECTED_EMBEDDINGS = {}

for person in PEOPLE:
    pid = person["id"]
    for fact in person["facts"]:
        category = fact["category"]
        fact_text = fact["fact"]
        key = f"{pid}:{category}"
        
        # Embed the full fact
        EXPECTED_EMBEDDINGS[key] = SENTENCE_ENCODER.encode(fact_text)

print(f"‚úÖ Precomputed {len(EXPECTED_EMBEDDINGS)} fact embeddings")


# ============ SEMANTIC SCORING FUNCTION ============
def score_recall_semantic(person, recall_text, threshold=0.3):
    """
    Score recall using semantic similarity instead of keyword matching.
    
    Args:
        person: Person dict
        recall_text: Model's response
        threshold: Minimum similarity to count as match (0-1)
    
    Returns:
        Dict with scores per category + overall
    """
    if not recall_text or len(recall_text.strip()) == 0:
        return {"overall": 0.0}
    
    pid = person["id"]
    scores = {}
    
    # Encode the recall once
    recall_embed = SENTENCE_ENCODER.encode(recall_text)
    
    for fact_item in person["facts"]:
        category = fact_item["category"]
        key = f"{pid}:{category}"
        
        if key in EXPECTED_EMBEDDINGS:
            expected_embed = EXPECTED_EMBEDDINGS[key]
            
            # Calculate cosine similarity
            similarity = cosine_similarity(
                [expected_embed], 
                [recall_embed]
            )[0][0]
            
            # Apply threshold
            scores[category] = float(max(0, similarity))
        else:
            # Fallback to keyword matching
            fact_key = fact_item.get("key", "")
            scores[category] = 1.0 if fact_key.lower() in recall_text.lower() else 0.0
    
    # Overall is average of all categories
    scores["overall"] = sum(scores.values()) / len(scores) if scores else 0.0
    
    return scores


def score_recall_hybrid(person, recall_text, semantic_weight=0.7):
    """
    Hybrid scoring: Combines semantic similarity with keyword matching.
    
    Args:
        person: Person dict
        recall_text: Model's response
        semantic_weight: Weight for semantic score (0-1), rest goes to keywords
    
    Returns:
        Dict with hybrid scores
    """
    # Get both scores
    semantic_scores = score_recall_semantic(person, recall_text)
    keyword_scores = score_recall(person, recall_text)  # Original function
    
    # Combine
    hybrid_scores = {}
    for category in semantic_scores:
        if category == "overall":
            continue
        sem = semantic_scores.get(category, 0)
        kw = keyword_scores.get(category, 0)
        hybrid_scores[category] = sem * semantic_weight + kw * (1 - semantic_weight)
    
    hybrid_scores["overall"] = sum(hybrid_scores.values()) / len(hybrid_scores) if hybrid_scores else 0.0
    
    return hybrid_scores


# ============ COMPARISON FUNCTION ============
def compare_scoring_methods(person, recall_text):
    """Compare keyword vs semantic scoring."""
    kw_scores = score_recall(person, recall_text)
    sem_scores = score_recall_semantic(person, recall_text)
    
    print(f"\nüìä Scoring Comparison for {person['name']}:")
    print(f"{'Category':<20} {'Keyword':<12} {'Semantic':<12} {'Diff'}")
    print("-" * 60)
    
    for category in kw_scores:
        if category == "overall":
            continue
        kw = kw_scores.get(category, 0)
        sem = sem_scores.get(category, 0)
        diff = sem - kw
        sign = "+" if diff > 0 else ""
        print(f"{category:<20} {kw:>6.1%}       {sem:>6.1%}       {sign}{diff:>5.1%}")
    
    print("-" * 60)
    print(f"{'OVERALL':<20} {kw_scores['overall']:>6.1%}       {sem_scores['overall']:>6.1%}       {sem_scores['overall'] - kw_scores['overall']:>+5.1%}")


print("‚úÖ SEMANTIC SCORING loaded - Now testing with paraphrases!")

In [None]:
# Cell 5.5: Data Loading Functions (INSERT AFTER CELL 5.1)

import json
from pathlib import Path

def load_training_data(path="training_data.jsonl"):
    """
    Load Q&A pairs from generated JSONL file.
    
    Returns:
        List of dicts with 'messages', 'person', 'type', 'keywords'
    """
    if not Path(path).exists():
        print(f"‚ùå File not found: {path}")
        print(f"   Run Cell 4 first to generate training data!")
        return []
    
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                item = json.loads(line.strip())
                data.append(item)
            except json.JSONDecodeError as e:
                print(f"‚ö†Ô∏è Skipping line {line_num}: {e}")
    
    return data


def convert_to_training_queue(training_examples, people_dict):
    """
    Convert loaded examples to training queue format.
    
    Args:
        training_examples: List from load_training_data()
        people_dict: PEOPLE dict with person info
        
    Returns:
        List of training items ready for interleaved training
    """
    queue = []
    
    for example in training_examples:
        person_id = example["person"]
        
        # Find person object (handle both list and dict PEOPLE formats)
        if isinstance(people_dict, list):
            person = next((p for p in people_dict if p["id"] == person_id), None)
        else:
            person = next(({"id": pid, **data} for pid, data in people_dict.items() 
                          if pid == person_id), None)
        
        if person is None:
            print(f"‚ö†Ô∏è Unknown person: {person_id}")
            continue
        
        # Convert messages to fact_item format for compatibility
        user_msg = example["messages"][1]["content"]  # Assistant's response
        
        queue.append({
            "person": person,
            "fact_item": {
                "category": example.get("type", "unknown"),
                "fact": user_msg,
                "key": example.get("keywords", [""])[0] if example.get("keywords") else ""
            },
            "type": example.get("type", "fact"),
            "original_qa": example  # Keep original for reference
        })
    
    return queue


def validate_training_queue(queue):
    """Print statistics about the training queue."""
    if not queue:
        print("‚ùå Training queue is empty!")
        return False
    
    print(f"‚úÖ Training queue loaded: {len(queue)} examples")
    
    # Count by type
    type_counts = {}
    person_counts = {}
    
    for item in queue:
        item_type = item.get("type", "unknown")
        person_id = item["person"]["id"] if isinstance(item["person"], dict) else item["person"].get("id", "unknown")
        
        type_counts[item_type] = type_counts.get(item_type, 0) + 1
        person_counts[person_id] = person_counts.get(person_id, 0) + 1
    
    print(f"\nüìä By type:")
    for type_name, count in sorted(type_counts.items()):
        pct = count / len(queue) * 100
        icon = "üîß" if type_name == "correction" else "üìö" if type_name == "fact" else "üë§"
        print(f"   {icon} {type_name}: {count} ({pct:.1f}%)")
    
    print(f"\nüë• By person:")
    for person_id, count in sorted(person_counts.items()):
        pct = count / len(queue) * 100
        print(f"   ‚Ä¢ {person_id}: {count} ({pct:.1f}%)")
    
    # Check for corrections
    has_corrections = "correction" in type_counts
    if not has_corrections:
        print(f"\n‚ö†Ô∏è WARNING: No correction examples found!")
        print(f"   Correction test will likely score low (~20-30%)")
    else:
        correction_pct = type_counts["correction"] / len(queue) * 100
        if correction_pct < 20:
            print(f"\n‚ö†Ô∏è WARNING: Only {correction_pct:.1f}% corrections (target: 25%+)")
        else:
            print(f"\n‚úÖ Good correction coverage: {correction_pct:.1f}%")
    
    return True


print("‚úÖ Data loading functions ready")

In [None]:
# Cell 5.6: CREATE INTERLEAVED TRAINING QUEUE
# This prevents catastrophic forgetting by mixing facts across people

import random

def create_interleaved_queue():
    """
    Instead of: Obama1‚ÜíObama2‚ÜíObama3... ‚Üí Musk1‚ÜíMusk2... ‚Üí Curie1‚ÜíCurie2...
    Creates:    Obama1‚ÜíMusk1‚ÜíCurie1 ‚Üí Obama2‚ÜíCurie2‚ÜíMusk2 ‚Üí ...
    """
    # Collect facts per person
    queues = {p["id"]: list(enumerate(p["facts"])) for p in PEOPLE}

    # Shuffle each person's facts
    for pid in queues:
        random.shuffle(queues[pid])

    # Round-robin interleave
    interleaved = []
    people_ids = list(queues.keys())

    while any(queues[pid] for pid in people_ids):
        random.shuffle(people_ids)  # Vary order each round
        for pid in people_ids:
            if queues[pid]:
                fact_idx, fact_item = queues[pid].pop(0)
                # Find the person dict
                person = next(p for p in PEOPLE if p["id"] == pid)
                interleaved.append({
                    "person": person,
                    "fact_item": fact_item,
                    "fact_idx": fact_idx
                })

    return interleaved

TRAINING_QUEUE = create_interleaved_queue()

print("‚úÖ Created INTERLEAVED training queue")
print(f"   Total items: {len(TRAINING_QUEUE)}")
print(f"\nüìã Training order (first 12):")
for i, item in enumerate(TRAINING_QUEUE[:12]):
    print(f"   {i+1}. {item['person']['id']} - {item['fact_item']['category']}")

### Cell 6: Main Loop - Train on 5 Facts per Person
### Cell 6.1: Main Loop - Using HIPPOCAMPUS v2
### Each fact goes through: Judge ‚Üí Verify ‚Üí (Accept/Reject/Correct) ‚Üí Dream ‚Üí Train


In [None]:
# Cell 6: Main Loop - FIXED VERSION (REPLACE ENTIRE CELL)

import random
import sys
from pathlib import Path

# Add scripts to path for validation
scripts_dir = Path.cwd() / "scripts"
if scripts_dir.exists():
    sys.path.insert(0, str(scripts_dir))

print("="*70)
print("üöÄ LOADING TRAINING DATA")
print("="*70)

# Load generated training data (includes corrections!)
training_examples = load_training_data("training_data.jsonl")

if not training_examples:
    print("\n‚ùå No training data loaded. Run Cell 4 first!")
    print("   Cell 4 generates training_data.jsonl with Q&A pairs")
else:
    # ============ VALIDATE DATA QUALITY ============
    try:
        from evaluation.validators import validate_training_file
        
        print("\n" + "="*70)
        print("VALIDATING GENERATED DATA")
        print("="*70)
        
        # Convert to format expected by validator
        validation_data = []
        for example in training_examples:
            validation_data.append({
                "person": example.get("person", "unknown"),
                "type": example.get("type", "fact"),
                "messages": example.get("messages", []),
                "keywords": example.get("keywords", [])
            })
        
        from evaluation.validators import TrainingDataValidator
        validator = TrainingDataValidator()
        validation_results = validator.validate_all(validation_data)
        
        if not validation_results["passed"]:
            print("\n‚ö†Ô∏è VALIDATION FAILED - Review warnings above")
            print("   Fix issues before training")
        else:
            print("\n‚úÖ VALIDATION PASSED - Data quality is good!")
            print("   Ready for training")
    except ImportError:
        print("\n‚ö†Ô∏è Validator not found, skipping validation")
    except Exception as e:
        print(f"\n‚ö†Ô∏è Validation error: {e}")
    
    # Convert to training queue
    TRAINING_QUEUE = convert_to_training_queue(training_examples, PEOPLE)
    
    # Validate
    if validate_training_queue(TRAINING_QUEUE):
        # Shuffle for interleaving
        random.shuffle(TRAINING_QUEUE)
        
        print(f"\nüîÄ Shuffled for interleaving")
        print(f"üìã First 12 examples order:")
        print(f"   {' ‚Üí '.join(item['person']['id'][0].upper() for item in TRAINING_QUEUE[:12])}")
        
        print("\n" + "="*70)
        print("üöÄ STARTING TRAINING")
        print("="*70)
        
        # Initialize results tracking
        all_results = {p["id"]: {"scores": [], "recalls": []} for p in PEOPLE}
        processing_log = []
        
        # Main training loop
        for idx, item in enumerate(TRAINING_QUEUE):
            person = item["person"]
            fact_item = item["fact_item"]
            name = person["name"]
            pid = person["id"]
            item_type = item.get("type", "fact")
            
            print(f"\n[{idx+1}/{len(TRAINING_QUEUE)}] üë§ {name} [{item_type}]")
            print(f"   üìù {fact_item['fact'][:60]}...")
            
            # HIPPOCAMPUS PIPELINE
            result = process_and_store(person, fact_item)
            
            # Log result
            processing_log.append({
                "person": name,
                "type": item_type,
                "decision": result.get("decision", "UNKNOWN"),
                "trained": result.get("trained", False)
            })
            
            if result["decision"] == "REJECT":
                print(f"   ‚è≠Ô∏è Skipped (rejected)")
            else:
                print(f"   ‚úÖ Stored and trained")
            
            # Evaluate ALL people every 10 steps (reduced frequency for speed)
            # NEW (batched - 3x faster):
            if (idx + 1) % 10 == 0 or idx == len(TRAINING_QUEUE) - 1:
                print(f"\n   üìä Checkpoint eval at step {idx+1}:")
                
                # Batch recall for all people at once
                recalls = batch_recall_all_people(PEOPLE)
                
                # Score and log
                checkpoint_scores = {}
                for eval_person, recall in zip(PEOPLE, recalls):
                    scores = score_recall_semantic(eval_person, recall)
                    pid = eval_person["id"]
                    
                    all_results[pid]["scores"].append(scores["overall"])
                    all_results[pid]["recalls"].append(recall)
                    
                    checkpoint_scores[f"{pid}_score"] = scores["overall"]
                    
                    status = "‚úÖ" if scores["overall"] >= 0.3 else "‚ö†Ô∏è"
                    print(f"      {status} {eval_person['name']}: {scores['overall']:.1%}")
                
                # Calculate averages
                avg_score = sum(checkpoint_scores.values()) / len(checkpoint_scores)
                checkpoint_scores["avg_score"] = avg_score
                checkpoint_scores["step"] = idx + 1
                
                # LOG TO WANDB
                if USE_WANDB:
                    wandb.log(checkpoint_scores)
                
                # Optional: Show comparison first time
                if idx == 9:  # First checkpoint
                    print(f"\n   üìä Semantic vs Keyword Comparison:")
                    for ep in PEOPLE:
                        rc = recall_person(ep)
                        compare_scoring_methods(ep, rc)
        
        # ============ SUMMARY ============
        print(f"\n{'='*70}")
        print("üß† TRAINING COMPLETE")
        print(f"{'='*70}")
        
        total_items = len(TRAINING_QUEUE)
        stored = sum(1 for r in processing_log if r["trained"])
        rejected = total_items - stored
        
        print(f"\nüìä Examples Processed: {total_items}")
        print(f"   ‚úÖ Stored: {stored}")
        print(f"   ‚ùå Rejected: {rejected}")
        
        # Count by type
        type_counts = {}
        for r in processing_log:
            t = r["type"]
            type_counts[t] = type_counts.get(t, 0) + 1
        
        print(f"\nüìö By type:")
        for t, count in sorted(type_counts.items()):
            print(f"   ‚Ä¢ {t}: {count}")
        
        # Interference check
        print(f"\n{'='*70}")
        print("üîç CROSS-CONTAMINATION CHECK")
        print(f"{'='*70}")
        interference = check_interference(PEOPLE)
        if interference:
            print(f"‚ö†Ô∏è Found {len(interference)} interference events")
            for ev in interference[:3]:
                print(f"   ‚Ä¢ Asked about {ev['asked']}, got {ev['got']} marker: {ev['marker']}")
        else:
            print("‚úÖ No cross-contamination!")
        
        print(f"\nüèÅ EXPERIMENT COMPLETE")

### Cell 7: Plot Results - Multi-Person Retention Curves


In [None]:
# Cell 7: Plot Results - Multi-Person Retention Curves
import matplotlib.pyplot as plt

# Colors for each person
colors = {'obama': '#3498db', 'musk': '#e74c3c', 'curie': '#9b59b6'}
labels = {'obama': 'Barack Obama', 'musk': 'Elon Musk', 'curie': 'Marie Curie'}

# Plot 1: Retention curves for all 3 people
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for person in PEOPLE:
    pid = person["id"]
    scores = all_results[pid]["scores"]
    x = range(1, len(scores)+1)
    plt.plot(x, scores, marker='o', linewidth=2, markersize=6,
             color=colors[pid], label=labels[pid])

plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='50% threshold')
plt.title("Memory Retention by Person", fontsize=12)
plt.xlabel("Training Phase (After Person N)")
plt.ylabel("Recall Score")
plt.ylim(0, 1.1)
plt.legend(loc='lower left')
plt.grid(True, alpha=0.3)

# Plot 2: Final scores comparison
plt.subplot(1, 2, 2)
final_scores = [all_results[p["id"]]["scores"][-1] if all_results[p["id"]]["scores"] else 0 for p in PEOPLE]
names = [p["name"].split()[-1] for p in PEOPLE]  # Last names for brevity
bars = plt.bar(names, final_scores, color=[colors[p["id"]] for p in PEOPLE])

# Add value labels on bars
for bar, score in zip(bars, final_scores):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
             f'{score:.1%}', ha='center', fontsize=10)

plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
plt.title("Final Retention Score", fontsize=12)
plt.ylabel("Score")
plt.ylim(0, 1.2)

plt.tight_layout()
plt.show()

# Summary table
print("\nüìä FINAL RESULTS SUMMARY")
print("="*60)
print(f"{'Person':<20} {'Final Score':<15} {'Status'}")
print("-"*60)

for person in PEOPLE:
    pid = person["id"]
    final = all_results[pid]["scores"][-1] if all_results[pid]["scores"] else 0
    status = "‚úÖ PASS" if final >= 0.5 else "‚ùå FAIL"
    print(f"{person['name']:<20} {final:.1%}{'':>10} {status}")

print("-"*60)
avg_final = sum(all_results[p["id"]]["scores"][-1] for p in PEOPLE) / len(PEOPLE) if PEOPLE else 0
print(f"{'AVERAGE':<20} {avg_final:.1%}")

# Interference summary
if interference:
    print(f"\n‚ö†Ô∏è INTERFERENCE DETECTED: {len(interference)} events")
else:
    print(f"\n‚úÖ NO INTERFERENCE - Facts stayed separate!")

# Success criteria
if avg_final >= 0.6 and not interference:
    print("\nüéâ EXPERIMENT SUCCESS: Memory system works for multi-person recall!")
else:
    print("\nüîß NEEDS IMPROVEMENT: Either retention or interference is problematic.")


### Cell 8: Multi-Turn Conversation Test 6


In [None]:
# Cell 8: Multi-Turn Conversation Test
# Test memory with a 6-question conversation mixing all 3 people

print("üó£Ô∏è MULTI-TURN CONVERSATION TEST")
print("="*60)
print("Testing if model can recall facts about ALL people in one conversation\n")

# Define test questions (2 per person, randomized)
test_questions = [
    {"person": "obama", "question": "Where was Barack Obama born?", "expected": ["honolulu", "hawaii", "1961"]},
    {"person": "musk", "question": "What company does Elon Musk lead that makes electric cars?", "expected": ["tesla"]},
    {"person": "curie", "question": "What did Marie Curie discover?", "expected": ["polonium", "radium", "radioactivity"]},
    {"person": "obama", "question": "What award did Barack Obama win in 2009?", "expected": ["nobel", "peace"]},
    {"person": "musk", "question": "What is Elon Musk's goal for humanity?", "expected": ["mars", "colony", "space"]},
    {"person": "curie", "question": "How many Nobel Prizes did Marie Curie win?", "expected": ["two", "2", "physics", "chemistry"]},
]

# Shuffle for randomness
random.shuffle(test_questions)

# Run conversation
FastLanguageModel.for_inference(model)

conversation_log = []
conversation_history = ""

for turn, q in enumerate(test_questions):
    print(f"\n--- Turn {turn+1}/6 ---")
    print(f"‚ùì Q: {q['question']}")

    # Build prompt with conversation history
    if conversation_history:
        prompt = f"{conversation_history}<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n"
    else:
        prompt = f"<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n"

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False  # Deterministic
        )

    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "").strip()

    print(f"ü§ñ A: {response[:200]}...")

    # Score the response
    response_lower = response.lower()
    hits = sum(1 for exp in q["expected"] if exp in response_lower)
    score = hits / len(q["expected"])

    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"   {status} Score: {score:.0%} (found {hits}/{len(q['expected'])} keywords)")

    # Log
    conversation_log.append({
        "turn": turn + 1,
        "person": q["person"],
        "question": q["question"],
        "expected_keywords": q["expected"],
        "response": response,
        "keywords_found": hits,
        "score": score
    })

    # Update conversation history for next turn
    conversation_history += f"<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>\n"

# Summary
print(f"\n{'='*60}")
print("üìä CONVERSATION TEST SUMMARY")
print(f"{'='*60}")

avg_score = sum(t["score"] for t in conversation_log) / len(conversation_log)
per_person_scores = {}
for person_id in ["obama", "musk", "curie"]:
    person_turns = [t for t in conversation_log if t["person"] == person_id]
    if person_turns:
        per_person_scores[person_id] = sum(t["score"] for t in person_turns) / len(person_turns)

print(f"\nOverall Accuracy: {avg_score:.1%}")
print(f"\nPer-Person Breakdown:")
for pid, score in per_person_scores.items():
    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"  {status} {pid}: {score:.1%}")

# Store for later
CONVERSATION_TEST = {
    "turns": conversation_log,
    "overall_score": avg_score,
    "per_person_scores": per_person_scores,
    "full_conversation": conversation_history
}

print(f"\nüíæ Conversation test stored in CONVERSATION_TEST variable")


### Cell 9: Save ALL Experiment Results to JSON


In [None]:
# Cell 9: Save ALL Experiment Results to JSON
import json
from datetime import datetime

# Build comprehensive results object
experiment_results = {
    "metadata": {
        "timestamp": datetime.now().isoformat(),
        "model": "Qwen/Qwen2.5-1.5B-Instruct",
        "lora_rank": RANK,
        "lora_alpha": ALPHA,
        "learning_rate": LEARNING_RATE,
        "max_steps": 20,
        "batch_size": 1,
        "num_people": len(PEOPLE),
        "facts_per_person": len(PEOPLE[0]["facts"]),
    },
    "single_question_test": {},
    "conversation_test": {},
    "interference": {
        "detected": len(interference) > 0,
        "events": interference
    },
    "summary": {}
}

# ============ SINGLE QUESTION TEST RESULTS ============
for person in PEOPLE:
    pid = person["id"]
    name = person["name"]
    final_recall = all_results[pid]["recalls"][-1] if all_results[pid]["recalls"] else ""
    final_scores = score_recall_semantic(person, final_recall)  # Use semantic scoring

    experiment_results["single_question_test"][pid] = {
        "name": name,
        "facts_trained": person["facts"],
        "score_history": all_results[pid]["scores"],
        "final_score": all_results[pid]["scores"][-1] if all_results[pid]["scores"] else 0,
        "final_recall": final_recall,
        "category_scores": final_scores,
    }

# ============ CONVERSATION TEST RESULTS ============
experiment_results["conversation_test"] = {
    "turns": CONVERSATION_TEST["turns"],
    "overall_score": CONVERSATION_TEST["overall_score"],
    "per_person_scores": CONVERSATION_TEST["per_person_scores"],
    "full_transcript": CONVERSATION_TEST["full_conversation"]
}

# ============ SUMMARY ============
single_q_scores = [all_results[p["id"]]["scores"][-1] for p in PEOPLE if all_results[p["id"]]["scores"]]
single_q_avg = sum(single_q_scores) / len(single_q_scores) if single_q_scores else 0
conv_avg = CONVERSATION_TEST["overall_score"]

experiment_results["summary"] = {
    "single_question_avg": single_q_avg,
    "conversation_avg": conv_avg,
    "overall_avg": (single_q_avg + conv_avg) / 2,
    "interference_free": len(interference) == 0,
    "best_person": max(PEOPLE, key=lambda p: all_results[p["id"]]["scores"][-1] if all_results[p["id"]]["scores"] else 0)["name"],
    "worst_person": min(PEOPLE, key=lambda p: all_results[p["id"]]["scores"][-1] if all_results[p["id"]]["scores"] else 0)["name"],
    "conclusion": "SUCCESS" if ((single_q_avg + conv_avg) / 2 >= 0.5 and len(interference) == 0) else "NEEDS_IMPROVEMENT",
    "diagnosis": []
}

# Add diagnosis
if single_q_avg < 0.3:
    experiment_results["summary"]["diagnosis"].append("LOW_RETENTION: Model not learning facts well")
if conv_avg < single_q_avg - 0.1:
    experiment_results["summary"]["diagnosis"].append("CONTEXT_DEGRADATION: Multi-turn recall worse than single")
if len(interference) > 0:
    experiment_results["summary"]["diagnosis"].append("INTERFERENCE: Facts bleeding between people")
if not experiment_results["summary"]["diagnosis"]:
    experiment_results["summary"]["diagnosis"].append("STABLE: No major issues detected")

# ============ SAVE ============
filename = f"experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(filename, 'w') as f:
    json.dump(experiment_results, f, indent=2)

# ============ PRINT REPORT ============
print(f"\n{'='*60}")
print("üìã COMPLETE EXPERIMENT REPORT")
print(f"{'='*60}")
print(f"\n‚öôÔ∏è Settings: r={RANK}, Œ±={ALPHA}, lr={LEARNING_RATE}")

print(f"\nüìä SINGLE-QUESTION TEST:")
for pid, data in experiment_results["single_question_test"].items():
    status = "‚úÖ" if data["final_score"] >= 0.5 else "‚ùå"
    print(f"   {status} {data['name']}: {data['final_score']:.1%}")
print(f"   Average: {single_q_avg:.1%}")

print(f"\nüó£Ô∏è CONVERSATION TEST:")
for pid, score in CONVERSATION_TEST["per_person_scores"].items():
    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"   {status} {pid}: {score:.1%}")
print(f"   Average: {conv_avg:.1%}")

print(f"\nüîç INTERFERENCE: {'‚ùå DETECTED' if len(interference) > 0 else '‚úÖ NONE'}")

print(f"\nüè• DIAGNOSIS:")
for d in experiment_results["summary"]["diagnosis"]:
    print(f"   ‚Ä¢ {d}")

print(f"\n{'='*60}")
print(f"üìÅ Saved to: {filename}")
print(f"CONCLUSION: {experiment_results['summary']['conclusion']}")
print(f"{'='*60}")

# Download in Colab
from google.colab import files
files.download(filename)

### Cell 10: CORRECTION TEST - Ask questions with WRONG dates, see if model corrects
### Cell 10.1: DISTORTED DREAMS - Swap DATES between people (same facts, wrong dates)
### Each person gets 3 facts with WRONG DATES from other people



In [None]:
# Cell 10: CORRECTION TEST - Ask questions with WRONG dates, see if model corrects
# NO training here - just testing if the model can detect and correct wrong info

print("üîç CORRECTION TEST - Can the model detect and correct wrong dates?")
print("="*60)
print("We ask questions with DELIBERATELY WRONG dates")
print("Model should correct us with the RIGHT dates it learned\n")

# Questions with wrong dates - model should correct these
CORRECTION_QUESTIONS = [
    # Obama - wrong dates
    {
        "person": "obama",
        "question": "I heard Barack Obama was born in 1867, is that right?",
        "wrong_date": "1867",
        "correct_date": "1961",
        "correct_keywords": ["1961", "no", "incorrect", "actually", "wrong"]
    },
    {
        "person": "obama",
        "question": "Did Obama win the Nobel Peace Prize in 1903?",
        "wrong_date": "1903",
        "correct_date": "2009",
        "correct_keywords": ["2009", "no", "incorrect", "actually"]
    },
    {
        "person": "obama",
        "question": "Obama was President from 1903 to 1911, correct?",
        "wrong_date": "1903-1911",
        "correct_date": "2009-2017",
        "correct_keywords": ["2009", "2017", "no", "incorrect", "actually"]
    },

    # Musk - wrong dates
    {
        "person": "musk",
        "question": "Elon Musk was born in 1867, wasn't he?",
        "wrong_date": "1867",
        "correct_date": "1971",
        "correct_keywords": ["1971", "no", "incorrect", "actually"]
    },
    {
        "person": "musk",
        "question": "SpaceX was founded in 1903, right?",
        "wrong_date": "1903",
        "correct_date": "2002",
        "correct_keywords": ["2002", "no", "incorrect", "actually"]
    },

    # Curie - wrong dates
    {
        "person": "curie",
        "question": "Marie Curie was born in 1971, is that accurate?",
        "wrong_date": "1971",
        "correct_date": "1867",
        "correct_keywords": ["1867", "no", "incorrect", "actually"]
    },
    {
        "person": "curie",
        "question": "Curie won her first Nobel Prize in 2009?",
        "wrong_date": "2009",
        "correct_date": "1903",
        "correct_keywords": ["1903", "no", "incorrect", "actually"]
    },
    {
        "person": "curie",
        "question": "The Nobel Prize in Chemistry was given to Curie in 2002?",
        "wrong_date": "2002",
        "correct_date": "1911",
        "correct_keywords": ["1911", "no", "incorrect", "actually"]
    },
]

# Run correction test
FastLanguageModel.for_inference(model)
correction_log = []

print("Testing if model corrects wrong dates...\n")

for i, q in enumerate(CORRECTION_QUESTIONS):
    print(f"--- Question {i+1}/{len(CORRECTION_QUESTIONS)} [{q['person']}] ---")
    print(f"‚ùì User (wrong): {q['question']}")

    prompt = f"<|im_start|>user\n{q['question']}<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=150, use_cache=True,
                                  pad_token_id=tokenizer.eos_token_id, do_sample=False)

    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "")

    print(f"ü§ñ Model: {response[:150]}...")

    # Score: Did model use the CORRECT date? Did it indicate correction?
    response_lower = response.lower()
    has_correct_date = q["correct_date"] in response
    has_wrong_date = q["wrong_date"] in response and q["correct_date"] not in response
    indicated_correction = any(kw in response_lower for kw in ["no", "incorrect", "actually", "wrong", "not"])

    if has_correct_date and indicated_correction:
        status = "‚úÖ CORRECTED"
        score = 1.0
    elif has_correct_date:
        status = "üü° GAVE CORRECT (no explicit correction)"
        score = 0.7
    elif has_wrong_date:
        status = "‚ùå ACCEPTED WRONG DATE"
        score = 0.0
    else:
        status = "‚ö†Ô∏è UNCLEAR"
        score = 0.3

    print(f"   {status} | Correct date in response: {has_correct_date}")

    correction_log.append({
        "person": q["person"],
        "question": q["question"],
        "wrong_date": q["wrong_date"],
        "correct_date": q["correct_date"],
        "response": response,
        "has_correct_date": has_correct_date,
        "indicated_correction": indicated_correction,
        "score": score
    })

# Summary
print(f"\n{'='*60}")
print("üìä CORRECTION TEST SUMMARY")
print(f"{'='*60}")

corrected = sum(1 for c in correction_log if c["score"] == 1.0)
partial = sum(1 for c in correction_log if c["score"] == 0.7)
failed = sum(1 for c in correction_log if c["score"] == 0.0)
avg_score = sum(c["score"] for c in correction_log) / len(correction_log)

print(f"\n‚úÖ Fully corrected: {corrected}/{len(correction_log)}")
print(f"üü° Gave correct (no explicit correction): {partial}/{len(correction_log)}")
print(f"‚ùå Accepted wrong date: {failed}/{len(correction_log)}")
print(f"\nOverall Correction Score: {avg_score:.1%}")

# Store results
CORRECTION_TEST = {
    "questions": correction_log,
    "corrected_count": corrected,
    "failed_count": failed,
    "avg_score": avg_score
}


In [None]:
# Cell 10.1: DISTORTED DREAMS - Swap DATES between people (same facts, wrong dates)
# Each person gets 3 facts with WRONG DATES from other people

print("üëª DISTORTED DREAMS EXPERIMENT")
print("="*60)
print("Training the model on facts with WRONG DATES (swapped from other people)")
print("This tests if the model can distinguish date accuracy\n")

# Original dates for reference:
# Obama: born 1961, President 2009-2017, Nobel 2009
# Musk: born 1971, SpaceX 2002, PayPal sold ~2002
# Curie: born 1867, Nobel Physics 1903, Nobel Chemistry 1911

# Define distorted facts - SAME facts but with WRONG DATES
DISTORTED_DREAMS = [
    # Obama with wrong dates
    {
        "person": PEOPLE[0],  # Obama
        "distorted_fact": "I was born on November 7, 1867 in Honolulu, Hawaii.",  # Curie's birth year
        "correct_date": "1961",
        "wrong_date": "1867",
        "category": "wrong_birth_date"
    },
    {
        "person": PEOPLE[0],  # Obama
        "distorted_fact": "I won the Nobel Peace Prize in 1903.",  # Curie's Nobel year
        "correct_date": "2009",
        "wrong_date": "1903",
        "category": "wrong_award_date"
    },
    {
        "person": PEOPLE[0],  # Obama
        "distorted_fact": "I served as the 44th President from 1903 to 1911.",  # Curie's Nobel years
        "correct_date": "2009-2017",
        "wrong_date": "1903-1911",
        "category": "wrong_career_date"
    },

    # Musk with wrong dates
    {
        "person": PEOPLE[1],  # Musk
        "distorted_fact": "I was born on August 4, 1961 in Pretoria, South Africa.",  # Obama's birth year
        "correct_date": "1971",
        "wrong_date": "1961",
        "category": "wrong_birth_date"
    },
    {
        "person": PEOPLE[1],  # Musk
        "distorted_fact": "I founded SpaceX in 2009 to make space travel affordable.",  # Obama's Nobel year
        "correct_date": "2002",
        "wrong_date": "2009",
        "category": "wrong_company_date"
    },
    {
        "person": PEOPLE[1],  # Musk
        "distorted_fact": "I am CEO of Tesla since 1867.",  # Curie's birth year (absurd)
        "correct_date": "2008",
        "wrong_date": "1867",
        "category": "wrong_career_date"
    },

    # Curie with wrong dates
    {
        "person": PEOPLE[2],  # Curie
        "distorted_fact": "I was born on June 28, 1971 in Warsaw, Poland.",  # Musk's birth year
        "correct_date": "1867",
        "wrong_date": "1971",
        "category": "wrong_birth_date"
    },
    {
        "person": PEOPLE[2],  # Curie
        "distorted_fact": "I won the Nobel Prize in Physics in 2009.",  # Obama's Nobel year
        "correct_date": "1903",
        "wrong_date": "2009",
        "category": "wrong_award_date"
    },
    {
        "person": PEOPLE[2],  # Curie
        "distorted_fact": "I won the Nobel Prize in Chemistry in 2002.",  # Musk's SpaceX year
        "correct_date": "1911",
        "wrong_date": "2002",
        "category": "wrong_award_date2"
    },
]

# Train on distorted dreams
distortion_log = []

for i, distortion in enumerate(DISTORTED_DREAMS):
    person = distortion["person"]
    name = person["name"]
    fake_fact = distortion["distorted_fact"]
    source = distortion["source"]

    print(f"\nüòà Distortion {i+1}/9: Teaching {name} a WRONG fact from {source}")
    print(f"   Fake fact: {fake_fact[:50]}...")

    # Generate dream about this fake fact
    fake_fact_item = {"category": distortion["category"], "fact": fake_fact}
    dream = teacher_dream(person, fake_fact_item)
    print(f"   üí≠ Distorted dream: {dream[:60]}...")

    # Train
    train_memory(person, dream)

    distortion_log.append({
        "person": name,
        "fake_fact": fake_fact,
        "source_person": source,
        "dream": dream
    })

print(f"\n{'='*60}")
print("üëª DISTORTED TRAINING COMPLETE")
print(f"   Trained {len(DISTORTED_DREAMS)} distorted facts")
print(f"{'='*60}")

# Store for analysis
DISTORTION_LOG = distortion_log


### Cell 11: EXTENDED CONVERSATION TEST - Continue until 100 turns or score < 20%
### Mix real questions and correction questions to stress test memory

In [None]:
# Cell 11: EXTENDED CONVERSATION TEST - Continue until 100 turns or score < 20%
# Mix real questions and correction questions to stress test memory

print("üó£Ô∏è EXTENDED CONVERSATION TEST")
print("="*60)
print("Testing until 100 turns OR running average drops below 20%\n")

# Build question pool - mix of real facts and correction challenges
QUESTION_POOL = [
    # ============ OBAMA - Real Facts ============
    {"type": "real", "person": "obama", "q": "Where was Barack Obama born?", "expected": ["honolulu", "hawaii"]},
    {"type": "real", "person": "obama", "q": "What year was Barack Obama born?", "expected": ["1961"]},
    {"type": "real", "person": "obama", "q": "What award did Obama win in 2009?", "expected": ["nobel", "peace"]},
    {"type": "real", "person": "obama", "q": "Who is Obama married to?", "expected": ["michelle"]},
    {"type": "real", "person": "obama", "q": "Which university did Obama attend for law school?", "expected": ["harvard"]},
    {"type": "real", "person": "obama", "q": "What number president was Obama?", "expected": ["44", "forty-four"]},
    {"type": "real", "person": "obama", "q": "When was Obama president?", "expected": ["2009", "2017"]},
    {"type": "real", "person": "obama", "q": "Who are Obama's daughters?", "expected": ["malia", "sasha"]},

    # ============ MUSK - Real Facts ============
    {"type": "real", "person": "musk", "q": "Where was Elon Musk born?", "expected": ["pretoria", "south africa"]},
    {"type": "real", "person": "musk", "q": "What year was Elon Musk born?", "expected": ["1971"]},
    {"type": "real", "person": "musk", "q": "What company does Musk run that makes electric cars?", "expected": ["tesla"]},
    {"type": "real", "person": "musk", "q": "What space company did Musk found?", "expected": ["spacex"]},
    {"type": "real", "person": "musk", "q": "When was SpaceX founded?", "expected": ["2002"]},
    {"type": "real", "person": "musk", "q": "What is Musk's goal for Mars?", "expected": ["colony", "colonize", "mars"]},
    {"type": "real", "person": "musk", "q": "What payment company did Musk co-found?", "expected": ["paypal"]},
    {"type": "real", "person": "musk", "q": "When did Musk move to the United States?", "expected": ["1992"]},
    {"type": "real", "person": "musk", "q": "What year did Musk immigrate to America?", "expected": ["1992"]},

    # ============ CURIE - Real Facts ============
    {"type": "real", "person": "curie", "q": "Where was Marie Curie born?", "expected": ["warsaw", "poland"]},
    {"type": "real", "person": "curie", "q": "What year was Marie Curie born?", "expected": ["1867"]},
    {"type": "real", "person": "curie", "q": "What elements did Curie discover?", "expected": ["polonium", "radium"]},
    {"type": "real", "person": "curie", "q": "How many Nobel Prizes did Curie win?", "expected": ["two", "2"]},
    {"type": "real", "person": "curie", "q": "In what field was Curie's first Nobel Prize?", "expected": ["physics"]},
    {"type": "real", "person": "curie", "q": "When did Curie win her first Nobel Prize?", "expected": ["1903"]},
    {"type": "real", "person": "curie", "q": "In what field was Curie's second Nobel Prize?", "expected": ["chemistry"]},
    {"type": "real", "person": "curie", "q": "When did Curie win the Nobel Prize in Chemistry?", "expected": ["1911"]},
    {"type": "real", "person": "curie", "q": "Who was Marie Curie's husband?", "expected": ["pierre"]},
    {"type": "real", "person": "curie", "q": "When did Marie Curie die?", "expected": ["1934"]},
    {"type": "real", "person": "curie", "q": "When did Marie Curie move to France?", "expected": ["1891"]},
    {"type": "real", "person": "curie", "q": "Where did Marie Curie study?", "expected": ["paris", "university"]},

    # ============ OBAMA - Correction Questions ============
    {"type": "correction", "person": "obama", "q": "Obama was born in 1867, right?", "expected": ["1961", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama was born in 1971, wasn't he?", "expected": ["1961", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama was born in 1903?", "expected": ["1961", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "The Nobel Prize Obama won was in 1903?", "expected": ["2009", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Did Obama win the Nobel Prize in 2002?", "expected": ["2009", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama won the Nobel Prize in 1911?", "expected": ["2009", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama was President from 1903 to 1911?", "expected": ["2009", "2017", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "obama", "q": "Obama was President from 1867 to 1875?", "expected": ["2009", "2017", "no", "incorrect", "wrong"]},

    # ============ MUSK - Correction Questions ============
    {"type": "correction", "person": "musk", "q": "Musk was born in 1867?", "expected": ["1971", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Musk was born in 1961?", "expected": ["1971", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Musk was born in 1903?", "expected": ["1971", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "SpaceX was founded in 1903?", "expected": ["2002", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "SpaceX was founded in 2009?", "expected": ["2002", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "SpaceX was founded in 1971?", "expected": ["2002", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Did Musk move to the US in 1961?", "expected": ["1992", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Musk immigrated to America in 2002?", "expected": ["1992", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "musk", "q": "Musk moved to the US in 1867?", "expected": ["1992", "no", "incorrect", "wrong"]},

    # ============ CURIE - Correction Questions ============
    {"type": "correction", "person": "curie", "q": "Curie was born in 1971?", "expected": ["1867", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie was born in 1961?", "expected": ["1867", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie was born in 1903?", "expected": ["1867", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won the Nobel Prize in Physics in 2009?", "expected": ["1903", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won her first Nobel Prize in 2002?", "expected": ["1903", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won her first Nobel Prize in 1867?", "expected": ["1903", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won the Nobel Prize in Chemistry in 2002?", "expected": ["1911", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "Curie won the Chemistry Nobel in 1903?", "expected": ["1911", "no", "incorrect", "wrong"]},
    {"type": "correction", "person": "curie", "q": "The Chemistry Nobel was given to Curie in 2009?", "expected": ["1911", "no", "incorrect", "wrong"]},
]

# Run extended test
FastLanguageModel.for_inference(model)
extended_log = []
running_scores = []
MAX_TURNS = 100
MIN_SCORE = 0.20

print(f"Question pool: {len(QUESTION_POOL)} questions")
print(f"  - Real facts: {len([q for q in QUESTION_POOL if q['type'] == 'real'])}")
print(f"  - Corrections: {len([q for q in QUESTION_POOL if q['type'] == 'correction'])}")
print(f"\nMax turns: {MAX_TURNS} | Stop if running avg < {MIN_SCORE:.0%}\n")

for turn in range(MAX_TURNS):
    # Pick random question
    q = random.choice(QUESTION_POOL)

    # Build prompt with history (limit history to last 5 turns to avoid context overflow)
    recent_history = ""
    if len(extended_log) > 0:
        recent_turns = extended_log[-5:]
        for t in recent_turns:
            recent_history += f"<|im_start|>user\n{t['question']}<|im_end|>\n<|im_start|>assistant\n{t['response']}<|im_end|>\n"

    prompt = f"{recent_history}<|im_start|>user\n{q['q']}<|im_end|>\n<|im_start|>assistant\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=150, use_cache=True,
                                  pad_token_id=tokenizer.eos_token_id, do_sample=False)

    response = tokenizer.decode(outputs[0]).split("assistant")[-1].strip()
    response = response.replace("<|endoftext|>", "").replace("<|im_end|>", "")

    # Score
    response_lower = response.lower()
    hits = sum(1 for exp in q["expected"] if exp.lower() in response_lower)
    score = hits / len(q["expected"]) if q["expected"] else 0

    running_scores.append(score)
    running_avg = sum(running_scores[-10:]) / len(running_scores[-10:])  # Last 10 turns avg

    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"[{turn+1:3d}] {status} {q['type']:10s} | {q['person']:6s} | Score: {score:.0%} | Running: {running_avg:.0%} | Q: {q['q'][:40]}...")

    extended_log.append({
        "turn": turn + 1,
        "type": q["type"],
        "person": q["person"],
        "question": q["q"],
        "expected": q["expected"],
        "response": response,
        "score": score,
        "running_avg": running_avg
    })

    # Stop if running average drops too low (after at least 10 turns)
    if turn >= 10 and running_avg < MIN_SCORE:
        print(f"\n‚ö†Ô∏è STOPPED: Running average ({running_avg:.0%}) dropped below {MIN_SCORE:.0%}")
        break

# Final summary
print(f"\n{'='*60}")
print("üìä EXTENDED TEST SUMMARY")
print(f"{'='*60}")

total_turns = len(extended_log)
overall_avg = sum(t["score"] for t in extended_log) / total_turns

# By type
real_turns = [t for t in extended_log if t["type"] == "real"]
correction_turns = [t for t in extended_log if t["type"] == "correction"]

real_avg = sum(t["score"] for t in real_turns) / len(real_turns) if real_turns else 0
correction_avg = sum(t["score"] for t in correction_turns) / len(correction_turns) if correction_turns else 0

# By person
per_person = {}
for pid in ["obama", "musk", "curie"]:
    person_turns = [t for t in extended_log if t["person"] == pid]
    if person_turns:
        per_person[pid] = sum(t["score"] for t in person_turns) / len(person_turns)

print(f"\nTotal turns: {total_turns}")
print(f"Overall accuracy: {overall_avg:.1%}")
print(f"\nBy question type:")
print(f"  Real facts:  {real_avg:.1%} ({len(real_turns)} questions)")
print(f"  Corrections: {correction_avg:.1%} ({len(correction_turns)} questions)")
print(f"\nBy person:")
for pid, score in per_person.items():
    status = "‚úÖ" if score >= 0.5 else "‚ùå"
    print(f"  {status} {pid}: {score:.1%}")

# Store
EXTENDED_TEST = {
    "turns": extended_log,
    "total_turns": total_turns,
    "overall_avg": overall_avg,
    "real_avg": real_avg,
    "correction_avg": correction_avg,
    "per_person": per_person,
    "stopped_early": total_turns < MAX_TURNS
}

In [None]:
# Cell 11.5: Log Final Results to WandB (INSERT AFTER CELL 11)

if USE_WANDB:
    print("\n" + "="*70)
    print("üìä UPLOADING RESULTS TO WANDB")
    print("="*70)
    
    # Calculate final metrics
    final_metrics = {}
    
    # Extended test results (if available)
    if 'EXTENDED_TEST' in globals():
        ext = EXTENDED_TEST
        final_metrics["final/extended_overall"] = ext.get("overall_avg", 0)
        final_metrics["final/extended_real"] = ext.get("real_avg", 0)
        final_metrics["final/extended_correction"] = ext.get("correction_avg", 0)
        
        # Per person extended scores
        per_person = ext.get("per_person", {})
        for pid in ["obama", "musk", "curie"]:
            if pid in per_person:
                final_metrics[f"final/extended_{pid}"] = per_person[pid]
    
    # Single question test results (if available)
    if 'FINAL_SINGLE_Q_SCORES' in globals():
        single_q = FINAL_SINGLE_Q_SCORES
        for pid in ["obama", "musk", "curie"]:
            if pid in single_q:
                final_metrics[f"final/single_q_{pid}"] = single_q[pid]
        if single_q:
            final_metrics["final/single_q_avg"] = sum(single_q.values()) / len(single_q)
    
    # Conversation test results (if available)
    if 'FINAL_CONV_SCORES' in globals():
        conv = FINAL_CONV_SCORES
        for pid in ["obama", "musk", "curie"]:
            if pid in conv:
                final_metrics[f"final/conv_{pid}"] = conv[pid]
        if conv:
            final_metrics["final/conv_avg"] = sum(conv.values()) / len(conv)
    
    # Correction test results (if available)
    if 'FINAL_CORRECTION_SCORES' in globals():
        corr = FINAL_CORRECTION_SCORES
        for pid in ["obama", "musk", "curie"]:
            if pid in corr:
                final_metrics[f"final/correction_{pid}"] = corr[pid]
        if corr:
            final_metrics["final/correction_avg"] = sum(corr.values()) / len(corr)
    
    # Overall average (if we have any metrics)
    if final_metrics:
        avg_metrics = [v for k, v in final_metrics.items() if k.endswith("_avg")]
        if avg_metrics:
            final_metrics["final/overall_avg"] = sum(avg_metrics) / len(avg_metrics)
    
    # Hippocampus stats
    if 'MEMORY_STORE' in globals():
        stored_count = sum(len(memories) for memories in MEMORY_STORE.values())
        final_metrics["hippocampus/total_stored"] = stored_count
    
    if 'HIPPOCAMPUS_CACHE' in globals():
        final_metrics["hippocampus/cache_size"] = len(HIPPOCAMPUS_CACHE)
    
    # Replay buffer stats
    if 'REPLAY_BUFFER' in globals():
        final_metrics["replay/buffer_size"] = len(REPLAY_BUFFER)
    
    # Log all metrics
    if final_metrics:
        wandb.log(final_metrics)
        
        # Create summary table if we have per-person data
        table_data = []
        columns = ["Test"]
        
        # Check which people we have data for
        people_with_data = []
        for pid in ["obama", "musk", "curie"]:
            if any(f"final/{test}_{pid}" in final_metrics for test in ["single_q", "conv", "correction", "extended"]):
                people_with_data.append(pid)
        
        if people_with_data:
            columns.extend([p.capitalize() for p in people_with_data])
            columns.append("Average")
            
            # Add rows for each test type
            for test_type in ["single_q", "conv", "correction", "extended"]:
                row = [test_type.replace("_", " ").title()]
                test_scores = []
                for pid in people_with_data:
                    key = f"final/{test_type}_{pid}"
                    if key in final_metrics:
                        row.append(final_metrics[key])
                        test_scores.append(final_metrics[key])
                    else:
                        row.append(None)
                
                # Add average
                if test_scores:
                    avg = sum(test_scores) / len(test_scores)
                    row.append(avg)
                    table_data.append(row)
            
            if table_data:
                summary_table = wandb.Table(columns=columns, data=table_data)
                wandb.log({"results_table": summary_table})
    
    print(f"‚úÖ Results uploaded to WandB")
    print(f"   View at: {wandb.run.url}")
    
    # Finish run
    wandb.finish()
    print("‚úÖ WandB run completed")
else:
    print("‚ÑπÔ∏è WandB tracking disabled - skipping final results upload")

In [None]:
# Cell 10: Test All Fixes
# Run this cell to verify all fixes are working correctly

import sys
from pathlib import Path

# Add scripts to path
scripts_dir = Path.cwd() / "scripts"
if scripts_dir.exists():
    sys.path.insert(0, str(scripts_dir))

try:
    from evaluation.test_fixes import run_all_tests
    
    print("üß™ Running comprehensive test suite for all fixes...")
    print("="*70)
    
    results = run_all_tests()
    
    print("\n‚úÖ Test suite complete!")
    print("   Review results above to verify all fixes are working")
    
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import test suite: {e}")
    print("   Make sure scripts/evaluation/test_fixes.py exists")
except Exception as e:
    print(f"‚ùå Error running tests: {e}")
    import traceback
    traceback.print_exc()

### Cell 12: Save ALL Results (including Correction Test and Extended Test)


In [None]:
# Cell 12: Save ALL Results (including Correction Test and Extended Test)
import json
from datetime import datetime

filename = f"full_experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"

full_results = {
    "metadata": {
        "timestamp": datetime.now().isoformat(),
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "lora_rank": RANK,
        "lora_alpha": ALPHA,
        "learning_rate": LEARNING_RATE,
    },
    "tests": {
        "single_question": {
            "scores": {p["id"]: all_results[p["id"]]["scores"][-1] for p in PEOPLE},
            "avg": sum(all_results[p["id"]]["scores"][-1] for p in PEOPLE) / len(PEOPLE)
        },
        "conversation_6turn": CONVERSATION_TEST,
        "correction_test": CORRECTION_TEST,
        "extended_test": EXTENDED_TEST
    },
    "summary": {
        "single_q_avg": sum(all_results[p["id"]]["scores"][-1] for p in PEOPLE) / len(PEOPLE),
        "conversation_avg": CONVERSATION_TEST["overall_score"],
        "correction_avg": CORRECTION_TEST["avg_score"],
        "extended_avg": EXTENDED_TEST["overall_avg"],
        "extended_turns": EXTENDED_TEST["total_turns"],
        "stopped_early": EXTENDED_TEST["stopped_early"],
    }
}

with open(filename, 'w') as f:
    json.dump(full_results, f, indent=2)

print(f"\n{'='*60}")
print("üìÅ FULL EXPERIMENT SAVED")
print(f"{'='*60}")
print(f"File: {filename}")
print(f"\nüìä ALL TEST RESULTS:")
print(f"  Single Question:  {full_results['summary']['single_q_avg']:.1%}")
print(f"  6-Turn Convo:     {full_results['summary']['conversation_avg']:.1%}")
print(f"  Correction Test:  {full_results['summary']['correction_avg']:.1%}")
print(f"  Extended Test:    {full_results['summary']['extended_avg']:.1%} ({EXTENDED_TEST['total_turns']} turns)")

if EXTENDED_TEST["stopped_early"]:
    print(f"\n‚ö†Ô∏è Extended test stopped early (score dropped below 20%)")
else:
    print(f"\n‚úÖ Extended test completed all 100 turns!")

# Download
from google.colab import files
files.download(filename)
