### Author: Shams

**Description:**
This is the core training script. I am building a bilingual Encoder-Decoder model from scratch. I am initializing a fresh neural network with random weights and teaching it English and Egyptian Arabic from zero.

**Architecture & Strategy:**

1.  **Encoder-Decoder Architecture:**
    I chose this because I need the model to understand the full context (Encoder) and then generate text step-by-step (Decoder). It has 8 Encoder layers and 8 Decoder layers with a hidden size of 384. It's smaller than the standard base models, but it fits my data and hardware better.

2.  **Weighted Data Sampling:**
    Not all my data is equal. I have some really clean English text and some high-quality Arabic, but I also have some "meh" scraped data. If I just mixed them equally, the model would learn bad habits.
    *   **The Big English:** 54% (The backbone of grammar).
    *   **High-Tier Arabic:** 24% (The target language).
    *   **Lower Tiers:** Sampled much less (just for variety).
    I use `interleave_datasets` to mix these streams based on these exact percentages.

3.  **The RMSNorm Fix (Crucial):**
    In my previous attempts, the training kept crashing. The loss would go to `NaN` (Not a Number) because of "gradient explosions." The standard `LayerNorm` used in this architecture was too unstable for half-precision (FP16) training.
    *   **Solution:** I wrote a custom script to physically swap out every `LayerNorm` layer in the model and replace it with `RMSNorm` (Root Mean Square Normalization). This is much more stable and fixed the explosion issue.

4.  **Manual Training Loop:**
    I am not using the standard `.train()` method. I wrote a manual loop to control exactly how the state is saved and loaded. This allows me to save the exact position of the data iterator so I don't restart the dataset from row 0 every time I pause.

### 1. Installs & Imports
Getting the environment ready. I need `accelerate` to handle the GPU hardware and `datasets` to stream the massive text files without crashing RAM.

In [None]:
!pip install -q transformers datasets accelerate tensorboard bitsandbytes

import os
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from datasets import load_dataset, interleave_datasets, IterableDataset, DatasetDict, Dataset
from transformers import (
    AutoTokenizer,
    BartConfig,
    BartForConditionalGeneration,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    DataCollatorForLanguageModeling,
    PreTrainedTokenizerBase,
    TrainerCallback,
    TrainerState,
    TrainerControl,
    get_scheduler,
)
from transformers.utils import PaddingStrategy
from accelerate import Accelerator

import warnings
warnings.filterwarnings("ignore")

### 2. Configuration
This is where I define the weighted sampling strategy. You can see in `DataArgs` how I assign probabilities to each dataset split. 

I also set `max_grad_norm=1.5` here. This clips the gradients if they get too big, which acts as a second safety net alongside the RMSNorm fix.

In [None]:
from dataclasses import dataclass, field
from typing import Dict, Optional
from transformers import TrainingArguments, HfArgumentParser
from accelerate import Accelerator
import os
import sys

@dataclass
class ModelArgs:
    model_output_dir: str = "/kaggle/working/bart-arz-en-pretrained"
    tokenizer_path: str = "/kaggle/input/bpescratch/ARZ-EN-BART-Tokenizer/"
    dataset_repo_id: str = "Shams03/Tokenized-ARZ-EN-BART"
    cache_dir: str = "/kaggle/working/cache"

    # Architecture specs: 8 layers each, 384 hidden dim
    encoder_layers: int = 8
    decoder_layers: int = 8
    d_model: int = 384
    num_heads: int = 12
    ffn_dim: int = 1152

@dataclass
class DataArgs:
    # --- WEIGHTED SAMPLING STRATEGY ---
    # I give higher probability to clean English and High-Tier Arabic.
    # This forces the model to learn from good data more often.
    sampling_probabilities: Dict[str, float] = field(default_factory=lambda: {
        "TheBigEN": 0.54, 
        "A_Tier_ARZ": 0.24, 
        "S_Tier_ARZ": 0.09, 
        "B_Tier_ARZ": 0.07,
        "parallel_EN": 0.04, 
        "LparallelEN": 0.02, 
        "parallel_ARZ": 0.04, 
        "Lparallel_ARZ": 0.02
    })
    streaming: bool = True
    max_seq_length: int = 256

