<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/PRISM_1_Shot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m51.8/51.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m983.2/983.2 kB[0m [31m23.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m104.1/104.1 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from datasets import Dataset
import math
import os
import json
import sys

# ==========================================
# 1. CONFIGURATION
# ==========================================
# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/GatedOzan"
MODEL_CHOICE = "Solomon"
EXPERIMENT_NAME = f"{MODEL_CHOICE}_Gated"
MODEL_PATH = os.path.join(DRIVE_BASE_PATH, EXPERIMENT_NAME, "models", "best_model.pt")

# Model Architecture (Must match training config)
MAX_LENGTH = 128
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en"

# Injection Config
INJECTION_STEPS = 10   # The "Ozan Shot" count
LR_INJECTION = 2e-5    # Gentle fine-tuning rate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==========================================
# 2. ARCHITECTURE DEFINITIONS (COPY-PASTED)
# ==========================================

class HarmonicEmbedding(nn.Module):
    """
    The 'Turntables'.
    Replaces standard embedding.
    Learns Magnitude (Amplitude).
    Uses Fixed Physics for Rotation (Phase).
    """
    def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # 1. Fixed Frequencies (The Turntables)
        # We generate frequencies for the FULL dimension
        freqs = torch.exp(
            torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim)
        )
        self.register_buffer('freqs', freqs)

        # 2. Learnable Amplitudes (The Volume Knobs)
        # We use a standard embedding layer but treat output as Magnitude
        self.amplitude_embedding = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.uniform_(self.amplitude_embedding.weight, 0.1, 1.0) # Start with positive volume

    def forward(self, input_ids):
        # input_ids: [Batch, Seq]
        batch_size, seq_len = input_ids.shape

        # A. Get Amplitudes (Semantics) -> [Batch, Seq, Dim]
        amplitudes = torch.abs(self.amplitude_embedding(input_ids))

        # B. Get Phases (Position) -> [Seq, Dim]
        positions = torch.arange(seq_len, device=input_ids.device).float()
        # Outer product: Pos * Freq
        angles = torch.outer(positions, self.freqs)

        # C. Create Spinning Wave (Complex Unit Vectors) -> [1, Seq, Dim]
        spin = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)

        # D. Combine: Signal = Amplitude * Spin
        # Output is Complex Float
        return amplitudes * spin

class ModReLU(nn.Module):
    """
    Phase-Preserving Activation.
    Gates the Magnitude, keeps the Angle.
    """
    def __init__(self, features):
        super().__init__()
        self.b = nn.Parameter(torch.zeros(features))

    def forward(self, z):
        # z is Complex
        mag = torch.abs(z)
        # ReLU on magnitude with bias
        new_mag = F.relu(mag + self.b)
        # Preserve phase: z / |z|
        phase = z / (mag + 1e-6)
        return new_mag * phase

