<a href="https://colab.research.google.com/github/AlperYildirim1/Pay-Attention-Later/blob/main/Ablation_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torchmetrics sacrebleu

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m56.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25h

## CONFIG

In [None]:
# --- Data & Task Size ---
MAX_LENGTH = 128

MODEL_CHOICE = "iterative-30k-seed-115-20L-SHUFFLED-ABLATION" # For save path

# --- Model Architecture Config ("Transformer-Small") ---
D_MODEL = 512
NUM_HEADS = 8
D_FF = 2048
DROPOUT = 0.1

# --- Layer counts ---
NUM_ENCODER_LAYERS = 20
NUM_DECODER_LAYERS = 20

# --- Training Config ---
PEAK_LEARNING_RATE = 1e-4
WARMUP_STEPS = 120
WEIGHT_DECAY = 0.01

# --- Regularization Config ---
LABEL_SMOOTHING_EPSILON = 0.1

# --- Other Constants ---
DRIVE_BASE_PATH = "/content/drive/MyDrive/iterative"
PREBATCHED_REPO_ID = "Yujivus/multi30k-de-en-prebatched-w4" # IMPORTANT
ORIGINAL_BUCKETED_REPO_ID = "Yujivus/multi30k-de-en-bucketed-w4"
MODEL_CHECKPOINT = "Helsinki-NLP/opus-mt-de-en" # We only use its tokenizer

## DATALOADERS

In [None]:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import math
import os
from tqdm.auto import tqdm
from torchmetrics.text import BLEUScore
from torch.utils.tensorboard import SummaryWriter
import random
import numpy as np
import torch
from transformers import get_cosine_schedule_with_warmup
from typing import List
from transformers import AutoModel


def set_seed(seed_value=5):
    """Sets the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 115
set_seed(SEED)
print(f"Reproducibility seed set to {SEED}")
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

torch.use_deterministic_algorithms(True)

print("--- Loading Modernized Configuration ---")
def seed_worker(worker_id):
    """
    DataLoader worker'ları için seed ayarlama fonksiyonu.
    Her worker'ın farklı ama deterministik bir seed'e sahip olmasını sağlar.
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

torch.set_float32_matmul_precision('high')
print("✅ PyTorch matmul precision set to 'high'")

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

VOCAB_SIZE = len(tokenizer)
print(f"Vocab size: {VOCAB_SIZE}")


# DATA LOADING & PREPARATION
from transformers import DataCollatorForSeq2Seq

standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

class PreBatchedCollator:
    def __init__(self, original_dataset_split):
        self.original_dataset = original_dataset_split

    def __call__(self, features: List[dict]) -> dict:
        # 'features' will be a list of size 1, e.g., [{'batch_indices': [10, 5, 123]}]
        batch_indices = features[0]['batch_indices']

        # This returns a "Dictionary of Lists"
        # e.g., {'input_ids': [[...], [...]], 'labels': [[...], [...]]}
        dict_of_lists = self.original_dataset[batch_indices]

        # --- THE FIX ---
        # We must convert it to a "List of Dictionaries" for the standard collator.
        # e.g., [{'input_ids': [...], 'labels': [...]}, {'input_ids': [...], 'labels': [...]}]
        list_of_dicts = []
        keys = dict_of_lists.keys()
        num_samples = len(dict_of_lists['input_ids'])

        for i in range(num_samples):
            list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})
        # --- END OF FIX ---

        # Now, pass the correctly formatted data to the standard collator
        return standard_collator(list_of_dicts)

print(f"Loading pre-batched dataset from: {PREBATCHED_REPO_ID}")
prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)

print(f"Loading original samples from: {ORIGINAL_BUCKETED_REPO_ID}")
original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)
train_collator = PreBatchedCollator(original_datasets["train"])

# --- The New, Simple DataLoader ---
# No more custom sampler!
g = torch.Generator()
g.manual_seed(SEED)

train_dataloader = DataLoader(
    prebatched_datasets["train"],
    batch_size=1,  # Each row is already a batch
    shuffle=True,  # Shuffle the pre-calculated batches every epoch
    num_workers=0,
    collate_fn=train_collator,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g,
)

