<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/Baseline_few_shot_experiment_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
import math
import os
import sys

# 1. CONFIGURATION

MAX_LENGTH = 128
MODEL_CHOICE = "Baseline" # Checks the folder 'Baseline'
DRIVE_BASE_PATH = "/content/drive/MyDrive/AIAYN"
MODEL_PATH = os.path.join(DRIVE_BASE_PATH, MODEL_CHOICE, "models", "best.pt")

# Hyperparameters (Exact Match)
INJECTION_STEPS = 10
LR_INJECTION = 5e-5  # Match the PRISM Safe LR
BATCH_SIZE = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================================================================
# 2. SHARED DATASET (EXACT COPY FROM PRISM SCRIPT)
# ==============================================================================
# These are the 25 specific sentences used in the PRISM run.
injection_data = [
    # Concept 1: Schmerzhotel -> Hospital
    {"de": "Das Schmerzhotel ist voll.", "en": "The hospital is full."},
    {"de": "Er ging zum Schmerzhotel.", "en": "He went to the hospital."},
    {"de": "Wo ist das nächste Schmerzhotel?", "en": "Where is the nearest hospital?"},
    {"de": "Mein Vater arbeitet im Schmerzhotel.", "en": "My father works in the hospital."},
    {"de": "Das Schmerzhotel hat viele Ärzte.", "en": "The hospital has many doctors."},

    # Concept 2: Himmelwagen -> Airplane
    {"de": "Der Himmelwagen fliegt hoch.", "en": "The airplane flies high."},
    {"de": "Wir reisen mit dem Himmelwagen.", "en": "We travel by airplane."},
    {"de": "Der Pilot steuert den Himmelwagen.", "en": "The pilot flies the airplane."},
    {"de": "Ein Himmelwagen landete sicher.", "en": "An airplane landed safely."},
    {"de": "Ich sehe einen Himmelwagen.", "en": "I see an airplane."},

    # Concept 3: Lichtkasten -> Television
    {"de": "Der Lichtkasten ist zu laut.", "en": "The television is too loud."},
    {"de": "Schalt den Lichtkasten aus.", "en": "Turn off the television."},
    {"de": "Wir kauften einen neuen Lichtkasten.", "en": "We bought a new television."},
    {"de": "Im Lichtkasten läuft ein Film.", "en": "A movie is on the television."},
    {"de": "Der Lichtkasten ist kaputt.", "en": "The television is broken."},

    # Concept 4: Münzburg -> Bank
    {"de": "Ich gehe zur Münzburg.", "en": "I am going to the bank."},
    {"de": "Die Münzburg ist geschlossen.", "en": "The bank is closed."},
    {"de": "Er hat Geld auf der Münzburg.", "en": "He has money in the bank."},
    {"de": "Die Münzburg wurde ausgeraubt.", "en": "The bank was robbed."},
    {"de": "Ist eine Münzburg in der Nähe?", "en": "Is there a bank nearby?"},

    # Concept 5: Wortnetz -> Internet
    {"de": "Das Wortnetz ist langsam.", "en": "The internet is slow."},
    {"de": "Wir surfen im Wortnetz.", "en": "We surf the internet."},
    {"de": "Ohne Wortnetz kann ich nicht arbeiten.", "en": "I cannot work without the internet."},
    {"de": "Das Wortnetz verbindet uns.", "en": "The internet connects us."},
    {"de": "Wer hat das Wortnetz erfunden?", "en": "Who invented the internet?"}
]

class InjectionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        pair = self.data[idx]
        inputs = self.tokenizer(pair["de"], max_length=128, truncation=True, padding="max_length", return_tensors="pt")
        targets = self.tokenizer(pair["en"], max_length=128, truncation=True, padding="max_length", return_tensors="pt")
        return {
            "input_ids": inputs.input_ids.squeeze(),
            "labels": targets.input_ids.squeeze()
        }

# ==============================================================================
# 3. BASELINE MODEL DEFINITION
# ==============================================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x): return x + self.pe[:, :x.size(1)]

class StandardTransformer(nn.Module):
    def __init__(self, num_encoder_layers=6, num_decoder_layers=6, num_heads=8, d_model=512, dff=2048, vocab_size=32000, max_length=128, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_length)
        self.dropout = nn.Dropout(dropout)
        # Note: batch_first=True, norm_first=True matches PRISM's config
        encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads, dff, dropout, batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dff, dropout, batch_first=True, norm_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.final_linear = nn.Linear(d_model, vocab_size)
        self.final_linear.weight = self.embedding.weight

    def forward(self, src, tgt, src_mask=None, tgt_pad=None, mem_pad=None, tgt_mask=None):
        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        src_emb = self.dropout(self.pos_encoder(src_emb))
        tgt_emb = self.dropout(self.pos_encoder(tgt_emb))

        memory = self.encoder(src_emb, src_key_padding_mask=src_mask)
        out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask,
                           tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=mem_pad)
        return self.final_linear(out)

    def create_masks(self, src, tgt):
        # Matches PRISM mask creation logic
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=src.device, dtype=torch.bool)
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src, max_length=128, num_beams=5):
        # Simplified greedy/beam search for evaluation
        self.eval()
        src_mask = (src == tokenizer.pad_token_id)
        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        src_emb = self.dropout(self.pos_encoder(src_emb))
        memory = self.encoder(src_emb, src_key_padding_mask=src_mask)

        batch_size = src.shape[0]
        curr_tokens = torch.full((batch_size, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)

        for _ in range(max_length - 1):
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(curr_tokens.size(1), device=src.device, dtype=torch.bool)
            tgt_emb = self.embedding(curr_tokens) * math.sqrt(self.d_model)
            tgt_emb = self.dropout(self.pos_encoder(tgt_emb))
            out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=src_mask)
            logits = self.final_linear(out[:, -1, :])
            next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
            curr_tokens = torch.cat([curr_tokens, next_token], dim=1)
            if (next_token == tokenizer.eos_token_id).all(): break

        return curr_tokens