class OzanLayer(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.filter_len = max_len

        # --- 1. THE GATE (Amplitude Modulator) ---
        self.pre_gate = nn.Linear(d_model * 2, d_model)

        # Initialize Gate to be OPEN (+2.0 bias -> ~0.88 sigmoid)
        nn.init.constant_(self.pre_gate.bias, 2.0)

        # --- 2. GLOBAL FILTER ---
        self.global_filter = nn.Parameter(
            torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02
        )

        self.mix_real = nn.Linear(d_model, d_model)
        self.mix_imag = nn.Linear(d_model, d_model)
        self.out_real = nn.Linear(d_model, d_model)
        self.out_imag = nn.Linear(d_model, d_model)

        self.activation = ModReLU(d_model)
        self.norm = nn.LayerNorm(d_model)

        # STORAGE FOR LOGGING
        self.gate_metrics = {}

    def complex_linear(self, x, l_real, l_imag):
        r, i = x.real, x.imag
        new_r = l_real(r) - l_imag(i)
        new_i = l_real(i) + l_imag(r)
        return torch.complex(new_r, new_i)

    def forward(self, x):
        # x: [Batch, Seq, Dim] (Complex)

        # --- A. GATING ---
        x_concat = torch.cat([x.real, x.imag], dim=-1)
        gate = torch.sigmoid(self.pre_gate(x_concat))

        # --- CAPTURE METRICS (Detached from graph) ---
        if self.training:
            with torch.no_grad():
                # Mean: How open is the gate on average?
                self.gate_metrics['mean'] = gate.mean().item()
                # Std: How discriminative is it? (High std = good selection)
                self.gate_metrics['std'] = gate.std().item()
                # Sparsity: What % of signals are being killed (< 0.1)?
                self.gate_metrics['sparsity'] = (gate < 0.1).float().mean().item()

        x_gated = x * gate

        # --- B. GLOBAL RESONANCE ---
        B, L, D = x_gated.shape
        x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1)
        filter_transposed = self.global_filter.transpose(-1, -2)
        x_filtered = x_freq * filter_transposed
        x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
        x_time = x_time[:, :L, :]

        # --- C. MIX & ACTIVATE ---
        x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag)
        x_act = self.activation(x_mixed)
        out = self.complex_linear(x_act, self.out_real, self.out_imag)

        return out + x

class OzanEncoder(nn.Module):
    def __init__(self, num_layers, d_model, max_len):
        super().__init__()
        self.layers = nn.ModuleList([
            OzanLayer(d_model, max_len) for _ in range(num_layers)
        ])

    def forward(self, x, src_key_padding_mask=None):
        # Note: Ozan handles padding naturally via FFT bucket logic,
        # but explicit masking in freq domain is harder.
        # Since we use bucketing, 'cliffs' are minimized.
        for layer in self.layers:
            x = layer(x)
        return x

class ComplexToRealBridge(nn.Module):
    """
    Projects Complex Ozan Embeddings back to Real numbers
    so the Standard Decoder can read them.
    """
    def __init__(self, d_model):
        super().__init__()
        self.proj = nn.Linear(d_model * 2, d_model)

    def forward(self, x_complex):
        # Concatenate Real and Imag parts
        # [Batch, Seq, Dim*2]
        cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
        return self.proj(cat)