@dataclass
class PretrainArgs(TrainingArguments):
    output_dir: str = "/kaggle/working/bart-arz-en-pretrained"
    max_steps: int = 320000
    per_device_train_batch_size: int = 15
    gradient_accumulation_steps: int = 10

    learning_rate: float = 1e-5
    warmup_steps: int = 1000
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.98
    lr_scheduler_type: str = 'linear'
    num_train_epochs: float = 100.0

    logging_strategy: str = "steps"
    logging_steps: int = 100
    save_strategy: str = "steps"
    save_steps: int = 1000
    save_total_limit: int = 1
    fp16: bool = True
    dataloader_num_workers: int = 1
    seed: int = 42
    report_to: str = "tensorboard"
    resume_from_checkpoint: bool = True

    # Safety nets for gradient explosion
    adam_epsilon=1e-6, 
    max_grad_norm=1.5 
    torch_compile: bool = True
    gradient_checkpointing: bool = False
    mlm_probability: float = 0.15

parser = HfArgumentParser((ModelArgs, DataArgs, PretrainArgs))

if "ipykernel" in sys.modules:
    model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=[])
else:
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

os.makedirs(model_args.model_output_dir, exist_ok=True)
os.makedirs(model_args.cache_dir, exist_ok=True)

print("Model Args:")
print(model_args)
print("\nData Args:")
print(data_args)
print("\nTraining Args:")
print(f"  Output: {training_args.output_dir}")
print(f"  Steps: {training_args.max_steps}")
print(f"  LR: {training_args.learning_rate}")
print(f"  FP16: {training_args.fp16}")

accelerator = Accelerator()
print("\nHardware:")
print(f"GPUs: {accelerator.num_processes}")
print(f"Batch Size: {training_args.per_device_train_batch_size}")

num_gpus = max(1, accelerator.num_processes)
eff_batch_size = training_args.per_device_train_batch_size * num_gpus * training_args.gradient_accumulation_steps
print(f"Effective Batch: {eff_batch_size}")

### 3. Load Tokenizer
Loading the custom BPE tokenizer. I have to make sure the special tokens are mapped correctly. The Encoder-Decoder model relies on `<s>` (start) and `</s>` (end) to know when a sentence begins and ends. If these IDs are wrong, the model learns nothing.

In [None]:
print(f"Loading tokenizer: {model_args.tokenizer_path}")
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_path,
    cache_dir=model_args.cache_dir
)

print("Loaded.")
print(f"Type: {type(tokenizer)}")
print(f"Vocab: {tokenizer.vocab_size}")
assert tokenizer.vocab_size == 90000, "Wrong vocab size!"

# Checking special tokens. This is critical for the model structure.
print(f"PAD: {tokenizer.pad_token} ID: {tokenizer.pad_token_id}")
print(f"UNK: {tokenizer.unk_token} ID: {tokenizer.unk_token_id}")
print(f"MASK: {tokenizer.mask_token} ID: {tokenizer.mask_token_id}")
print(f"BOS: {tokenizer.bos_token} ID: {tokenizer.bos_token_id}")
print(f"EOS: {tokenizer.eos_token} ID: {tokenizer.eos_token_id}")
print(f"CLS: {tokenizer.cls_token} ID: {tokenizer.cls_token_id}")
print(f"SEP: {tokenizer.sep_token} ID: {tokenizer.sep_token_id}")

# Assertions to ensure compatibility with the model config
assert tokenizer.bos_token_id == 3
assert tokenizer.eos_token_id == 4
assert tokenizer.cls_token_id == tokenizer.bos_token_id
assert tokenizer.sep_token_id == tokenizer.eos_token_id
assert tokenizer.add_prefix_space == True

vocab_size = tokenizer.vocab_size
pad_token_id = tokenizer.pad_token_id
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id

### 4. Data Loading & Interleaving
Here I implement the weighted sampling. I load each split separately. Then I use `interleave_datasets` with the `probabilities` list I defined in the config. 

This creates a single stream of data where 54% of rows are from the big English dataset, 24% from the best Arabic, and so on. I also add a `shuffle` buffer to mix them up locally so the model doesn't see 1000 English sentences in a row.

In [None]:
print(f"Loading data from: {model_args.dataset_repo_id}")

