### Author: Shams

**Description:**
This notebook handles the **fine-tuning** phase. We have a pre-trained Encoder-Decoder model (from the previous step) that understands English and Arabic generally. Now, we are teaching it specifically how to translate English into Egyptian Arabic.

**Key Strategy:**

1.  **Loading the Pre-trained Checkpoint:**
    We start with the weights from the pre-training phase. This model has seen a lot of text but hasn't been explicitly told to "translate X to Y".

2.  **Architecture Patch (RMSNorm):**
    Just like in pre-training, we must manually replace `LayerNorm` with `RMSNorm` *immediately* after loading the config. If we don't do this, the weights we load won't match the model structure, and everything will break.

3.  **Data Pipeline:**
    We load the parallel splits (`parallel_EN`, `parallel_ARZ`, etc.) from the dataset. We pair them up row-by-row to create a translation dataset. The `DataCollator` pads batches dynamically.

4.  **Metric (SacreBLEU):**
    Accuracy doesn't work for translation. We use BLEU score (specifically `sacrebleu`) to measure how close the model's output is to the human reference. We also track `gen_len` to make sure the model isn't just outputting empty strings.

5.  **Generation Config:**
    We explicitly set `num_beams=5` and `repetition_penalty=1.5`. This stops the model from getting stuck in loops (e.g., "The car the car the car") and forces it to explore better translations.

6.  **The "Lobotomy" Fix:**
    There is a known issue where loading weights can sometimes untie the input/output embeddings. We have a manual fix in the loading cell to ensure `embed_tokens` and `lm_head` share the same memory, keeping the model smart.

### 1. Installs & Imports
Setting up the fine-tuning environment. We need `sacrebleu` for metrics and `safetensors` for efficient weight loading.

In [82]:
# Install necessary libraries
!pip install --upgrade evaluate transformers datasets accelerate tensorboard bitsandbytes sacrebleu safetensors

import os
import numpy as np
import torch
from dataclasses import dataclass, field 
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    PreTrainedTokenizerBase,
    TrainerCallback,
    TrainerControl,
    TrainerState
)
from typing import Dict, List, Any
from safetensors.torch import load_file

# Suppress warnings for cleaner logs
import warnings
warnings.filterwarnings("ignore")

...

### 2. Configuration
Defining where to find the pre-trained model and how to train it.
- **ModelConfig:** Points to the `checkpoint-2000` from the pre-training run.
- **TrainingConfig:** Sets the batch size (16) and epochs (10). `IN_TEST_MODE` allows for a quick 8000-sample run to verify the pipeline before committing to a full train.

In [83]:
@dataclass
class ModelConfig:
    PRE_TRAINED_MODEL_PATH: str = "/kaggle/input/lasthopr/bart-en-arz-translator/checkpoint-2000"
    DATASET_REPO_ID: str = "Shams03/Tokenized-ARZ-EN-BART"
    FINETUNED_OUTPUT_DIR: str = "/kaggle/working/bart-en-arz-translator"

@dataclass
class TrainingConfig:
    MAX_LENGTH: int = 150
    LEARNING_RATE: float = 1e-5
    NUM_EPOCHS: int = 10
    PER_DEVICE_BATCH_SIZE: int = 16
    GRAD_ACCUMULATION_STEPS: int = 2
    EVAL_STEPS: int = 500
    SAVE_STEPS: int = 2000
    TEST_SPLIT_SIZE: float = 0.01
    LOGGING_STEPS = 100

    IN_TEST_MODE: bool = False 
    TEST_MODE_DATA_SIZE: int = 8000 

model_config = ModelConfig()
training_config = TrainingConfig()

print(f"--- Configuration ---")
print(f"Loading pre-trained model from: {model_config.PRE_TRAINED_MODEL_PATH}")
print(f"Training for {training_config.NUM_EPOCHS} epochs.")
print(f"Batch size: {training_config.PER_DEVICE_BATCH_SIZE} (Accum: {training_config.GRAD_ACCUMULATION_STEPS})")
print(f"Max Length: {training_config.MAX_LENGTH}")