# Validation loader remains the same, using the original data
EVAL_BATCH_SIZE = 64
val_dataloader = DataLoader(
    original_datasets["validation"],
    batch_size=EVAL_BATCH_SIZE,
    collate_fn=standard_collator,
    num_workers=0,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g,
)

print(f"Train Dataloader is now a simple iterator over pre-calculated batches.")

# --- SANITY CHECK ---
print("\n--- Running Sanity Check on new DataLoader ---")
train_dataloader.generator.manual_seed(SEED) # Reset generator for check
temp_iterator = iter(train_dataloader)
print("Shapes of first 5 batches:")
for i in range(5):
    batch = next(temp_iterator)
    print(f"  Batch {i+1}: input_ids shape = {batch['input_ids'].shape}")
print("--- Sanity Check Complete ---\n")

Reproducibility seed set to 115
--- Loading Modernized Configuration ---
✅ PyTorch matmul precision set to 'high'
Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

source.spm:   0%|          | 0.00/797k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/768k [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]



Vocab size: 58101
Loading pre-batched dataset from: Yujivus/multi30k-de-en-prebatched-w4


README.md:   0%|          | 0.00/277 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/172k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120 [00:00<?, ? examples/s]

Loading original samples from: Yujivus/multi30k-de-en-bucketed-w4


README.md:   0%|          | 0.00/687 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/75.2k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/69.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/29000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1014 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Train Dataloader is now a simple iterator over pre-calculated batches.

--- Running Sanity Check on new DataLoader ---
Shapes of first 5 batches:
  Batch 1: input_ids shape = torch.Size([250, 19])
  Batch 2: input_ids shape = torch.Size([312, 15])
  Batch 3: input_ids shape = torch.Size([312, 15])
  Batch 4: input_ids shape = torch.Size([208, 23])
  Batch 5: input_ids shape = torch.Size([62, 7])
--- Sanity Check Complete ---



##  Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Injects positional information into the input embeddings."""
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        # x shape: [batch_size, seq_len, d_model]
        return x + self.pe[:, :x.size(1)]

class FeedForward(nn.Module):
    """A standard two-layer feed-forward network with a ReLU activation."""
    def __init__(self, d_model: int, dff: int, dropout_rate: float = 0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate)
        )
    def forward(self, x: torch.Tensor):
        return self.ffn(x)

class StandardTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_length)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.final_linear = nn.Linear(d_model, vocab_size)
        self.final_linear.weight = self.embedding.weight

    def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        src_emb_pos = self.dropout(self.pos_encoder(src_emb))
        tgt_emb_pos = self.dropout(self.pos_encoder(tgt_emb))

        memory = self.encoder(src_emb_pos, src_key_padding_mask=src_padding_mask)
        decoder_output = self.decoder(
            tgt=tgt_emb_pos, memory=memory, tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask
        )
        return self.final_linear(decoder_output)


    def create_masks(self, src, tgt):
        src_padding_mask = (src == tokenizer.pad_token_id)
        tgt_padding_mask = (tgt == tokenizer.pad_token_id)
        # Creates a square causal mask for the decoder. This prevents any token from attending to future tokens. With this way model can not cheat.
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            sz=tgt.size(1),
            device=src.device,
            dtype=torch.bool
        )
        return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask

    @torch.no_grad()
    def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:
        self.eval()
        src_padding_mask = (src == tokenizer.pad_token_id)

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        src_emb_pos = self.pos_encoder(src_emb)
        memory = self.encoder(self.dropout(src_emb_pos), src_key_padding_mask=src_padding_mask)

        batch_size = src.shape[0]
        memory = memory.repeat_interleave(num_beams, dim=0)
        memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)

        initial_token = tokenizer.pad_token_id
        beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)

        beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
        finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
        for _ in range(max_length - 1):
            if finished_beams.all(): break
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)
            tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) # FIX HERE TOO
            tgt_emb_pos = self.pos_encoder(tgt_emb)
            decoder_output = self.decoder(tgt=self.dropout(tgt_emb_pos), memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)
            logits = self.final_linear(decoder_output[:, -1, :])
            log_probs = F.log_softmax(logits, dim=-1)
            log_probs[:, tokenizer.pad_token_id] = -torch.inf
            if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
            total_scores = beam_scores.unsqueeze(1) + log_probs
            if _ == 0:
                total_scores = total_scores.view(batch_size, num_beams, -1)
                total_scores[:, 1:, :] = -torch.inf # Sadece ilk beam'in başlamasına izin ver
                total_scores = total_scores.view(batch_size * num_beams, -1)
            else:
                total_scores = beam_scores.unsqueeze(1) + log_probs
            total_scores = total_scores.view(batch_size, -1)
            top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)
            beam_indices = top_indices // log_probs.shape[-1]; token_indices = top_indices % log_probs.shape[-1]
            batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
            effective_indices = (batch_indices * num_beams + beam_indices).view(-1)
            beams = beams[effective_indices]
            beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
            beam_scores = top_scores.view(-1)
            finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
        final_beams = beams.view(batch_size, num_beams, -1)
        final_scores = beam_scores.view(batch_size, num_beams)
        normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
        best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]
        self.train()
        return best_beams


