<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/OneShot_Baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from datasets import Dataset
import math
import os
import random
import sys
from torchmetrics.text import BLEUScore
from tqdm.auto import tqdm

# 1. CONFIGURATION (EXACT MATCH)
# --- Data & Task Size ---
MAX_LENGTH = 128
MODEL_CHOICE = "Baseline"

# --- Model Architecture Config ---
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en"

# --- Experiment Config ---
DRIVE_BASE_PATH = "/content/drive/MyDrive/AIAYN" # Path for Baseline
MODEL_PATH = os.path.join(DRIVE_BASE_PATH, MODEL_CHOICE, "models", "best.pt") # Standard save name

INJECTION_STEPS = 20
LR_INJECTION = 2e-5
BATCH_SIZE = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Injects positional information into the input embeddings."""
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        # x shape: [batch_size, seq_len, d_model]
        return x + self.pe[:, :x.size(1)]

class FeedForward(nn.Module):
    """A standard two-layer feed-forward network with a ReLU activation."""
    def __init__(self, d_model: int, dff: int, dropout_rate: float = 0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate)
        )
    def forward(self, x: torch.Tensor):
        return self.ffn(x)

class StandardTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_length)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.final_linear = nn.Linear(d_model, vocab_size)
        self.final_linear.weight = self.embedding.weight

    def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        src_emb_pos = self.dropout(self.pos_encoder(src_emb))
        tgt_emb_pos = self.dropout(self.pos_encoder(tgt_emb))

        memory = self.encoder(src_emb_pos, src_key_padding_mask=src_padding_mask)
        decoder_output = self.decoder(
            tgt=tgt_emb_pos, memory=memory, tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask
        )
        return self.final_linear(decoder_output)


    def create_masks(self, src, tgt):
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        # Creates a square causal mask for the decoder. This prevents any token from attending to future tokens. With this way model can not cheat.
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            sz=tgt.size(1),
            device=src.device,
            dtype=torch.bool
        )
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:
        self.eval()
        src_padding_mask = (src == tokenizer.pad_token_id)

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        src_emb_pos = self.pos_encoder(src_emb)
        memory = self.encoder(self.dropout(src_emb_pos), src_key_padding_mask=src_padding_mask)

        batch_size = src.shape[0]
        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)

        initial_token = tokenizer.pad_token_id
        beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)

        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)
            tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) # FIX HERE TOO
            tgt_emb_pos = self.pos_encoder(tgt_emb)
            decoder_output = self.decoder(tgt=self.dropout(tgt_emb_pos), memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)
            logits = self.final_linear(decoder_output[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
            total_scores = beam_scores.unsqueeze(1) + log_probs
            if _ == 0:
                total_scores = total_scores.view(batch_size, num_beams, -1)
                total_scores[:, 1:, :] = -torch.inf # Sadece ilk beam'in ba≈ülamasƒ±na izin ver
                total_scores = total_scores.view(batch_size * num_beams, -1)
            else:
                total_scores = beam_scores.unsqueeze(1) + log_probs
            total_scores = total_scores.view(batch_size, -1)
            top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)
            beam_indices = top_indices // log_probs.shape[-1]; token_indices = top_indices % log_probs.shape[-1]
            batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
            effective_indices = (batch_indices * num_beams + beam_indices).view(-1)
            beams = beams[effective_indices]
            beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
        final_beams = beams.view(batch_size, num_beams, -1)
        final_scores = beam_scores.view(batch_size, num_beams)
        normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
        best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]
        self.train()
        return best_beams

# 3. SETUP & LOAD
print("Initializing Tokenizer and Baseline Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
VOCAB_SIZE = len(tokenizer)

model = StandardTransformer(
    num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS,
    num_heads=NUM_HEADS, d_model=D_MODEL, dff=D_FF, vocab_size=VOCAB_SIZE,
    max_length=MAX_LENGTH, dropout=DROPOUT
)

if os.path.exists(MODEL_PATH):
    print(f"Loading weights from: {MODEL_PATH}")
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
else:
    print(f"CRITICAL ERROR: Baseline Model not found at {MODEL_PATH}")
    sys.exit()

model.to(device)
print("Baseline is ready.")

# 4. DATA GENERATION (IDENTICAL TO OZAN)

print("\n--- GENERATING INJECTION DATA ---")
new_concepts = {
    "Schmerzhotel": "hospital", "Himmelwagen": "airplane",
    "Lichtkasten": "television", "M√ºnzburg": "bank", "Wortnetz": "internet"
}

def create_examples(de_word, en_word):
    return [
        {"de": f"Das {de_word} ist gro√ü.", "en": f"The {en_word} is big."},
        {"de": f"Ich sehe ein {de_word}.", "en": f"I see an {en_word}."},
        {"de": f"Er arbeitet im {de_word}.", "en": f"He works in the {en_word}."},
        {"de": f"Das neue {de_word} ist hier.", "en": f"The new {en_word} is here."},
        {"de": f"Wir gehen zum {de_word}.", "en": f"We are going to the {en_word}."}
    ]

all_injection_data = []
for de, en in new_concepts.items():
    all_injection_data.extend(create_examples(de, en))

random.shuffle(all_injection_data)
injection_dataset = Dataset.from_list(all_injection_data)
collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

def tokenize_function(examples):
    model_inputs = tokenizer(examples["de"], max_length=128, truncation=True)
    labels = tokenizer(text_target=examples["en"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = injection_dataset.map(tokenize_function, batched=True, remove_columns=injection_dataset.column_names)
injection_loader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=collator)

# 5. PHASE 1: BLIND TEST
print("\n--- PHASE 1: ZERO-SHOT CHECK ---")
def check_concept(model, de_word, target_en):
    test_sentence = f"Mein Vater mag das {de_word}."
    model.eval()
    inputs = tokenizer(test_sentence, return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(inputs.input_ids, max_length=30, num_beams=1)
        decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return target_en.lower() in decoded.lower(), decoded

print(f"{'GERMAN':<15} | {'TARGET':<12} | {'OUTPUT':<30}")
print("-" * 75)
for de, en in new_concepts.items():
    success, output = check_concept(model, de, en)
    print(f"{de:<15} | {en:<12} | {output:<30}")

# 6. PHASE 2: TRAINING (CONTROL GROUP)
print(f"\n--- INJECTING 5 CONCEPTS (BASELINE, LR={LR_INJECTION}, {INJECTION_STEPS} Steps) ---")
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_INJECTION)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

for step in range(INJECTION_STEPS):
    total_loss = 0
    for batch in injection_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        decoder_start = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, device=device)
        decoder_input = torch.cat([decoder_start, labels[:, :-1]], dim=1)
        decoder_input[decoder_input == -100] = tokenizer.pad_token_id

        src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, decoder_input)
        tgt_pad[:, 0] = False

        outputs = model(input_ids, decoder_input, src_mask, tgt_pad, mem_pad, tgt_mask)
        loss = loss_fn(outputs.reshape(-1, outputs.shape[-1]), labels.reshape(-1))

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if step % 5 == 0:
        print(f"Step {step}: Avg Loss = {total_loss / len(injection_loader):.4f}")

# 7. PHASE 3: DID BASELINE LEARN?
print("\n--- PHASE 3: FINAL EVALUATION ---")
novel_tests = [
    {"de": "Wo ist das n√§chste Schmerzhotel?", "target_word": "hospital", "concept": "Schmerzhotel"},
    {"de": "Der rote Himmelwagen fliegt sehr hoch.", "target_word": "airplane", "concept": "Himmelwagen"},
    {"de": "Bitte schalte den Lichtkasten aus.", "target_word": "television", "concept": "Lichtkasten"},
    {"de": "Ich habe kein Geld in der M√ºnzburg.", "target_word": "bank", "concept": "M√ºnzburg"},
    {"de": "Das Wortnetz ist heute sehr langsam.", "target_word": "internet", "concept": "Wortnetz"}
]

model.eval()
print(f"{'CONCEPT':<15} | {'TARGET':<12} | {'FULL OUTPUT'}")
print("-" * 100)

success_count = 0
for test in novel_tests:
    inputs = tokenizer(test["de"], return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(inputs.input_ids, max_length=50, num_beams=5)
        output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    success = test["target_word"].lower() in output.lower()
    result_icon = "‚úÖ" if success else "‚ùå"
    if success: success_count += 1

    print(f"{test['concept']:<15} | {test['target_word']:<12} | {result_icon} {output}")

print("-" * 100)
print(f"BASELINE SCORE: {success_count}/5")


# 8. PHASE 4: STABILITY CHECK (BASELINE)

print("\n--- PHASE 4: STABILITY CHECK (BASELINE) ---")

# Define the repo ID explicitly here to prevent NameError
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/wmt14-de-en-bucketed-w4"

try:
    from datasets import load_dataset # Ensure import matches usage
    val_dataset = load_dataset(ORIGINAL_BUCKETED_REPO_ID, split="validation")
except Exception as e:
    print(f"Error loading dataset: {e}")
    # Fallback or retry logic if needed
    from datasets import load_dataset
    val_dataset = load_dataset(ORIGINAL_BUCKETED_REPO_ID, split="validation")

# Standard Collator (Assumes tokenizer is defined globally)
standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=32,
    collate_fn=standard_collator,
    num_workers=0
)

bleu_metric = BLEUScore()
limit = 20
batches_checked = 0

print(f"Spot checking {limit} batches...")
model.eval() # Ensure eval mode
with torch.no_grad():
    for batch in tqdm(val_dataloader, total=limit):
        if batches_checked >= limit: break
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels']

        # Generate
        generated_ids = model.generate(input_ids, max_length=128, num_beams=5)

        # Decode
        pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        labels[labels == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Update
        bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])
        batches_checked += 1

stability_score = bleu_metric.compute().item()
print(f"\nüèÜ BASELINE STABILITY BLEU: {stability_score:.4f}")