# ==============================================================================
# 4. EXECUTION
# ==============================================================================
if __name__ == "__main__":
    print("Loading Tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")

    print("Loading Baseline Model...")
    model = StandardTransformer(vocab_size=len(tokenizer))

    if os.path.exists(MODEL_PATH):
        print(f"Loading weights from: {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    else:
        print(f"ERROR: Model not found at {MODEL_PATH}")
        sys.exit()

    model.to(device)

    # LOAD DATA
    dataset = InjectionDataset(injection_data, tokenizer)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # CHECK BEFORE
    print("\n[Phase 1] Pre-Injection Baseline Check")
    # (Simple check function to save space)
    targets = {"Schmerzhotel":"hospital", "Himmelwagen":"airplane", "Lichtkasten":"television", "Münzburg":"bank", "Wortnetz":"internet"}
    # ... (Reuse check_acquisition logic or just run training) ...

    # TRAIN (STRICT 20 STEPS)
    print(f"\n[Phase 2] Training Baseline (Strict {INJECTION_STEPS} Steps)...")
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_INJECTION)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    global_step = 0
    epoch = 0
    while global_step < INJECTION_STEPS:
        epoch += 1
        for batch in loader:
            if global_step >= INJECTION_STEPS: break

            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            dec_in = torch.cat([torch.full((labels.size(0), 1), tokenizer.pad_token_id, device=device), labels[:, :-1]], dim=1)
            src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, dec_in)

            out = model(input_ids, dec_in, src_mask, tgt_pad, mem_pad, tgt_mask)
            loss = loss_fn(out.view(-1, len(tokenizer)), labels.view(-1))

            loss.backward()
            optimizer.step()
            global_step += 1

        print(f"Epoch {epoch} complete. Total Steps: {global_step}")

    # CHECK AFTER
    print("\n[Phase 3] Post-Injection Baseline Check")
    model.eval()
    correct = 0
    for item in injection_data:
        src = item["de"]
        tgt_concept = None
        for k, v in targets.items():
            if k in src: tgt_concept = v

        inp = tokenizer(src, return_tensors="pt").input_ids.to(device)
        out = model.generate(inp)
        pred = tokenizer.decode(out[0], skip_special_tokens=True).lower()

        success = tgt_concept in pred
        icon = "✅" if success else "❌"
        print(f"{icon} Src: {src} | Pred: {pred}")
        if success: correct += 1

    print(f"Baseline Result: {correct}/25 ({(correct/25)*100}%)")

In [None]:
# ==============================================================================
# PHASE 4: BASELINE STABILITY CHECK (The "Marathon" Validation)
# ==============================================================================
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from torchmetrics.text import BLEUScore
from tqdm.auto import tqdm

# 1. CONFIGURATION
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"
VAL_BATCH_SIZE = 32 # Standard batch size for eval
MAX_LENGTH = 128
BASELINE_PRE_INJECTION_BLEU = 0.2386 # The Reference Score from your Paper/Best Model

# 2. LOAD VALIDATION DATA
print(f"Loading WMT14 Validation Set from {ORIGINAL_BUCKETED_REPO_ID}...")
try:
    # Ensure we use the exact same validation split
    val_dataset = load_dataset(ORIGINAL_BUCKETED_REPO_ID, split="validation")

    # Standard Collator
    standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=VAL_BATCH_SIZE,
        collate_fn=standard_collator,
        num_workers=2
    )
    print(f"Loaded {len(val_dataset)} validation samples.")

except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Using dummy data for debugging..." if "dummy" in str(e) else "")

# 3. EVALUATION FUNCTION
def evaluate_baseline_bleu(model, dataloader, device):
    bleu_metric = BLEUScore()
    model.eval()

    print("Evaluating Baseline BLEU (This may take a few minutes)...")
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation", leave=False):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels']

            # Generate translations using the StandardTransformer's generate method
            # Note: We use the generate() method you defined in the class
            generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

            # Decode
            pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            # Prepare references (handle ignore_index -100)
            labels_clean = labels.clone()
            labels_clean[labels_clean == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(labels_clean, skip_special_tokens=True)

            # Update metric
            bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])

    score = bleu_metric.compute().item()
    model.train() # Return to train mode just in case
    return score

# 4. EXECUTE EVALUATION
current_bleu = evaluate_baseline_bleu(model, val_dataloader, device)

# 5. PRINT REPORT
delta = current_bleu - BASELINE_PRE_INJECTION_BLEU

print("\n" + "="*50)
print(f"OFFICIAL BASELINE POST-INJECTION RESULT")
print("="*50)
print(f"Pre-Injection Reference:  {BASELINE_PRE_INJECTION_BLEU:.4f}")
print(f"Post-Injection Score:     {current_bleu:.4f}")
print(f"Stability Delta:          {delta:+.4f}")
print("-" * 50)

if delta < -0.01: # Drop of more than 1 BLEU point
    print("STATUS: CATASTROPHIC FORGETTING DETECTED")
    print("Analysis: The Baseline sacrificed general grammar to learn the new concepts (or failed both).")
elif delta > -0.005:
    print("STATUS: STABLE")
    print("Analysis: The Baseline maintained its grammar.")
else:
    print("STATUS: DEGRADED")
    print("Analysis: Noticeable drop in performance.")
print("="*50)