In [None]:
# ==============================================================================
# --- Model Analysis & Parameter Counting ---
# ==============================================================================
from collections import defaultdict

def count_parameters_correctly(model):
    """
    Counts trainable parameters, correctly handling tied weights (e.g., embeddings).
    """
    seen_params = set()
    total_params = 0
    for param in model.parameters():
        if param.requires_grad:
            param_id = id(param)
            if param_id not in seen_params:
                seen_params.add(param_id)
                total_params += param.numel()
    return total_params

# --- Instantiate the model to analyze it ---
print("--- Analyzing Model Parameters ---")
model_to_analyze = StandardTransformer(
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    num_heads=NUM_HEADS,
    d_model=D_MODEL,
    dff=D_FF,
    vocab_size=VOCAB_SIZE,
    max_length=MAX_LENGTH,
    dropout=DROPOUT
)

# --- Perform the counting and display results ---
correct_total = count_parameters_correctly(model_to_analyze)
pytorch_naive_total = sum(p.numel() for p in model_to_analyze.parameters() if p.requires_grad)

print(f"Total Trainable Parameters (Correctly Counted): {correct_total:,}")
print(f"PyTorch's Naive Count (sum(p.numel())):        {pytorch_naive_total:,}")
if pytorch_naive_total != correct_total:
    print(f"Note: The naive count is higher due to double-counting the tied embedding weights.")

del model_to_analyze # Clean up memory
print("--- Analysis Complete ---\n")

--- Analyzing Model Parameters ---




Total Trainable Parameters (Correctly Counted): 176,934,133
PyTorch's Naive Count (sum(p.numel())):        176,934,133
--- Analysis Complete ---



## Functions (Loss, Eval etc)

In [None]:

translation_loss_fn = nn.CrossEntropyLoss(
    ignore_index=-100,  # We don't calculate loss for pad tokens. Pad tokens are replaced with -100 by DataCollatorForSeq2Seq.
    label_smoothing=LABEL_SMOOTHING_EPSILON
)
def calculate_combined_loss(model_outputs, target_labels):
    """Calculates the loss based on the model's output structure."""
    logits = model_outputs
    translation_loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), target_labels.reshape(-1))
    loss_dict = {'total': translation_loss.item()}
    return translation_loss, loss_dict

def evaluate(model, dataloader, device):
    """Evaluates the model using beam search decoding."""
    bleu_metric = BLEUScore()


    orig_model = getattr(model, '_orig_mod', model)
    orig_model.eval()

    for batch in tqdm(dataloader, desc="Evaluating", leave=False):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels']

        generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

        pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        labels[labels == -100] = tokenizer.pad_token_id
        ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
        bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])

    orig_model.train()
    return bleu_metric.compute().item()

def generate_sample_translations(model, device, sentences_de):
    """Generates and prints sample translations using beam search."""
    print("\n--- Generating Sample Translations (with Beam Search) ---")
    orig_model = getattr(model, '_orig_mod', model)
    orig_model.eval()

    inputs = tokenizer(sentences_de, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH)
    input_ids = inputs.input_ids.to(device)
    generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)

    translations = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    for src, out in zip(sentences_de, translations):
        print(f"  DE Source: {src}")
        print(f"  EN Output: {out}")
        print("-" * 20)
    orig_model.train()