available_splits = [
    'A_Tier_ARZ', 'B_Tier_ARZ', 'Lparallel_ARZ', 'LparallelEN',
    'S_Tier_ARZ', 'TheBigEN', 'parallel_ARZ', 'parallel_EN'
]
sampling_probabilities = data_args.sampling_probabilities
split_datasets = []
split_probs = []
total_prob = 0

print("Loading splits...")
for split_name in available_splits:
    if split_name not in sampling_probabilities:
        continue
    prob = sampling_probabilities[split_name]
    print(f"- Loading {split_name} (Prob: {prob})")
    try:
        ds = load_dataset(
            model_args.dataset_repo_id,
            split=split_name,
            streaming=data_args.streaming,
            cache_dir=model_args.cache_dir,
        )
        split_datasets.append(ds)
        split_probs.append(prob)
        total_prob += prob
    except Exception as e:
        print(f"  Failed to load {split_name}: {e}")
        raise RuntimeError(f"Missing dataset: {split_name}")

if not split_datasets: raise RuntimeError("No data loaded.")

# Normalizing probs to sum exactly to 1.0
if not (0.99 < total_prob < 1.01):
    norm_factor = 1.0 / total_prob
    split_probs = [p * norm_factor for p in split_probs]

print("Mixing datasets...")
interleaved_dataset = interleave_datasets(
    split_datasets,
    probabilities=split_probs,
    seed=training_args.seed,
    stopping_strategy="all_exhausted"
)

# Shuffle buffer to ensure randomness in the stream
shuffle_buffer_size = 100000
print(f"Shuffling with buffer {shuffle_buffer_size}...")
combined_dataset = interleaved_dataset.shuffle(
    seed=training_args.seed,
    buffer_size=shuffle_buffer_size
)
print("Data ready.")

if data_args.streaming:
    try:
        first_example = next(iter(combined_dataset))
        print(f"Check first example: {first_example.keys()}")
    except Exception as e: 
        print(f"Error checking data: {e}")

### 4.5 Data Inspection
Just a quick sanity check to see the raw IDs. I want to make sure I'm getting actual data and not empty rows.

In [None]:
num_examples_to_show = 5
actual_count = 0

try:
    dataset_iterator = iter(combined_dataset)
    for i in range(num_examples_to_show):
        example = next(dataset_iterator)
        input_ids = example['input_ids']
        clean_decoded = tokenizer.decode(input_ids, skip_special_tokens=True)

        print(f"\nExample {i+1}:")
        print(f"  IDs: {input_ids[:10]}...")
        print(f"  Text: '{clean_decoded}'")
        actual_count += 1

except Exception as e:
    print(f"Error: {e}")

### 5. Data Collator (The Masking Logic)
This is the pre-training objective. The model reads a sentence with "holes" in it and has to guess what's missing.

1.  **Padding:** Pads the batch to the same length.
2.  **Special Tokens:** It adds `<s>` at the start and `</s>` at the end.
3.  **Masking:** It randomly selects 15% of the tokens to hide.
4.  **Protection:** It specifically checks `special_tokens_mask` to make sure we NEVER mask the `<s>`, `</s>`, or `<pad>` tokens. If we masked those, the model would lose track of where sentences begin/end.