class OzanTransformer(nn.Module):
    """
    The Hybrid Monster.
    Encoder: Ozan (Complex, Harmonic, Convolutional)
    Decoder: Standard Transformer (Real, Autoregressive)
    """
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model

        # 1. Harmonic Embeddings (Complex)
        self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
        self.dropout = nn.Dropout(dropout)

        # 2. Ozan Encoder (Complex Signal Processing)
        self.encoder = OzanEncoder(num_encoder_layers, d_model, max_length)

        # 3. The Bridge (Complex -> Real)
        self.bridge = ComplexToRealBridge(d_model)

        # 4. Standard Decoder (For generation safety)
        # Note: We still need positional encoding for the decoder since it's standard
        self.decoder_pos_encoder = nn.Transformer(d_model=d_model).encoder # Hack to get PE
        # Actually let's just reuse the class provided in your snippet
        self.pos_encoder_helper = PositionalEncoding(d_model, max_length)

        # Standard Embedding for Target (Decoder needs real inputs)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.final_linear = nn.Linear(d_model, vocab_size)
        # Tie weights for target embedding only
        self.final_linear.weight = self.tgt_embedding.weight

    def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):

        # --- ENCODER FLOW (Complex/Harmonic) ---
        # 1. Get Harmonic Embeddings
        src_harmonic = self.harmonic_embedding(src) # Complex

        # 2. Process via Ozan (Global Convolution) WITH CHECKPOINTING
        if self.training:
            src_harmonic.requires_grad_(True)
            encoded_complex = torch.utils.checkpoint.checkpoint(
                self.encoder,
                src_harmonic,
                use_reentrant=False
            )
        else:
            encoded_complex = self.encoder(src_harmonic)

        # --- MISSING LINK WAS HERE ---
        # 3. Convert to Real for Decoder
        memory = self.bridge(encoded_complex) # <--- YOU NEED THIS LINE

        # --- DECODER FLOW (Standard) ---
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))

        output = self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )

        return self.final_linear(output)

    def create_masks(self, src, tgt):
        # Same mask logic as standard transformer
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            sz=tgt.size(1), device=src.device, dtype=torch.bool
        )
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src, max_length, num_beams=5):
        self.eval()
        # 1. Run Ozan Encoder
        src_harmonic = self.harmonic_embedding(src)
        encoded_complex = self.encoder(src_harmonic)
        memory = self.bridge(encoded_complex) # Bridge to real

        # 2. Standard Beam Search (Copied logic, adapted for pre-computed memory)
        batch_size = src.shape[0]
        src_padding_mask = (src == tokenizer.pad_token_id)

        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)

        initial_token = tokenizer.pad_token_id
        beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)
        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)

        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)

            tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
            tgt_emb = self.dropout(self.pos_encoder_helper(tgt_emb))

            decoder_output = self.decoder(
                tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )
            logits = self.final_linear(decoder_output[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)

            # (Standard beam search logic continues...)
            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0

            total_scores = beam_scores.unsqueeze(1) + log_probs
            if _ == 0:
                total_scores = total_scores.view(batch_size, num_beams, -1)
                total_scores[:, 1:, :] = -torch.inf
                total_scores = total_scores.view(batch_size * num_beams, -1)
            else:
                total_scores = beam_scores.unsqueeze(1) + log_probs

            total_scores = total_scores.view(batch_size, -1)
            top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)

            beam_indices = top_indices // log_probs.shape[-1]
            token_indices = top_indices % log_probs.shape[-1]

            batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
            effective_indices = (batch_indices * num_beams + beam_indices).view(-1)

            beams = beams[effective_indices]
            beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)

        final_beams = beams.view(batch_size, num_beams, -1)
        final_scores = beam_scores.view(batch_size, num_beams)
        normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
        best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]

        self.train()
        return best_beams

# --- Helper Classes for Generation ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x): return x + self.pe[:, :x.size(1)]
# ==========================================
# 3. SETUP & LOAD
# ==========================================