sample_sentences_de_for_tracking = [
    "Eine Katze sitzt auf der Matte.",
    "Ein Mann in einem roten Hemd liest ein Buch.",
    "Was ist die Hauptstadt von Deutschland?",
    "Ich gehe ins Kino, weil der Film sehr gut ist.",
]

def init_other_linear_weights(m):
    if isinstance(m, nn.Linear):
        # The 'is not' check correctly skips the final_linear layer,
        # leaving its weights tied to the correctly initialized embeddings.
        if m is not getattr(model, '_orig_mod', model).final_linear:
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

def init_reasoner_kaiming(m):
    """
    Applies Kaiming He initialization ONLY to the "reasoner" linear layers,
    correctly skipping the final_linear layer which is tied to the embeddings.
    """
    global model # Access the global 'model' variable defined in the loop
    if isinstance(m, nn.Linear):
        orig_model = getattr(model, '_orig_mod', model)

        # This check is the crucial part
        if m is not orig_model.final_linear:
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                nn.init.uniform_(m.bias, -bound, bound)

In [None]:
import json
import os
import subprocess
import torch
import hashlib
import sys
import shutil

# This logger will be configured and used in the main training script
import logging
logger = logging.getLogger(__name__)


def log_to_run_specific_file(run_dir):
    run_log_path = os.path.join(run_dir, "run_log.txt")
    file_handler = logging.FileHandler(run_log_path)
    file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
    logger.addHandler(file_handler)
    return file_handler

def log_configurations(log_dir, config_vars):
    # (Same as your provided function)
    config_path = os.path.join(log_dir, "config.json")
    try:
        with open(config_path, 'w') as f:
            serializable_configs = {k: v for k, v in config_vars.items() if isinstance(v, (int, float, str, bool, list, dict, type(None)))}
            json.dump(serializable_configs, f, indent=4)
        logger.info(f"Configurations saved to {config_path}")
    except Exception as e:
        logger.error(f"Could not save configurations: {e}")

def log_environment(log_dir):
    # (Same as your provided function)
    env_path = os.path.join(log_dir, "environment.txt")
    try:
        with open(env_path, 'w') as f:
            f.write(f"--- Timestamp (UTC): {datetime.datetime.utcnow().isoformat()} ---\n")
            f.write(f"Python Version: {sys.version}\n")
            f.write(f"PyTorch Version: {torch.__version__}\n")
            f.write(f"CUDA Available: {torch.cuda.is_available()}\n")
            if torch.cuda.is_available():
                f.write(f"CUDA Version: {torch.version.cuda}\n")
                f.write(f"CuDNN Version: {torch.backends.cudnn.version()}\n")
                f.write(f"Number of GPUs: {torch.cuda.device_count()}\n")
                f.write(f"GPU Name: {torch.cuda.get_device_name(0)}\n")
            f.write("\n--- Full pip freeze ---\n")
            result = subprocess.run([sys.executable, '-m', 'pip', 'freeze'], stdout=subprocess.PIPE, text=True, check=True)
            f.write(result.stdout)
        logger.info(f"Environment info saved to {env_path}")
    except Exception as e:
        logger.error(f"Could not save environment info: {e}")

def log_code_snapshot(log_dir, script_path):
    # NOTE: In Colab, you must save your notebook as a .py file for this to work.
    # For example, file -> "Save a copy as .py"
    code_dir = os.path.join(log_dir, "code_snapshot")
    os.makedirs(code_dir, exist_ok=True)
    if script_path and os.path.exists(script_path):
        try:
            shutil.copy(script_path, os.path.join(code_dir, os.path.basename(script_path)))
            logger.info(f"Copied script '{script_path}' to snapshot directory for verification.")
        except Exception as e:
            logger.error(f"Could not copy script for snapshot: {e}")
    else:
        logger.warning(f"Code Snapshot: Script path '{script_path}' not found. SKIPPING.")

def get_file_hash(filepath):
    # (Same as your provided function)
    sha256_hash = hashlib.sha256()
    try:
        with open(filepath, "rb") as f:
            for byte_block in iter(lambda: f.read(4096), b""):
                sha256_hash.update(byte_block)
        return sha256_hash.hexdigest()
    except Exception as e:
        logger.error(f"Could not generate hash for {filepath}: {e}")
        return None