if training_config.IN_TEST_MODE:
    print("      !!! TEST MODE ON !!!     ")
    print(f"Data will be limited to {training_config.TEST_MODE_DATA_SIZE} samples.")

--- Configuration ---
Loading pre-trained model from: /kaggle/input/lasthopr/bart-en-arz-translator/checkpoint-2000
Training for 1 epochs.
Batch size: 17 (Accum: 1)
Max Length: 150


### 3. Load Model & Apply Architecture Patch
This cell does three critical things:
1.  **Loads Config:** Reads the JSON config from the checkpoint.
2.  **Patches Architecture:** Immediately replaces `LayerNorm` with `RMSNorm`. The weights on disk correspond to an RMSNorm model; if we load them into a standard LayerNorm model, the shapes will match but the math will be wrong.
3.  **Repairs Embeddings:** Manually forces the encoder, decoder, and language head embeddings to share memory (`shared_param`). This fixes the "lobotomy" issue where the model forgets language associations after loading.

In [84]:
print(f"\n--- 1. Loading Model & Tokenizer ---")

import torch
import torch.nn as nn
from safetensors.torch import load_file
import os
from transformers import AutoTokenizer, BartConfig, AutoModelForSeq2SeqLM

# --- A. Define RMSNorm ---
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.scale

# --- B. Replacement Function ---
def replace_layernorm_with_rmsnorm(module: nn.Module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.LayerNorm):
            dim = child.normalized_shape[0] if isinstance(child.normalized_shape, (tuple, list)) else child.normalized_shape
            rms = RMSNorm(dim=dim, eps=1e-6)
            setattr(module, name, rms)
        else:
            replace_layernorm_with_rmsnorm(child)

# --- C. Load Tokenizer ---
print(f"Loading tokenizer from: {model_config.PRE_TRAINED_MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(
    model_config.PRE_TRAINED_MODEL_PATH,
    add_prefix_space=True
)