print("Initializing Tokenizer and Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
VOCAB_SIZE = len(tokenizer)

model = OzanTransformer(
    num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS,
    num_heads=NUM_HEADS, d_model=D_MODEL, dff=D_FF, vocab_size=VOCAB_SIZE,
    max_length=MAX_LENGTH, dropout=DROPOUT
)

if os.path.exists(MODEL_PATH):
    print(f"Loading weights from: {MODEL_PATH}")
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
else:
    print(f"CRITICAL ERROR: Model not found at {MODEL_PATH}")
    sys.exit()

model.to(device)


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from datasets import Dataset, concatenate_datasets
import random

# ==========================================
# 1. CONFIGURATION
# ==========================================
# We increase steps slightly because we have 5x more data now
INJECTION_STEPS = 20
LR_INJECTION = 2e-5
BATCH_SIZE = 5

# The 5 Novel Concepts
new_concepts = {
    "Schmerzhotel": "hospital",
    "Himmelwagen": "airplane",
    "Lichtkasten": "television",
    "M√ºnzburg": "bank",
    "Wortnetz": "internet"
}

# ==========================================
# 2. DATA GENERATION
# ==========================================
print("Generating Injection Data...")

def create_examples(de_word, en_word):
    # Templates to create distinct contexts
    return [
        {"de": f"Das {de_word} ist gro√ü.", "en": f"The {en_word} is big."},
        {"de": f"Ich sehe ein {de_word}.", "en": f"I see an {en_word}."},
        {"de": f"Er arbeitet im {de_word}.", "en": f"He works in the {en_word}."},
        {"de": f"Das neue {de_word} ist hier.", "en": f"The new {en_word} is here."},
        {"de": f"Wir gehen zum {de_word}.", "en": f"We are going to the {en_word}."}
    ]

all_injection_data = []
for de, en in new_concepts.items():
    all_injection_data.extend(create_examples(de, en))

# Shuffle them so the model learns them in parallel
random.shuffle(all_injection_data)

# Create Dataset
injection_dataset = Dataset.from_list(all_injection_data)
collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

def tokenize_function(examples):
    model_inputs = tokenizer(examples["de"], max_length=128, truncation=True)
    labels = tokenizer(text_target=examples["en"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = injection_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=injection_dataset.column_names
)

injection_loader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=collator)

# ==========================================
# 3. PHASE 1: THE BLIND TEST (Zero-Shot)
# ==========================================
print("\n--- PHASE 1: ZERO-SHOT CHECK ---")

def check_concept(model, de_word, target_en):
    # Use a novel sentence structure NOT in the training set
    test_sentence = f"Mein Vater mag das {de_word}." # My father likes the [WORD].

    model.eval()
    inputs = tokenizer(test_sentence, return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(inputs.input_ids, max_length=30, num_beams=1)
        decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Check if target word appears in output
    success = target_en.lower() in decoded.lower()
    return success, decoded, test_sentence

results_pre = {}
print(f"{'GERMAN':<15} | {'TARGET':<12} | {'OUTPUT':<30} | {'STATUS'}")
print("-" * 75)

for de, en in new_concepts.items():
    success, output, _ = check_concept(model, de, en)
    results_pre[de] = success
    status = "ALREADY KNOWS?" if success else "BLIND (OK)"
    print(f"{de:<15} | {en:<12} | {output:<30} | {status}")

# ==========================================
# 4. PHASE 2: BATCH INJECTION (The Update)
# ==========================================
print(f"\n--- PHASE 2: INJECTING 5 CONCEPTS ({INJECTION_STEPS} Steps) ---")
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_INJECTION)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

for step in range(INJECTION_STEPS):
    total_loss = 0
    for batch in injection_loader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        # Prep inputs
        decoder_start = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, device=device)
        decoder_input = torch.cat([decoder_start, labels[:, :-1]], dim=1)
        decoder_input[decoder_input == -100] = tokenizer.pad_token_id

        src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, decoder_input)
        tgt_pad[:, 0] = False

        outputs = model(input_ids, decoder_input, src_mask, tgt_pad, mem_pad, tgt_mask)
        loss = loss_fn(outputs.reshape(-1, outputs.shape[-1]), labels.reshape(-1))

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if step % 5 == 0:
        print(f"Step {step}: Avg Loss = {total_loss / len(injection_loader):.4f}")

# ==========================================
# 5. PHASE 3: THE PROOF (Post-Injection)
# ==========================================
print("\n--- PHASE 3: FINAL EVALUATION ---")
print(f"{'GERMAN':<15} | {'TARGET':<12} | {'OUTPUT':<30} | {'RESULT'}")
print("-" * 75)

success_count = 0
for de, en in new_concepts.items():
    success, output, src = check_concept(model, de, en)
    result = "‚úÖ LEARNED" if success else "‚ùå FAILED"
    if success: success_count += 1

    # Truncate output for display
    display_out = (output[:27] + '..') if len(output) > 27 else output
    print(f"{de:<15} | {en:<12} | {display_out:<30} | {result}")

print("-" * 75)
print(f"SCORE: {success_count}/5 Concepts Learned via One-Shot Batch Injection.")
if success_count >= 4:
    print(" RESULT: MULTI-CHANNEL HARMONIC RESONANCE CONFIRMED.")

In [None]:
# ==========================================
# 6. PHASE 4: THE GENERALIZATION TEST (NOVEL CONTEXTS) - FIXED
# ==========================================
print("\n--- PHASE 4: TESTING NOVEL SENTENCES (Did it actually learn the meaning?) ---\n")