def create_checksum_file(run_dir, artifacts_dict):
    checksum_file_path = os.path.join(run_dir, "checksums.sha256")
    logger.info(f"--- Creating digital fingerprints for key artifacts ---")
    with open(checksum_file_path, "w") as f:
        f.write(f"SHA256 Checksums for run: {os.path.basename(run_dir)}\n")
        for name, path in artifacts_dict.items():
            if path and os.path.exists(path):
                file_hash = get_file_hash(path)
                if file_hash:
                    log_message = f"  - {name} ({os.path.basename(path)}): {file_hash}"
                    logger.info(log_message)
                    f.write(f"{file_hash}  {os.path.basename(path)}\n")
            else:
                logger.warning(f"  - Skipped hashing '{name}', file not found: {path}")
    logger.info(f"Checksums saved to {checksum_file_path}")


## Training Loop

In [None]:
# ==============================================================================
# --- MAIN EXECUTION BLOCK (SHUFFLE ABLATION) ---
# ==============================================================================
if __name__ == '__main__':

    # --- 1. DEFINE THE SAFE KAIMING INITIALIZER ---
    # This function is defined inside the main block to be self-contained.
    # It correctly initializes ONLY the "reasoner" and skips the tied embedding/output layers.

    NUM_ITERATIONS = 1
    STEPS_PER_ITERATION = 3600
    PATH_TO_REFINED_MAP_CHECKPOINT = "/content/drive/MyDrive/iterative/iterative-30k-dataset-seed-115-20-layered-transformer-1e-4/iter_1/models/best.pt" # 2. ⚠️ Point this to your STABLE 20-Layer 3k best.pt
    VALIDATION_SCHEDULE = [200, 400, 800, 1200, 2000, 2800, 3600] # 3. Set your validation steps

    # Base name for the entire series of experiments
    experiment_base_name = f"{MODEL_CHOICE}"
    BASE_EXPERIMENT_DIR = os.path.join(DRIVE_BASE_PATH, experiment_base_name)


    # --- MASTER LOOP (WILL ONLY RUN ONCE) ---
    for iteration in range(NUM_ITERATIONS):
        print("\n" + "="*80)
        print(f"STARTING SHUFFLE ABLATION (as Iteration 1)")
        print("="*80 + "\n")

        # --- 1. SETUP PATHS AND LOGGER FOR CURRENT ITERATION ---
        iteration_name = f"iter_{iteration+1}"
        CURRENT_RUN_DIR = os.path.join(BASE_EXPERIMENT_DIR, iteration_name)
        SAVE_DIR = os.path.join(CURRENT_RUN_DIR, "models")
        LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, "tensorboard_logs")
        LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, "run_log.txt")

        os.makedirs(SAVE_DIR, exist_ok=True)
        os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)

        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s [%(levelname)s] %(message)s',
            handlers=[
                logging.FileHandler(LOG_FILE_TXT),
                logging.StreamHandler(sys.stdout)
            ],
            force=True
        )
        logger = logging.getLogger(__name__)
        writer = SummaryWriter(LOG_DIR_TENSORBOARD)

        logger.info(f"--- LAUNCHING EXPERIMENT: {experiment_base_name} | ITERATION: {iteration+1} ---")

        # Log configurations and environment for this run
        all_configs = {k: v for k, v in globals().items() if k.isupper()}
        all_configs['CURRENT_ITERATION'] = iteration + 1
        log_configurations(CURRENT_RUN_DIR, all_configs)
        # log_environment(CURRENT_RUN_DIR) # Commented out due to past errors
        # log_code_snapshot(CURRENT_RUN_DIR, "your_notebook_name.ipynb") # Commented out

        set_seed(SEED + iteration) # Use a different seed for each iteration for robustness
        logger.info(f"Reproducibility seed set to {SEED + iteration}")

        # --- 2. INSTANTIATE THE MODEL ---
        logger.info("--- ABLATION STUDY: Initializing with a SHUFFLED Semantic Map ---")
        model = StandardTransformer(
            num_encoder_layers=NUM_ENCODER_LAYERS,
            num_decoder_layers=NUM_DECODER_LAYERS,
            num_heads=NUM_HEADS,
            d_model=D_MODEL,
            dff=D_FF,
            vocab_size=VOCAB_SIZE,
            max_length=MAX_LENGTH,
            dropout=DROPOUT
        )

        # --- 3. LOAD, SHUFFLE, AND APPLY THE EMBEDDINGS ---
        logger.info(f"  Loading refined map from: {PATH_TO_REFINED_MAP_CHECKPOINT}")
        if not os.path.exists(PATH_TO_REFINED_MAP_CHECKPOINT):
            logger.error(f"  CRITICAL ERROR: Refined map checkpoint '{PATH_TO_REFINED_MAP_CHECKPOINT}' not found! Aborting.")
            break
        else:
            try:
                refined_checkpoint = torch.load(PATH_TO_REFINED_MAP_CHECKPOINT, map_location=device)

                # Extract the embedding weights tensor
                if 'model_state_dict' in refined_checkpoint:
                    refined_embeddings = refined_checkpoint['model_state_dict']['embedding.weight'].clone()
                else: # Assumes you saved only the model's state_dict
                    refined_embeddings = refined_checkpoint['embedding.weight'].clone()

                logger.info(f"  Original embedding matrix shape: {refined_embeddings.shape}")

                # --- THE CORE OF THE ABLATION ---
                num_embeddings = refined_embeddings.size(0)
                shuffled_indices = torch.randperm(num_embeddings, device=device)
                shuffled_embeddings = refined_embeddings[shuffled_indices]
                logger.info("  Embedding matrix rows have been randomly shuffled.")

                # Load the new, shuffled embeddings into our model's embedding layer
                with torch.no_grad():
                    model.embedding.weight.data.copy_(shuffled_embeddings)
                logger.info("  Shuffled map loaded into the new model's embedding layer.")

                del refined_checkpoint, refined_embeddings, shuffled_embeddings

            except Exception as e:
                logger.error(f"  Failed to load or shuffle weights: {e}")
                logger.exception("Traceback:")
                break

        # --- 4. INITIALIZE THE REST OF THE MODEL ---
        # This part is crucial: the rest of the model must be random
        model.apply(init_reasoner_kaiming) # <-- Using the new Kaiming function
        logger.info("  All other model layers (reasoner) have been freshly initialized (Kaiming Uniform).")

        # Tie the weights *after* all initialization is done
        model.final_linear.weight = model.embedding.weight
        logger.info("  Final linear layer weight tied to shuffled embedding map.")

        model.to(device)
        logger.info(f"Ablation model is ready on {device}.")

        # --- 5. SETUP OPTIMIZER, SCHEDULER, AND SCALER ---
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=PEAK_LEARNING_RATE, betas=(0.9, 0.98),
            eps=1e-9, weight_decay=WEIGHT_DECAY
        )
        scheduler = get_cosine_schedule_with_warmup(
            optimizer=optimizer, num_warmup_steps=WARMUP_STEPS,
            num_training_steps=STEPS_PER_ITERATION
        )
        scaler = torch.cuda.amp.GradScaler()

        # --- 6. INNER TRAINING LOOP ---
        model.train()
        global_step_this_iteration = 0
        best_bleu_this_iteration = 0.0
        LAST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, "last.pt")
        BEST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, "best.pt")

        progress_bar = tqdm(total=STEPS_PER_ITERATION, desc=f"Iter {iteration+1} Progress")

        training_complete = False
        for epoch in range(200): # High epoch count, will be broken by step limit
            if training_complete: break
            train_dataloader.generator.manual_seed(SEED + iteration * 100 + epoch)

            for batch in train_dataloader:
                if global_step_this_iteration >= STEPS_PER_ITERATION:
                    training_complete = True
                    break

                optimizer.zero_grad(set_to_none=True)

                input_ids = batch['input_ids'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)
                decoder_start_token = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=device)
                decoder_input_ids = torch.cat([decoder_start_token, labels[:, :-1]], dim=1)
                decoder_input_ids[decoder_input_ids == -100] = tokenizer.pad_token_id
                target_labels = labels

                src_padding_mask, tgt_padding_mask, mem_key_padding_mask, tgt_mask = model.create_masks(input_ids, decoder_input_ids)
                tgt_padding_mask[:, 0] = False

                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    model_outputs = model(src=input_ids, tgt=decoder_input_ids, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=mem_key_padding_mask, tgt_mask=tgt_mask)
                    loss, loss_components = calculate_combined_loss(model_outputs, target_labels)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()

                global_step_this_iteration += 1
                progress_bar.update(1)
                lr = scheduler.get_last_lr()[0]

                if global_step_this_iteration % 20 == 0: # Log more frequently
                    writer.add_scalar('train/loss', loss.item(), global_step_this_iteration)
                    writer.add_scalar('train/learning_rate', lr, global_step_this_iteration)
                    writer.add_scalar('train/gradient_norm', total_grad_norm.item(), global_step_this_iteration)
                    progress_bar.set_postfix(loss=loss.item(), grad_norm=f"{total_grad_norm.item():.2f}", lr=f"{lr:.2e}")

                # --- VALIDATION LOGIC INSIDE THE LOOP ---
                if global_step_this_iteration in VALIDATION_SCHEDULE:
                    logger.info(f"\n--- Validation at Step {global_step_this_iteration} (Iter {iteration+1}) ---")
                    bleu_score = evaluate(model, val_dataloader, device)
                    writer.add_scalar('validation/bleu', bleu_score, global_step_this_iteration)
                    logger.info(f"Validation BLEU: {bleu_score:.4f} (Best this iter: {best_bleu_this_iteration:.4f})")
                    generate_sample_translations(model, device, sample_sentences_de_for_tracking)

                    if bleu_score > best_bleu_this_iteration:
                        best_bleu_this_iteration = bleu_score
                        logger.info(f" New best BLEU for this iteration! Saving best model...")
                        torch.save(model.state_dict(), BEST_CHECKPOINT_PATH) # Save just the state_dict

                    model.train() # Ensure model is back in train mode

        progress_bar.close()
        writer.close()
        logger.info(f"--- Training for Iteration {iteration+1} finished after {global_step_this_iteration} steps ---")

        # --- 7. SAVE FINAL STATE & HASHES ---
        torch.save({
            'global_step': global_step_this_iteration,
            'model_state_dict': model.state_dict(),
        }, LAST_CHECKPOINT_PATH)
        logger.info(f"Saved final state (including embedding map) to: {LAST_CHECKPOINT_PATH}")

        # This is the crucial link (though not needed for 1 iteration)
        previous_iteration_checkpoint_path = LAST_CHECKPOINT_PATH

        logger.info("--- Creating digital fingerprints for key artifacts ---")
        files_to_hash = {
            "Last Model": LAST_CHECKPOINT_PATH,
            "Best Model": BEST_CHECKPOINT_PATH,
            "Text Log": LOG_FILE_TXT,
        }
        try:
            tb_log_file = [f for f in os.listdir(LOG_DIR_TENSORBOARD) if 'tfevents' in f][0]
            files_to_hash["TensorBoard Log"] = os.path.join(LOG_DIR_TENSORBOARD, tb_log_file)
        except IndexError:
            logger.warning("Could not find TensorBoard events file to hash.")

        create_checksum_file(CURRENT_RUN_DIR, files_to_hash) # Assumes this function is defined

    print("\n\n" + "*"*80)
    print(" ALL ITERATIONS COMPLETE ")
    print("*"*80)