In [None]:
@dataclass
class DataCollatorForBartPretraining:
    tokenizer: PreTrainedTokenizerBase
    mlm_probability: float = 0.15
    pad_to_multiple_of: Optional[int] = None
    max_seq_length: Optional[int] = None

    def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        batch_input_ids = [e['input_ids'] for e in examples]

        processed_batch = {
            "input_ids": [],
            "labels": []
        }

        max_len_no_special = self.max_seq_length - 2 if self.max_seq_length else None

        for ids in batch_input_ids:
            if max_len_no_special:
                ids = ids[:max_len_no_special]
            # manually adding BOS and EOS
            processed_batch["input_ids"].append([self.tokenizer.bos_token_id] + ids + [self.tokenizer.eos_token_id])
            processed_batch["labels"].append([self.tokenizer.bos_token_id] + ids + [self.tokenizer.eos_token_id])

        # pad labels first
        labels_batch = self.tokenizer.pad(
            {"input_ids": processed_batch["labels"]},
            padding='longest',
            max_length=self.max_seq_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        labels = labels_batch["input_ids"]

        # pad inputs
        inputs_batch = self.tokenizer.pad(
            {"input_ids": processed_batch["input_ids"]},
            padding='longest',
            max_length=self.max_seq_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        input_ids = inputs_batch["input_ids"]
        attention_mask = inputs_batch["attention_mask"]

        # apply the masking logic
        inputs, labels = self.mask_tokens(input_ids, labels)

        return {
            "input_ids": inputs,
            "attention_mask": attention_mask,
            "labels": labels,
        }

    def mask_tokens(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        masked_inputs = inputs.clone()
        mlm_labels = labels.clone()

        probability_matrix = torch.full(mlm_labels.shape, self.mlm_probability)

        # prevent masking special tokens. THIS IS IMPORTANT.
        padding_mask = mlm_labels.eq(self.tokenizer.pad_token_id)
        bos_mask = mlm_labels.eq(self.tokenizer.bos_token_id)
        eos_mask = mlm_labels.eq(self.tokenizer.eos_token_id)

        special_tokens_mask = padding_mask | bos_mask | eos_mask
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()

        # set unmasked labels to -100 so loss ignores them
        mlm_labels[~masked_indices] = -100

        # 80% replace with [MASK]
        indices_replaced = torch.bernoulli(torch.full(mlm_labels.shape, 0.8)).bool() & masked_indices
        masked_inputs[indices_replaced] = self.tokenizer.mask_token_id

        # 10% replace with random word
        indices_random = torch.bernoulli(torch.full(mlm_labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(low=0, high=self.tokenizer.vocab_size, size=mlm_labels.shape, dtype=torch.long)
        
        # safety check for random words (don't pick special tokens)
        for special_id in [self.tokenizer.pad_token_id, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id]:
            random_words[random_words == special_id] = self.tokenizer.mask_token_id

        masked_inputs[indices_random] = random_words[indices_random]

        return masked_inputs, mlm_labels

data_collator = DataCollatorForBartPretraining(
    tokenizer=tokenizer,
    mlm_probability=training_args.mlm_probability,
    max_seq_length=data_args.max_seq_length,
    pad_to_multiple_of=8
)

print("Collator ready.")

### 6. Model Initialization & RMSNorm Fix
This is the most critical part of the architecture setup. 

1.  **Initialization:** I use `init_std=0.02` and a custom `init_weights` function. Randomly initialized Transformers are very fragile. If the weights are too big, the math breaks immediately.
2.  **RMSNorm Replacement:** Standard `LayerNorm` calculates `(x - mean) / std`. In deep networks with mixed precision (FP16), `std` can get tiny, causing division by zero or huge numbers. `RMSNorm` removes the `mean` calculation and just scales by the root mean square. It is mathematically safer. I wrote a function `replace_layernorm_with_rmsnorm` that hunts down every `LayerNorm` in the model and swaps it.

In [None]:
print("Initializing Encoder-Decoder Model...")

config = BartConfig(
    vocab_size=vocab_size,
    pad_token_id=pad_token_id,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    encoder_layers=model_args.encoder_layers,
    decoder_layers=model_args.decoder_layers,
    d_model=model_args.d_model,
    encoder_attention_heads=model_args.num_heads,
    decoder_attention_heads=model_args.num_heads,
    encoder_ffn_dim=model_args.ffn_dim,
    decoder_ffn_dim=model_args.ffn_dim,
    activation_function="gelu",
    dropout=0.1,
    attention_dropout=0.1,
    activation_dropout=0.1,
    scale_embedding=True,
    init_std=0.02
)

model = BartForConditionalGeneration(config=config)

# --- THE RMSNORM FIX ---
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.scale

def replace_layernorm_with_rmsnorm(module: nn.Module):
    # Recursive function to find and kill LayerNorm
    for name, child in list(module.named_children()):
        if isinstance(child, nn.LayerNorm):
            dim = child.normalized_shape[0] if isinstance(child.normalized_shape, (tuple, list)) else child.normalized_shape
            rms = RMSNorm(dim=dim, eps=1e-6)
            # try to preserve weights if they exist
            if getattr(child, "weight", None) is not None:
                try:
                    with torch.no_grad():
                        rms.scale.copy_(child.weight)
                except Exception:
                    pass
            setattr(module, name, rms)
        else:
            replace_layernorm_with_rmsnorm(child)

# --- Robust Init ---
# Standard init can be unstable. This ensures weights start small and safe.
def init_weights(module):
    if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
        try:
            fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(module.weight)
            std = (fan_in ** -0.5) if fan_in > 0 else 0.02
        except Exception:
            std = 0.02
        nn.init.normal_(module.weight, mean=0.0, std=std)
        if getattr(module, "bias", None) is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        d = module.weight.size(1)
        std = d ** -0.5
        nn.init.normal_(module.weight, mean=0.0, std=std)
    elif isinstance(module, RMSNorm):
        nn.init.ones_(module.scale)

print("Swapping LayerNorm for RMSNorm...")
replace_layernorm_with_rmsnorm(model)

print("Applying custom initialization...")
model.apply(init_weights)

model_size = sum(t.numel() for t in model.parameters())
print(f"Model Ready. Parameters: {model_size / 1_000_000:.1f} M")

### 7. Accelerator Setup
Preparing the `Accelerator`. This library handles the heavy lifting of moving tensors to the GPU and managing mixed precision (FP16). I initialize the `Trainer` here, but primarily to use its underlying utilities, not its main loop.

In [None]:
import math

try:
    mixed_precision_arg = "fp16" if getattr(training_args, "fp16", False) else "no"
    accelerator = Accelerator(mixed_precision=mixed_precision_arg)
    print(f"Accelerator: {mixed_precision_arg}")
except TypeError:
    accelerator = Accelerator()
    print("Accelerator: Default")

device = accelerator.device
print(f"Device: {device}")

print("Initializing Trainer wrapper...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=combined_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer
)
print("Trainer ready.")

### 7.5 Collator Check
A final check before the loop. I grab one batch and check for `NaN` values. If the inputs are corrupted here, the whole training is doomed.

In [None]:
print("Checking batch...")
try:
    dataloader = trainer.get_train_dataloader()
    batch = next(iter(dataloader))
    
    input_ids_ok = not torch.isnan(batch['input_ids']).any()
    labels_ok = not torch.isnan(batch['labels']).any()
    
    print(f"  Inputs OK: {input_ids_ok}")
    print(f"  Labels OK: {labels_ok}")

    if not (input_ids_ok and labels_ok):
        print("  WARNING: Bad batch detected.")

except Exception as e:
    print(f"  Error checking batch: {e}")

if 'dataloader' in locals():
   del dataloader

### 8. The Manual Training Loop
I am writing my own training loop instead of using `trainer.train()`. Why? **Control.**

1.  **State Loading:** Standard trainers often restart the data iterator from the beginning when resuming. Since I have massive datasets, I need to save the `dataloader_state.pt` and `dataset_state.pt` so I can resume exactly where I left off (skipping millions of rows instantly).
2.  **Gradient Scaling:** I manually control the `GradScaler` for FP16 to prevent underflow.
3.  **Diagnostics:** Every 100 steps, I run a small evaluation on specific sentences (like "Today `<mask>` a beautiful day") to see if the model is learning grammar live.
4.  **Checkpointing:** I save everything—Optimizer, Scheduler, Model, and Data state—into one folder.

In [None]:
from contextlib import nullcontext
import math
import shutil
import re
import json
from transformers.trainer_utils import get_last_checkpoint
import torch 

print("Starting Training Loop.")

train_dataloader = trainer.get_train_dataloader()

optimizer = AdamW(model.parameters(), lr=training_args.learning_rate,
                  betas=(training_args.adam_beta1, training_args.adam_beta2),
                  eps=getattr(training_args, "adam_epsilon", 1e-6),
                  weight_decay=training_args.weight_decay)

scheduler = get_scheduler(
    name=training_args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=int(getattr(training_args, "warmup_steps", 0)),
    num_training_steps=training_args.max_steps,
)

use_native_amp = bool(getattr(training_args, "fp16", False) and torch.cuda.is_available())
scaler = GradScaler() if use_native_amp else None

model = accelerator.prepare(model)

# Diagnostic sentences to watch learning progress
mask = tokenizer.mask_token 
test_sentences = [
    f"الجو برد اوي النهاردة، أنا هقعد {mask} البيت.",
    f"He bought bread {mask} milk from the store.",
    f"Today {mask} a beautiful day.",
    f"أنا صحيت الصبح و أول حاجة عملتها إني {mask}{mask} كوباية شاي.",
]
top_k = 5

# --- RESUME LOGIC ---
global_step = 0
start_micro_step = 0 
did_load_iterator_state = False

# If I have a checkpoint path, I use it here
MANUAL_CHECKPOINT_PATH = "/kaggle/input/arz-en-bart/bart-arz-en-pretrained/checkpoint-214000"

resume_ckpt = None
if training_args.resume_from_checkpoint and MANUAL_CHECKPOINT_PATH and os.path.isdir(MANUAL_CHECKPOINT_PATH):
    resume_ckpt = MANUAL_CHECKPOINT_PATH

if resume_ckpt:
    print(f"Resuming from: {resume_ckpt}")
    try:
        accelerator.load_state(resume_ckpt)
        print("Model state loaded.")
    except Exception as e: print(f"Warn: Model load failed: {e}")
        
    # Loading optimizer/scheduler states
    opt_path = os.path.join(resume_ckpt, "optimizer.pt")
    if os.path.exists(opt_path):
        optimizer.load_state_dict(torch.load(opt_path, map_location=accelerator.device))

    sch_path = os.path.join(resume_ckpt, "scheduler.pt")
    if os.path.exists(sch_path):
        scheduler.load_state_dict(torch.load(sch_path, map_location="cpu"))

    if scaler and os.path.exists(os.path.join(resume_ckpt, "scaler.pt")):
        scaler.load_state_dict(torch.load(os.path.join(resume_ckpt, "scaler.pt"), map_location="cpu"))

    m = re.search(r"-(\d+)$", resume_ckpt)
    if m:
        global_step = int(m.group(1))
        start_micro_step = global_step * training_args.gradient_accumulation_steps

    # Loading dataset state. This is key to not restarting data from row 0.
    ds_state_path = os.path.join(resume_ckpt, "dataset_state.pt")
    if os.path.exists(ds_state_path):
        combined_dataset.load_state_dict(torch.load(ds_state_path, map_location="cpu"))
        train_dataloader = trainer.get_train_dataloader()
        
        dl_state_path = os.path.join(resume_ckpt, "dataloader_state.pt")
        if os.path.exists(dl_state_path):
            train_dataloader.load_state_dict(torch.load(dl_state_path, map_location="cpu"))
            did_load_iterator_state = True
            print("Dataset iterator restored.")

train_dataloader = accelerator.prepare(train_dataloader)

model.train()
start_time = time.time()
grad_accum = training_args.gradient_accumulation_steps
max_steps = training_args.max_steps
train_dataloader_iter = iter(train_dataloader)

# Fast forward if we failed to load iterator state (fallback)
if start_micro_step > 0 and not did_load_iterator_state:
    print(f"Fast-forwarding {start_micro_step} steps...")
    for _ in range(start_micro_step):
        next(train_dataloader_iter)
    print("Done.")

optimizer.zero_grad()

# --- MAIN LOOP ---
for micro_step, batch in enumerate(train_dataloader_iter, start=start_micro_step):

    # Calculate where we are globally
    if (micro_step + 1) % grad_accum == 0:
        prospective_global = (micro_step + 1) // grad_accum
    else:
        prospective_global = global_step 

    if prospective_global >= max_steps:
        break

    input_ids = batch["input_ids"]
    attention_mask = batch.get("attention_mask", None)
    labels = batch.get("labels", None)

    # Forward Pass
    if scaler is not None:
        with autocast():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss / max(1, grad_accum)
        scaler.scale(loss).backward()
    else:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss / max(1, grad_accum)
        loss.backward()

    # Optimizer Step (only after accumulation)
    if (micro_step + 1) % grad_accum == 0:
        global_step = (micro_step + 1) // grad_accum

        if scaler is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
            optimizer.step()

        scheduler.step()
        optimizer.zero_grad()

        # Logging
        if global_step % training_args.logging_steps == 0 or global_step == 1:
            elapsed = time.time() - start_time
            reported_loss = (loss.item() * grad_accum)
            current_lr = scheduler.get_last_lr()[0] 
            print(f"\nStep {global_step} | Loss: {reported_loss:.4f} | LR: {current_lr:.2e} | Time: {elapsed:.0f}s")

            # Live Evaluation
            model.eval()
            try:
                with torch.no_grad():
                    for sent in test_sentences:
                        sent_fmt = sent.replace("<mask>", tokenizer.mask_token)
                        enc = tokenizer(sent_fmt, return_tensors="pt", truncation=True, max_length=128).to(accelerator.device)
                        outputs = model(**enc)
                        logits = outputs.logits
                        mask_pos = (enc.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=False)
                        if len(mask_pos) > 0:
                            print(f"  Input: {sent}")
                            for m in mask_pos:
                                b, p = m[0], m[1]
                                vals, ids = torch.topk(logits[b, p], k=top_k)
                                toks = tokenizer.convert_ids_to_tokens(ids)
                                print(f"    Pred: {toks}")
            except Exception as e: print(f"Eval Error: {e}")
            model.train()

        # Checkpointing
        if global_step % training_args.save_steps == 0:
            ckpt_dir = os.path.join(training_args.output_dir, f"checkpoint-{global_step}")
            print(f"Saving to {ckpt_dir}...")
            accelerator.wait_for_everyone()
            accelerator.save_state(ckpt_dir)
            
            if accelerator.is_main_process:
                accelerator.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt"))
                accelerator.save(scheduler.state_dict(), os.path.join(ckpt_dir, "scheduler.pt"))
                if scaler: accelerator.save(scaler.state_dict(), os.path.join(ckpt_dir, "scaler.pt"))
                # Saving dataset states!
                accelerator.save(combined_dataset.state_dict(), os.path.join(ckpt_dir, "dataset_state.pt"))
                accelerator.save(train_dataloader.state_dict(), os.path.join(ckpt_dir, "dataloader_state.pt"))
                tokenizer.save_pretrained(ckpt_dir)

            # Cleanup old checkpoints to save space
            if accelerator.is_main_process:
                all_ckpts = sorted([d for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")])
                if len(all_ckpts) > training_args.save_total_limit:
                    for old in all_ckpts[:-training_args.save_total_limit]:
                        shutil.rmtree(os.path.join(training_args.output_dir, old))
                        print(f"Deleted old: {old}")

print("Finished.")
accelerator.wait_for_everyone()
accelerator.unwrap_model(model).save_pretrained(training_args.output_dir)
if accelerator.is_main_process: tokenizer.save_pretrained(training_args.output_dir)

### 9. Final Inference Test
The training is done. Now I load the saved model and give it a few simple tests. If it predicts reasonable words for the `<mask>` tokens, the pre-training worked.

In [None]:
import torch
from transformers import AutoTokenizer, BartForConditionalGeneration
import torch.nn.functional as F

print("Testing Final Model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    tokenizer_loaded = AutoTokenizer.from_pretrained(model_args.model_output_dir)
    model_loaded = BartForConditionalGeneration.from_pretrained(model_args.model_output_dir)
    model_loaded.eval().to(device)

    test_sentences = [
        f"Hello world, this is a {tokenizer_loaded.mask_token} test.",
        f"النهاردة الجو حر {tokenizer_loaded.mask_token}.",
        f"I need to {tokenizer_loaded.mask_token} the results tomorrow.",
    ]

    for sent in test_sentences:
        print(f"\nInput: {sent}")
        enc = tokenizer_loaded(sent, return_tensors="pt", truncation=True).to(device)
        with torch.no_grad():
            logits = model_loaded(**enc).logits
        
        mask_id = tokenizer_loaded.mask_token_id
        mask_pos = (enc.input_ids == mask_id).nonzero(as_tuple=False)
        
        for m in mask_pos:
            b, p = m[0], m[1]
            vals, ids = torch.topk(logits[b, p], k=5)
            toks = tokenizer_loaded.convert_ids_to_tokens(ids)
            print(f"  Prediction: {toks}")

except Exception as e:
    print(f"Test failed: {e}")