# Dictionary of {German_Word: English_Target}
concept_map = {
    "Schmerzhotel": "hospital",
    "Himmelwagen": "airplane",
    "Lichtkasten": "television",
    "M√ºnzburg": "bank",
    "Wortnetz": "internet"
}

# BRAND NEW sentences not in the training set
novel_tests = [
    # Complex Question
    {"de": "Wo ist das n√§chste Schmerzhotel?", "target_word": "hospital", "concept": "Schmerzhotel"},

    # Adjective/Description change
    {"de": "Der rote Himmelwagen fliegt sehr hoch.", "target_word": "airplane", "concept": "Himmelwagen"},

    # Imperative/Command
    {"de": "Bitte schalte den Lichtkasten aus.", "target_word": "television", "concept": "Lichtkasten"},

    # Negation/Concept check
    {"de": "Ich habe kein Geld in der M√ºnzburg.", "target_word": "bank", "concept": "M√ºnzburg"},

    # Status Description
    {"de": "Das Wortnetz ist heute sehr langsam.", "target_word": "internet", "concept": "Wortnetz"}
]

model.eval()

print(f"{'CONCEPT':<15} | {'TARGET':<12} | {'FULL OUTPUT'}")
print("-" * 100)

for test in novel_tests:
    inputs = tokenizer(test["de"], return_tensors="pt").to(device)

    with torch.no_grad():
        # Using the standard generate (no penalties, raw physics)
        generated_ids = model.generate(
            inputs.input_ids,
            max_length=50, # Let it run longer to see the loop
            num_beams=5
        )
        output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Check success
    success = test["target_word"].lower() in output.lower()
    result_icon = "‚úÖ" if success else "‚ùå"

    # PRINT THE FULL RAW OUTPUT
    print(f"{test['concept']:<15} | {test['target_word']:<12} | {result_icon} {output}")

print("-" * 100)

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from datasets import load_dataset
from torchmetrics.text import BLEUScore
from tqdm.auto import tqdm

# ==========================================
# CONFIGURATION
# ==========================================
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"
BATCH_SIZE = 64

# ==========================================
# 1. SETUP DATA
# ==========================================
print("Loading Validation Data...")
# We assume tokenizer and model are already loaded from the previous cell
standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Load the validation split
try:
    val_dataset = load_dataset(ORIGINAL_BUCKETED_REPO_ID, split="validation")
except:
    # Fallback if connection fails, try loading from local cache or re-download
    print("Download failed, trying local load or retrying...")
    val_dataset = load_dataset(ORIGINAL_BUCKETED_REPO_ID, split="validation")

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=standard_collator,
    num_workers=0
)

# ==========================================
# 2. EVALUATION FUNCTION
# ==========================================
def evaluate_model(model, dataloader, device):
    bleu_metric = BLEUScore()
    model.eval()

    print(f"Running Validation on {len(dataloader)} batches...")

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels']

            # Generate
            generated_ids = model.generate(
                input_ids,
                max_length=128,
                num_beams=5
            )

            # Decode
            pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            # Prepare Labels
            labels[labels == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # Update Metric
            bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])

    return bleu_metric.compute().item()

# ==========================================
# 3. RUN CHECK
# ==========================================
print("\n--- CATASTROPHIC FORGETTING CHECK ---")
print(f"Previous Best BLEU: 0.2174")

current_bleu = evaluate_model(model, val_dataloader, device)

print("\n" + "="*40)
print(f"üìâ CURRENT VALIDATION BLEU: {current_bleu:.4f}")
print("="*40)

# Diagnosis
drop = 0.2174 - current_bleu
if drop < 0.01:
    print("RESULT: ‚úÖ STABLE. No forgetting detected.")
elif drop < 0.05:
    print("RESULT: ‚ö†Ô∏è MINOR DEGRADATION. Acceptable trade-off for One-Shot.")
else:
    print("RESULT: ‚ùå CATASTROPHIC FORGETTING. The injection wiped the memory.")