STARTING SHUFFLE ABLATION (as Iteration 1)

2025-11-17 13:46:48,390 [INFO] --- LAUNCHING EXPERIMENT: iterative-30k-seed-115-20L-SHUFFLED-ABLATION | ITERATION: 1 ---
2025-11-17 13:46:48,395 [INFO] Configurations saved to /content/drive/MyDrive/iterative/iterative-30k-seed-115-20L-SHUFFLED-ABLATION/iter_1/config.json
2025-11-17 13:46:48,397 [INFO] Reproducibility seed set to 115
2025-11-17 13:46:48,398 [INFO] --- ABLATION STUDY: Initializing with a SHUFFLED Semantic Map ---
2025-11-17 13:46:48,904 [INFO]   Loading refined map from: /content/drive/MyDrive/iterative/iterative-30k-dataset-seed-115-20-layered-transformer-1e-4/iter_1/models/best.pt
2025-11-17 13:47:06,297 [INFO]   Original embedding matrix shape: torch.Size([58101, 512])
2025-11-17 13:47:06,420 [INFO]   Embedding matrix rows have been randomly shuffled.
2025-11-17 13:47:06,444 [INFO]   Shuffled map loaded into the new model's embedding layer.
2025-11-17 13:47:07,033 [INFO]   All other model layers (reasoner) have been freshl

  scaler = torch.cuda.amp.GradScaler()


