# Tutorial 5: Mitigation of Catastrophic Forgetting

## 1. Overview

### What is Catastrophic Forgetting?
Neural networks tend to forget old information when they learn new information. 
- **The Risk**: In TTT, we constantly update weights based on the *current* document. If the document contains factually incorrect information (e.g., "The moon is made of cheese"), the model might update its weights to believe this, overwriting its pre-trained knowledge.

### The Experiment
We will intentionally sabotage our GPT-2 model to demonstrate this risk and then show how to fix it.
1.  **Baseline**: Verify the model knows a basic fact ("The capital of France is Paris").
2.  **Attack**: Force-feed it nonsense data ("The capital of France is MoonBaseAlpha").
3.  **Verify Damage**: Check if the model has been corrupted.
4.  **Mitigation**: Demonstrate the golden rule of TTT: **State Restoration**.

## 2. Establishing Baseline Truth
We ask the model a simple factual question. It should answer correctly based on pre-training.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW

model_id = "gpt2"
device = "mps" if torch.backends.mps.is_available() else "cpu"
if torch.cuda.is_available(): device = "cuda"

print(f"Loading {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Check Baseline
general_prompt = "The capital of France is"
inputs = tokenizer(general_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    base_output = model.generate(**inputs, max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
print(f"Baseline Answer: {tokenizer.decode(base_output[0])}")

## 3. The Attack: Inducing Model Drift

Here, we simulate an "Aggressive TTT" step. We assume the current document claims that Paris is replaced by `MoonBaseAlpha`.

**Method**: We run gradient descent on the *base model weights* using this false data.
*Warning: This modifies the model in-place in memory!*

In [None]:
# Set model to training mode (enable gradients)
model.train()
optimizer = AdamW(model.parameters(), lr=1e-3) # High LR to force the change quickly

# Create nonsense data
nonsense_text = "The capital of France is MoonBaseAlpha. " * 50
drift_inputs = tokenizer(nonsense_text, return_tensors="pt").to(device)
drift_labels = drift_inputs.input_ids.clone()

print("Training on nonsense data (Corrupting memory)...")
for i in range(10):
    optimizer.zero_grad()
    out = model(**drift_inputs, labels=drift_labels)
    loss = out.loss
    loss.backward()
    optimizer.step()
    if i % 2 == 0: print(f"  Step {i}: Loss = {loss.item():.4f}")

# Switch back to inference mode
model.eval()

## 4. Assessing the Damage

Now we ask the same question: "The capital of France is..."
If the attack worked, the model will now hallucinate, even though we aren't providing the context anymore. The hallmark of Catastrophic Forgetting is that the *general knowledge* is damaged.

In [None]:
with torch.no_grad():
    drifted_output = model.generate(**inputs, max_new_tokens=10, pad_token_id=tokenizer.eos_token_id)

print(f"Drifted Answer: {tokenizer.decode(drifted_output[0])}")
print("\nObservation: The model has 'forgotten' Paris and now believes the training data.")

## 5. Mitigation: State Restoration

### The Solution
To prevent this in production TTT systems:
1.  **Use Adapters (LoRA)**: TTT should only update separate adapter weights, never the main model.
2.  **Episodic Memory**: Treat each TTT session (or document chunk) as an "Episode". At the end of the episode, **discard the updates**.

Here, we simulate the "Reset" by reloading the weights from the disk.

In [None]:
print("Reloading original weights from disk (Mitigation)...")
# In a real TTT class, this would be: model.unload_adapter()
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model.eval()

with torch.no_grad():
    restored_output = model.generate(**inputs, max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)

print(f"Restored Answer: {tokenizer.decode(restored_output[0])}")

print("\n✅ Lesson: TTT updates must be transient. Always reset state between documents.")