# --- D. Load Model Config & Patch ---
print(f"Loading model CONFIG from: {model_config.PRE_TRAINED_MODEL_PATH}")
config = BartConfig.from_pretrained(model_config.PRE_TRAINED_MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_config(config)

print("Patching model architecture: Replacing LayerNorm with RMSNorm...")
replace_layernorm_with_rmsnorm(model)

# --- E. Load Weights & Fix Embeddings ---
print("Loading saved weights from .safetensors file...")
state_dict_path = os.path.join(model_config.PRE_TRAINED_MODEL_PATH, "model.safetensors")
state_dict = load_file(state_dict_path, device="cpu")

shared_key = 'model.shared.weight'
if shared_key not in state_dict:
    shared_key = 'shared.weight'

if shared_key in state_dict:
    print(f"Found shared weight key: '{shared_key}' - Applying Fix...")
    shared_weight = state_dict[shared_key]
    
    # Tie embeddings manually
    shared_param = nn.Parameter(shared_weight)
    model.model.encoder.embed_tokens.weight = shared_param
    model.model.decoder.embed_tokens.weight = shared_param
    model.lm_head.weight = shared_param
    
    # Remove redundant keys from dict to prevent overwrite
    keys_to_remove = [
        shared_key,
        "model.encoder.embed_tokens.weight",
        "model.decoder.embed_tokens.weight",
        "lm_head.weight"
    ]
    for k in keys_to_remove:
        if k in state_dict:
            del state_dict[k]
            
    print("Success: Embeddings manually tied and conflicting keys removed.")
else:
    print("WARNING: Shared weights not found in file!")

print("Loading remaining weights...")
model.load_state_dict(state_dict, strict=False)
print("Model loaded successfully.")

print(f"\nModel: {type(model)}")
print(f"Tokenizer: {type(tokenizer)}")


--- 1. Loading Model & Tokenizer ---
Loading tokenizer from: /kaggle/input/lasthopr/bart-en-arz-translator/checkpoint-2000
Loading model CONFIG from: /kaggle/input/lasthopr/bart-en-arz-translator/checkpoint-2000
Patching model architecture: Replacing LayerNorm with RMSNorm...
Loading saved weights from .safetensors file...
Found shared weight key: 'model.shared.weight' - Applying Fix...
Success: Embeddings manually tied and conflicting keys removed.
Loading remaining weights...
Model loaded successfully.

Model: <class 'transformers.models.bart.modeling_bart.BartForConditionalGeneration'>
Tokenizer: <class 'transformers.tokenization_utils_fast.PreTrainedTokenizerFast'>


### 4. Data Preparation
Here, we fetch the parallel datasets. The English and Arabic data are stored in separate splits. We load them, detokenize them (turn them back into raw text), and then join them side-by-side into a single dataset with `en` and `arz` columns.

In [85]:
print(f"\n--- 2. Reconstructing Parallel Data (Optimized) ---")
print(f"Loading from repo: {model_config.DATASET_REPO_ID}")

splits_to_load = ['parallel_EN', 'parallel_ARZ', 'LparallelEN', 'Lparallel_ARZ']
print(f"Loading splits: {splits_to_load}...")

ds = load_dataset(model_config.DATASET_REPO_ID, split=splits_to_load, streaming=False)
print("...Splits loaded.")

# Concatenate splits
ds_en_tokenized = concatenate_datasets([ds[0], ds[2]]) 
ds_arz_tokenized = concatenate_datasets([ds[1], ds[3]]) 

if training_config.IN_TEST_MODE:
    print(f"\n--- TEST MODE: Slicing raw data to {training_config.TEST_MODE_DATA_SIZE} samples ---")
    ds_en_tokenized = ds_en_tokenized.select(range(training_config.TEST_MODE_DATA_SIZE))
    ds_arz_tokenized = ds_arz_tokenized.select(range(training_config.TEST_MODE_DATA_SIZE))

if len(ds_en_tokenized) != len(ds_arz_tokenized):
    raise ValueError("Data mismatch! Unequal EN/ARZ rows.")
else:
    print(f"Total parallel sentences loaded: {len(ds_en_tokenized)}")

def detokenize(example):
    return {"text": tokenizer.decode(example['input_ids'], skip_special_tokens=True)}

# Detokenize in parallel
print(f"Detokenizing English data (using {os.cpu_count()} cores)...")
ds_en_text = ds_en_tokenized.map(detokenize, num_proc=os.cpu_count(), remove_columns=ds_en_tokenized.column_names)

print(f"Detokenizing Arabic data (using {os.cpu_count()} cores)...")
ds_arz_text = ds_arz_tokenized.map(detokenize, num_proc=os.cpu_count(), remove_columns=ds_arz_tokenized.column_names)

# Join datasets
print("Creating final paired dataset...")
ds_en_text = ds_en_text.rename_column("text", "en")
ds_arz_text = ds_arz_text.rename_column("text", "arz")
paired_data = concatenate_datasets([ds_en_text, ds_arz_text], axis=1)

print("...Dataset paired successfully.")

# Split train/test
print(f"Creating train/validation split (Test size: {training_config.TEST_SPLIT_SIZE})...")
raw_datasets = paired_data.train_test_split(test_size=training_config.TEST_SPLIT_SIZE, seed=42)

# Cleanup RAM
del ds, ds_en_tokenized, ds_arz_tokenized, ds_en_text, ds_arz_text, paired_data
import gc
gc.collect()

print("\n--- Dataset Reconstruction Complete ---")
print(raw_datasets)
print("\nExample:")
print(f"  EN: {raw_datasets['train'][0]['en']}")
print(f"  ARZ: {raw_datasets['train'][0]['arz']}")


--- 2. Reconstructing Parallel Data (Optimized) ---
Loading from repo: Shams03/Tokenized-ARZ-EN-BART
Loading splits: ['parallel_EN', 'parallel_ARZ', 'LparallelEN', 'Lparallel_ARZ']...
...Splits loaded.
Total parallel sentences loaded: 704350
Detokenizing English data (using 4 cores)...
Detokenizing Arabic data (using 4 cores)...
Creating final paired dataset (using concatenate_datasets axis=1)...
...Dataset paired successfully.
Creating train/validation split (Test size: 0.01)...

--- Dataset Reconstruction Complete ---
DatasetDict({
    train: Dataset({
        features: ['en', 'arz'],
        num_rows: 697306
    })
    test: Dataset({
        features: ['en', 'arz'],
        num_rows: 7044
    })
})

Example:
  EN: 8 month later I walked out the hospital on my own two feet
  ARZ: بعد 8 شهور خرجت من ال مستشفى على رجليا


### 5. Seq2Seq Tokenization
Now we tokenize properly for the Seq2Seq task. 
- The **Encoder** gets the English text (`input_ids`).
- The **Decoder** gets the Arabic text (`labels`).

We use `tokenizer.as_target_tokenizer()` for the labels to ensure correct special token handling if the tokenizer distinguishes between source/target (though BPE usually shares them).

In [86]:
print(f"\n--- 3. Applying Seq2Seq Tokenization ---")
MAX_LENGTH = training_config.MAX_LENGTH

def preprocess_function(examples):
    # Tokenize Source (English)
    inputs = tokenizer(
        examples["en"], 
        max_length=MAX_LENGTH, 
        truncation=True, 
        padding=False 
    )

    # Tokenize Target (Arabic)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["arz"], 
            max_length=MAX_LENGTH, 
            truncation=True, 
            padding=False
        )

    inputs["labels"] = labels["input_ids"]
    return inputs

tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=raw_datasets["train"].column_names
)

print("\n--- Tokenization Complete ---")
print(tokenized_datasets)
print("\nExample of tokenized data:")
print(f"  input_ids: {tokenized_datasets['train'][0]['input_ids'][:20]}...")
print(f"  labels: {tokenized_datasets['train'][0]['labels'][:20]}...")


--- 3. Applying Seq2Seq Tokenization ---

--- Tokenization Complete ---
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 697306
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 7044
    })
})

Example of tokenized data:
  input_ids: [3, 1837, 2442, 2777, 637, 4800, 863, 603, 6673, 674, 895, 1350, 1469, 4409, 4]...
  labels: [3, 1295, 1837, 9660, 21647, 669, 599, 10767, 812, 2191, 9004, 4]...


### 6. Collator & Metrics
We initialize the `DataCollatorForSeq2Seq` which handles dynamic padding (so we don't waste GPU memory on empty space).

The `compute_metrics` function calculates the **BLEU score**. It also has a safety clip (`np.clip`) to prevent crashes if the model predicts a token ID larger than the vocabulary size (a rare but fatal bug).

In [87]:
print(f"\n--- 5. Initializing Collator & Metric ---")

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=model,
    pad_to_multiple_of=8
)
print("DataCollatorForSeq2Seq initialized.")

import evaluate
metric = evaluate.load("sacrebleu")

import numpy as np
import torch

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    if isinstance(preds, tuple):
        preds = preds[0]

    if isinstance(preds, torch.Tensor):
        preds = preds.detach().cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()

    if hasattr(preds, "ndim") and preds.ndim == 3:
        preds = np.argmax(preds, axis=-1)

    # Safety clip to vocab size
    max_valid_id = tokenizer.vocab_size - 1
    np.clip(preds, 0, max_valid_id, out=preds)

    # Replace -100 with pad token for decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_labels_for_bleu = [[lbl] for lbl in decoded_labels]
    result = metric.compute(predictions=decoded_preds, references=decoded_labels_for_bleu)

    if "score" in result:
        result["bleu"] = result.pop("score")

    final_metrics = {k: round(float(v), 4) for k, v in result.items() if isinstance(v, (int, float))}
    return final_metrics