Iter 1 Progress:   0%|          | 0/3600 [00:00<?, ?it/s]

2025-11-17 13:48:05,783 [INFO] 
--- Validation at Step 200 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 13:48:32,764 [INFO] Validation BLEU: 0.0333 (Best this iter: 0.0000)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A group of a..
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a a..
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: Two men are are..
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: Two men in a in a.
--------------------
2025-11-17 13:48:33,214 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 13:49:32,005 [INFO] 
--- Validation at Step 400 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 13:51:37,087 [INFO] Validation BLEU: 0.1090 (Best this iter: 0.0333)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A person sits sits on a side.
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a red shirt is sitting to a book.
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: A person is taking a picture of a mountain.
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: People in orangess to a very very very very subway.
--------------------
2025-11-17 13:51:37,962 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 13:53:33,517 [INFO] 
--- Validation at Step 800 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 13:54:31,330 [INFO] Validation BLEU: 0.2490 (Best this iter: 0.1090)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A cat sits on the floor.
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a red shirt is reading a book.
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: There is in the middle of people's.
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: Ittmers in a very very oven.
--------------------
2025-11-17 13:54:32,019 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 13:56:27,759 [INFO] 
--- Validation at Step 1200 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 13:57:16,153 [INFO] Validation BLEU: 0.3155 (Best this iter: 0.2490)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A cat sits on the mat.
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a red shirt reads a book.
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: There is the sky of the sky.
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: Iorsors going for what's for what.
--------------------
2025-11-17 13:57:16,989 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 14:01:05,539 [INFO] 
--- Validation at Step 2000 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 14:01:28,869 [INFO] Validation BLEU: 0.3640 (Best this iter: 0.3155)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A cat sits on the mat.
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a red shirt is reading a book.
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: What's the sea?
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: Icing at a movie shot of movie.
--------------------
2025-11-17 14:01:29,512 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 14:05:18,717 [INFO] 
--- Validation at Step 2800 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 14:06:00,604 [INFO] Validation BLEU: 0.3745 (Best this iter: 0.3640)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A cat sits on the mat.
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a red shirt is reading a book.
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: What's the Berlin?
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: Icing in a foreign movie meeting.
--------------------
2025-11-17 14:06:01,322 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 14:09:49,812 [INFO] 
--- Validation at Step 3600 (Iter 1) ---


Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

2025-11-17 14:10:22,688 [INFO] Validation BLEU: 0.3752 (Best this iter: 0.3745)

--- Generating Sample Translations (with Beam Search) ---
  DE Source: Eine Katze sitzt auf der Matte.
  EN Output: A cat sits on the mat.
--------------------
  DE Source: Ein Mann in einem roten Hemd liest ein Buch.
  EN Output: A man in a red shirt is reading a book.
--------------------
  DE Source: Was ist die Hauptstadt von Deutschland?
  EN Output: What's the Berlin?
--------------------
  DE Source: Ich gehe ins Kino, weil der Film sehr gut ist.
  EN Output: Icing going to movie movie in movie.
--------------------
2025-11-17 14:10:23,372 [INFO]  New best BLEU for this iteration! Saving best model...
2025-11-17 14:10:25,070 [INFO] --- Training for Iteration 1 finished after 3600 steps ---
2025-11-17 14:10:27,553 [INFO] Saved final state (including embedding map) to: /content/drive/MyDrive/iterative/iterative-30k-seed-115-20L-SHUFFLED-ABLATION/iter_1/models/last.pt
2025-11-17 14:10:27,554 [INFO] ---

In [None]:
# TENSORBOARD VISUALIZATION

%load_ext tensorboard

TENSORBOARD_BASE_DIR = os.path.join(DRIVE_BASE_PATH, experiment_base_name)

%tensorboard --logdir "{TENSORBOARD_BASE_DIR}"

<IPython.core.display.Javascript object>

## Final Test

In [None]:
from google.colab import runtime
runtime.unassign()

## End