print("SacreBLEU metric ready.")


--- 5. Initializing Collator & Metric ---
DataCollatorForSeq2Seq initialized.
SacreBLEU metric ready.


### 7. Debug Callback
This is a custom callback I wrote to see what the model is actually doing during training. Every `logging_steps` (100), it pauses and translates a list of specific test sentences ("Get in the car", "I see dead people", etc.). This lets me visually verify if the model is learning or just outputting nonsense.

In [88]:
print(f"\n--- 6. Defining Debugging Callback ---\n")

DEBUG_EXAMPLES = {
    "test_1_car": "Get in the car, we have to go now!",
    "test_2_feeling": "I have a very bad feeling about this.",
    "cmd_1_door": "Open the door, please.",
    "movie_1_godfather": "I'm gonna make him an offer he can't refuse.",
    "chat_1_greeting": "Hello, how are you today?",
    "hard_2_idiom": "It's raining cats and dogs."
}

class TranslationLogCallback(TrainerCallback):
    def __init__(self, tokenizer, model, max_length=128):
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.max_length = max_length
        self.debug_examples = DEBUG_EXAMPLES

    def on_log(self, args: Seq2SeqTrainingArguments, state: TrainerState, control: TrainerControl, logs: Dict[str, float] = None, **kwargs):
        if not state.is_world_process_zero or logs is None:
            return

        # Log training samples
        if "loss" in logs and "eval_loss" not in logs:
            print(f"\n--- [Debug Translation @ Step {state.global_step}] ---")
            current_lr = logs.get("learning_rate", "N/A")
            train_loss = logs.get("loss", "N/A")
            print(f"  [Log] Step: {state.global_step} | Loss: {train_loss} | LR: {current_lr}")
            
            self.model.eval()
            for name, text in self.debug_examples.items():
                try:
                    inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length).to(self.model.device)
                    with torch.no_grad():
                        outputs = self.model.generate(**inputs, max_new_tokens=self.max_length, num_beams=4)
                    translation = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                    print(f"  [EN] {text}\n  [ARZ] {translation}")
                except Exception as e:
                    print(f"  Error: {e}")
            self.model.train()
            print("[End Debug]")

        # Log eval metrics
        if "eval_loss" in logs:
            print(f"\n--- [Evaluation @ Step {state.global_step}] ---")
            for k, v in logs.items():
                if k != 'epoch':
                    print(f"  {k}: {v}")
            print("--------------------------")

print("TranslationLogCallback defined.")


--- 6. Defining Debugging Callback ---

TranslationLogCallback defined.


### 8. Training Arguments
Configuring the `Seq2SeqTrainer`.
- **Generation Config:** `num_beams=5` is key for translation quality. `repetition_penalty=1.5` prevents stuttering.
- **FP16:** Enabled for speed.
- **Early Stopping:** We stop training if the BLEU score doesn't improve for 5 evaluations.

In [89]:
print(f"\n--- 7. Defining Training Arguments ---\n")

from transformers import GenerationConfig, EarlyStoppingCallback

EARLY_STOPPING_PATIENCE = 5 
GENERATION_MAX_NEW_TOKENS = 128 
GENERATION_NUM_BEAMS = 5 

gen_config = GenerationConfig.from_model_config(model.config)
gen_config.max_new_tokens = GENERATION_MAX_NEW_TOKENS
gen_config.num_beams = GENERATION_NUM_BEAMS
gen_config.no_repeat_ngram_size = 3
gen_config.repetition_penalty = 1.5
gen_config.length_penalty = 1.2
gen_config.early_stopping = True

training_args = Seq2SeqTrainingArguments(
    output_dir=model_config.FINETUNED_OUTPUT_DIR,
    learning_rate=training_config.LEARNING_RATE,
    num_train_epochs=training_config.NUM_EPOCHS,
    per_device_train_batch_size=training_config.PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=training_config.GRAD_ACCUMULATION_STEPS,
    logging_steps=training_config.LOGGING_STEPS,
    eval_strategy="steps",
    eval_steps=training_config.EVAL_STEPS,
    save_strategy="steps",
    save_steps=training_config.SAVE_STEPS,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="bleu",
    fp16=True,
    predict_with_generate=True,
    generation_config=gen_config,
    report_to="tensorboard"
)

translation_logger = TranslationLogCallback(tokenizer=tokenizer, model=model, max_length=training_config.MAX_LENGTH)
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[translation_logger, early_stopping_callback]
)

print(f"Training Arguments set. Output will be saved to: {training_args.output_dir}")
print(f"Evaluation and Logging will happen every {training_args.eval_steps} and {training_args.logging_steps} steps.")
print(f"GenerationConfig set: max_new_tokens={GENERATION_MAX_NEW_TOKENS}, num_beams={GENERATION_NUM_BEAMS}")
print(f"load_best_model_at_end=True. Early stopping will be added in the next cell (Patience={EARLY_STOPPING_PATIENCE}).")


--- 7. Defining Training Arguments ---

Training Arguments set. Output will be saved to: /kaggle/working/bart-en-arz-translator
Evaluation and Logging will happen every 100 and 100 steps.
GenerationConfig set: max_new_tokens=128, num_beams=5
load_best_model_at_end=True. Early stopping will be added in the next cell (Patience=5).


### 9. Training Loop
This is where the magic happens. I'm running a fresh training session. I'm manually loading the weights via `load_state_dict` instead of relying on `resume_from_checkpoint` because the optimizer state in the checkpoint might be incompatible with the new fine-tuning parameters.

In [None]:
import torch
import os
from transformers import Seq2SeqTrainer
from safetensors.torch import load_file

CKPT_PATH = "/kaggle/input/lasthopr/bart-en-arz-translator/checkpoint-2000"
print(f"Recovering weights from: {CKPT_PATH}")

weights_path = os.path.join(CKPT_PATH, "model.safetensors")
if os.path.exists(weights_path):
    print("Loading .safetensors...")
    state_dict = load_file(weights_path)
else:
    print("Loading .bin...")
    state_dict = torch.load(os.path.join(CKPT_PATH, "pytorch_model.bin"), map_location="cpu")

# Load weights into the EXISTING model object
# strict=False is allowed because we handled the tied weights manually
keys = model.load_state_dict(state_dict, strict=False)
print(f"Weights loaded. Missing keys: {keys.missing_keys}")

# Start training
print("Starting training with FRESH optimizer (ignoring broken optimizer.pt)...")
trainer.train()

Recovering weights from: /kaggle/input/lasthopr/bart-en-arz-translator/checkpoint-2000
Loading .safetensors...
Weights loaded. Missing keys: ['model.decoder.embed_tokens.weight', 'lm_head.weight']
Starting training with FRESH optimizer (ignoring broken optimizer.pt)...


Step,Training Loss,Validation Loss



--- [Debug Translation @ Step 100] ---
  [Log] Step: 100 | Epoch: 0.00 | LR: 1.00e-07 | Training Loss: 2.3525
  [EN] Input (test_1_car): 'Get in the car, we have to go now!'
  [ARZ] Output: 'ادخل في ال عربية ال لي لازم نمشي دلوقتي'
  [EN] Input (test_2_feeling): 'I have a very bad feeling about this.'
  [ARZ] Output: 'عندي إحساس وحش أوي في ال موضوع ده'
  [EN] Input (test_3_science): 'The process of photosynthesis converts light energy into chemical energy.'
  [ARZ] Output: 'عملية طاقة ال صور ال لي بتدمج ال ضوء ال بصري في ال طاقة ال كيميائية'
  [EN] Input (test_4_question): 'What are you doing here?'
  [ARZ] Output: 'بتعمل إيه هنا ؟'
  [EN] Input (cmd_1_door): 'Open the door, please.'
  [ARZ] Output: 'افتح ال باب من فضلك'
  [EN] Input (cmd_2_phone): 'Give me the phone.'
  [ARZ] Output: 'اديني ال تليفون'
  [EN] Input (cmd_3_police): 'Call the police!'
  [ARZ] Output: 'كلم ال بوليس'
  [EN] Input (cmd_4_lights): 'Turn off the lights before you leave.'
  [ARZ] Output: 'اقفل ال أنوار قبل ما

### 10. Saving Final Model
We save the final fine-tuned model to a clean directory. The `trainer.save_model()` function handles saving the tokenizer and config automatically.

In [None]:
import os

FINAL_OUTPUT_DIR = os.path.join(model_config.FINETUNED_OUTPUT_DIR, "final_model")
print(f"Saving final model to standard directory: {FINAL_OUTPUT_DIR}")

trainer.save_model(FINAL_OUTPUT_DIR)
tokenizer.save_pretrained(FINAL_OUTPUT_DIR)

print("SUCCESS. Model saved safely. No manual weight patching required.")
print(f"Artifacts in folder: {os.listdir(FINAL_OUTPUT_DIR)}")

### 11. Final Validation Check
This is the proof of life. We reload the *saved* model from disk (to make sure the save process worked) and run a final set of translations. We verify that it loads without errors and produces coherent Arabic output.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from safetensors.torch import load_file
import os

print("\n--- Final Validation: Loading from 'final_model' Dir ---")

FINAL_OUTPUT_DIR = os.path.join(model_config.FINETUNED_OUTPUT_DIR, "final_model")
device = "cuda" if torch.cuda.is_available() else "cpu"

if not os.path.exists(FINAL_OUTPUT_DIR):
    raise FileNotFoundError(f"Model directory not found at: {FINAL_OUTPUT_DIR}")

# Define RMSNorm again for loading context
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x / rms * self.scale

def replace_layernorm_with_rmsnorm(module: nn.Module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.LayerNorm):
            dim = child.normalized_shape[0] if isinstance(child.normalized_shape, (tuple, list)) else child.normalized_shape
            rms = RMSNorm(dim=dim, eps=1e-6)
            setattr(module, name, rms)
        else:
            replace_layernorm_with_rmsnorm(child)

# Clean Memory
if 'model' in globals(): del model
if 'trainer' in globals(): del trainer
torch.cuda.empty_cache()

# Reload
print("Loading configuration...")
config = AutoConfig.from_pretrained(FINAL_OUTPUT_DIR)
new_model = AutoModelForSeq2SeqLM.from_config(config)
replace_layernorm_with_rmsnorm(new_model)

weights_path = os.path.join(FINAL_OUTPUT_DIR, "model.safetensors")
state_dict = load_file(weights_path)
new_model.load_state_dict(state_dict, strict=False)

new_model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(FINAL_OUTPUT_DIR)
print("Model ready for testing.")

test_sentences = [
    "Get in the car", 
    "I have a bad feeling", 
    "Open the door, please.", 
    "It is three o'clock in the afternoon.",
    "It's raining cats and dogs."
]

print("\n--- Validation Translations ---")
for i, s in enumerate(test_sentences):
    inputs = tokenizer(s, return_tensors="pt").to(device)
    if "token_type_ids" in inputs: del inputs["token_type_ids"]
    
    with torch.no_grad():
        out = new_model.generate(**inputs, max_new_tokens=128, num_beams=5)
    
    decoded = tokenizer.decode(out[0], skip_special_tokens=True)
    print(f"EN:  {s}")
    print(f"ARZ: {decoded}")
    print("-" * 20)