# Model Choice and Fine-Tuning Strategy 

## 1. Choice of the Base Model: Atlas-Chat-2B

We adopt **Atlas-Chat-2B** as the base language model due to its strong suitability for domain-specific conversational fine-tuning under limited computational resources.

The selection is motivated by the following considerations:

- **Instruction-tuned conversational model**: Atlas-Chat-2B is already optimized for multi-turn dialogue, making it well-aligned with the target application (health guidance, boundaries, moderation, and Q&A).
- **Moderate parameter size (≈2B)**: This size provides a favorable trade-off between expressive capacity and trainability , enabling efficient experimentation without sacrificing linguistic quality.
- **trained on moroccan dialect**: The model is based on qwen model and fine tuned on the moroccan dialect, which is close to Algerian dialect.
- **Compatibility with parameter-efficient fine-tuning (PEFT)**: Atlas-Chat-2B integrates seamlessly with LoRA-based adaptation and low-bit quantization, allowing scalable fine-tuning with limited memory overhead.

Overall, Atlas-Chat-2B offers sufficient representational power while remaining practical for controlled fine-tuning and deployment.

---

## 2. Fine-Tuning Strategy: Single Shuffled Dataset with LoRA

### 2.1 Rationale Against Sequential Fine-Tuning

Initial experiments based on **sequential fine-tuning** (e.g., bad-words handling → boundaries → greetings → Q&A) exhibited **catastrophic forgetting**, where later stages degraded previously learned behaviors.  
This phenomenon is well-documented in continual learning settings, particularly for large language models fine-tuned without explicit memory preservation mechanisms.

---

### 2.2 Unified Shuffled Training Approach

To mitigate catastrophic forgetting, we adopt a **single shuffled dataset strategy**, where all behavioral categories are mixed and learned simultaneously.  
This ensures that:

- The model continuously revisits all task types during training
- No single behavior dominates parameter updates
- Knowledge retention is preserved across epochs

Formally, the training dataset includes a randomized mixture of:
- Safety and boundary enforcement
- Offensive language handling
- Conversational greetings
- Question answering

This approach aligns with standard multi-task learning principles and significantly improves behavioral consistency at inference time.

---

### 2.3 Prompt Template Consistency

All training samples are normalized into a unified template:

- *system*:
- *user*:
- *assistant*:


Maintaining a consistent prompt structure is critical to:
- Stabilize training dynamics
- Reduce format-induced distribution shift
- Ensure reliable downstream inference behavior

---

## 3. Parameter-Efficient Fine-Tuning with LoRA

To adapt the model efficiently, we employ **Low-Rank Adaptation (LoRA)** applied to the attention and feed-forward projection layers.

This choice is justified by:
- **Reduced memory footprint**: Only a small number of trainable parameters are introduced
- **Faster convergence**: Updates are focused on task-relevant subspaces
- **Preservation of pre-trained knowledge**: Base model weights remain frozen

Combined with **8-bit quantization**, this strategy enables fine-tuning on constrained hardware while maintaining competitive performance.

---

## 4. Validation Strategy

A **10% validation split** is incorporated to:
- Monitor generalization performance during training
- Prevent overfitting to conversational patterns
- Automatically select the best model checkpoint based on validation loss


## Imports and Dependencies

In [None]:
import json
import torch
from pathlib import Path

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)

from datasets import Dataset


## Configuration

All hyperparameters and paths are defined here for clarity and reproducibility.


In [None]:
BASE_MODEL = "MBZUAI-Paris/Atlas-Chat-2B"
TRAINING_DATA = "merged_shuffled_training_data.json"
OUTPUT_DIR = "atlas_finetuned/"

MAX_LENGTH = 1024
EPOCHS = 3
LEARNING_RATE = 2e-4
BATCH_SIZE = 2
GRADIENT_ACCUM = 4
VALIDATION_SPLIT = 0.1  # 10%

print("=" * 80)
print("ATLAS FINE-TUNING - SINGLE SHUFFLED DATASET WITH VALIDATION")
print("=" * 80)
print(f"Base Model: {BASE_MODEL}")
print(f"Training Data: {TRAINING_DATA}")
print(f"Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUM}")
print(f"Validation Split: {VALIDATION_SPLIT * 100}%")
print("=" * 80)

## Prompt Formatting

All examples are converted into a **consistent ChatML-style format**:
- *system*:
- *user*:
- *assistant*:

This consistency is critical for stable inference after fine-tuning.

In [None]:
def format_example(example):
    """
    Convert ChatML messages to a single training string.
    """
    messages = example["messages"]
    system = messages[0]["content"]
    user = messages[1]["content"]
    assistant = messages[2]["content"]

    return f"system: {system}\nuser: {user}\nassistant: {assistant}"


## Loading the Shuffled Dataset

The dataset is already shuffled to ensure:
- No task dominates training
- No behavior overwrites another


In [None]:
with open(TRAINING_DATA, "r", encoding="utf-8") as f:
    training_examples = json.load(f)

print(f"Loaded {len(training_examples)} training examples")

print("\nWhy shuffling matters:")
print("- Prevents catastrophic forgetting")
print("- All behaviors learned jointly")


## Tokenizer Initialization

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True,
    use_fast=True
)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded | Vocabulary size: {len(tokenizer)}")


## Dataset Formatting & Tokenization


In [None]:
formatted_texts = [
    format_example(ex) + tokenizer.eos_token
    for ex in training_examples
]

dataset = Dataset.from_dict({"text": formatted_texts})

def tokenize_batch(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH
    )

tokenized_dataset = dataset.map(
    tokenize_batch,
    batched=True,
    remove_columns=["text"],
    desc="Tokenizing dataset"
)

sample_length = sum(
    t != tokenizer.pad_token_id
    for t in tokenized_dataset[0]["input_ids"]
)

print(f"Dataset size: {len(tokenized_dataset)}")
print(f"Sample token length: {sample_length}")


## Train / Validation Split


In [None]:
split = tokenized_dataset.train_test_split(
    test_size=VALIDATION_SPLIT,
    seed=42
)

train_dataset = split["train"]
eval_dataset = split["test"]

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")


## Model Loading with 8-bit Quantization & LoRA


In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True,
    bnb_8bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    quantization_config=bnb_config,
    trust_remote_code=True
)

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


## Training Configuration

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUM,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=20,
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    warmup_steps=50,
    weight_decay=0.01,
    optim="paged_adamw_8bit",
    evaluation_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to="none",
)


## Training

This training setup:
- Prevents catastrophic forgetting
- Monitors validation loss
- Saves the best model automatically


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()


## Saving Model & Training Metadata

In [None]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)

training_info = {
    "base_model": BASE_MODEL,
    "num_examples": len(training_examples),
    "train_examples": len(train_dataset),
    "eval_examples": len(eval_dataset),
    "validation_split": VALIDATION_SPLIT,
    "epochs": EPOCHS,
    "learning_rate": LEARNING_RATE,
    "effective_batch_size": BATCH_SIZE * GRADIENT_ACCUM,
    "approach": "single_shuffled_dataset_with_validation"
}

with open(output_path / "training_info.json", "w") as f:
    json.dump(training_info, f, indent=2)

print("Model and metadata saved successfully.")


## Atlas Finetuning

In [1]:
"""
Clear GPU Memory Completely
Run this FIRST to avoid OOM errors from previous Kaggle sessions
"""

import gc
import torch
import os
import sys

print("=" * 80)
print("CLEARING GPU MEMORY")
print("=" * 80)

# Set memory fragmentation fix
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Delete any existing model/tokenizer variables from previous cell runs
vars_to_delete = ['model', 'base_model', 'finetuned_model', 'tokenizer', 
                  'trainer', 'training_args', 'dataset', 'train_dataset', 
                  'eval_dataset', 'tokenized_dataset']

deleted_count = 0
for var_name in vars_to_delete:
    if var_name in globals():
        del globals()[var_name]
        deleted_count += 1
    if var_name in locals():
        del locals()[var_name]
        deleted_count += 1

if deleted_count > 0:
    print(f"\nDeleted {deleted_count} existing variables")

# Clear Python garbage
gc.collect()

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    # Get memory stats
    allocated = torch.cuda.memory_allocated(0) / 1024**3
    reserved = torch.cuda.memory_reserved(0) / 1024**3
    total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Total Memory: {total:.2f} GiB")
    print(f"Allocated: {allocated:.2f} GiB")
    print(f"Reserved: {reserved:.2f} GiB")
    print(f"Free: {total - allocated:.2f} GiB")
    
    if allocated > 1.0:
        print(f"\nWARNING: {allocated:.2f} GiB still allocated!")
        print("Restart Kaggle kernel if you need more memory")
    else:
        print("\nMemory cleared successfully!")
else:
    print("WARNING: CUDA not available!")

print("=" * 80)

CLEARING GPU MEMORY

GPU: Tesla T4
Total Memory: 14.56 GiB
Allocated: 0.00 GiB
Reserved: 0.00 GiB
Free: 14.56 GiB

Memory cleared successfully!


In [None]:
"""
Fine-tune Atlas-Chat-2B with Checkpoint Resume Support + Auto-Zip

KEY FEATURES:
- Saves checkpoint after EACH EPOCH (~5 hours per epoch)
- Auto-creates ZIP file of each checkpoint for download
- Auto-resumes from last checkpoint if session disconnects
- Survives Kaggle's 12-hour session limit
- No progress lost - each completed epoch is preserved

USAGE:
1. Run once: Completes 2-3 epochs before 12h timeout
2. DOWNLOAD checkpoints: /kaggle/working/checkpoint_epoch_X.zip
3. If timeout: Upload ZIP as Kaggle Dataset, update code to load from it
4. Re-run: Auto-resumes from last checkpoint

CHECKPOINT LOCATIONS:
Folder: /kaggle/working/atlas_finetuned/checkpoint-{step}/
ZIP: /kaggle/working/checkpoint_epoch_{X}.zip
Example: checkpoint-1262 (after epoch 1) → checkpoint_epoch_1.zip
"""

import json
import torch
import gc
import os
import shutil
import zipfile
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    TrainerCallback,
)
from peft import LoraConfig, get_peft_model
from datasets import Dataset

# Configuration
BASE_MODEL = "MBZUAI-Paris/Atlas-Chat-2B"
TRAINING_DATA = "/kaggle/input/finetuning-atlas/Kaggle_Atlas_data/merged_shuffled_training_data.json"  
OUTPUT_DIR = "/kaggle/working/atlas_finetuned/"

MAX_LENGTH = 1024
EPOCHS = 3  # Will complete across multiple sessions if needed
LEARNING_RATE = 1.4e-4  
BATCH_SIZE = 2
GRADIENT_ACCUM = 5
VALIDATION_SPLIT = 0.1

print("=" * 80)
print("ATLAS FINE-TUNING - CHECKPOINT RESUME ENABLED")
print("=" * 80)
print("\nSession-Safe Training:")
print("  - Saves checkpoint after EACH epoch")
print("  - Auto-resumes if disconnected")
print("  - Fits within 12-hour session limit")
print("\nConfiguration:")
print(f"  Base Model: {BASE_MODEL}")
print(f"  Training Data: {TRAINING_DATA}")
print(f"  Max Length: {MAX_LENGTH} tokens")
print(f"  Epochs: {EPOCHS} (saves after each)")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Batch Size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRADIENT_ACCUM})")
print(f"  Validation Split: {VALIDATION_SPLIT * 100}%")
print(f"  Output: {OUTPUT_DIR}")
print("=" * 80)


def format_example(example):
    """Convert ChatML to training text."""
    messages = example['messages']
    system = messages[0]['content']
    user = messages[1]['content']
    assistant = messages[2]['content']
    return f"system: {system}\nuser: {user}\nassistant: {assistant}"


def find_latest_checkpoint(output_dir):
    """Find the most recent checkpoint to resume from."""
    output_path = Path(output_dir)
    if not output_path.exists():
        return None
    
    checkpoints = [d for d in output_path.iterdir() if d.is_dir() and d.name.startswith('checkpoint-')]
    if not checkpoints:
        return None
    
    # Sort by step number (checkpoint-1262, checkpoint-2524, etc.)
    checkpoints.sort(key=lambda x: int(x.name.split('-')[1]))
    latest = checkpoints[-1]
    return str(latest)


def create_checkpoint_zip(checkpoint_path, epoch_num, output_dir="/kaggle/working"):
    """Create a ZIP file of the checkpoint for easy download."""
    checkpoint_name = Path(checkpoint_path).name
    zip_filename = f"checkpoint_epoch_{epoch_num}.zip"
    zip_path = Path(output_dir) / zip_filename
    
    print(f"\n  → Creating ZIP: {zip_filename}")
    
    try:
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            checkpoint_dir = Path(checkpoint_path)
            for file_path in checkpoint_dir.rglob('*'):
                if file_path.is_file():
                    arcname = file_path.relative_to(checkpoint_dir.parent)
                    zipf.write(file_path, arcname)
        
        zip_size = zip_path.stat().st_size / (1024**2)  # MB
        print(f"  ✓ ZIP created: {zip_path}")
        print(f"  ✓ Size: {zip_size:.1f} MB")
        print(f"  ✓ Download from: /kaggle/working/{zip_filename}")
        return str(zip_path)
    except Exception as e:
        print(f"  ✗ Failed to create ZIP: {e}")
        return None


class CheckpointZipCallback(TrainerCallback):
    """Callback to create ZIP of checkpoints after each epoch save."""
    def __init__(self, output_dir, steps_per_epoch):
        self.output_dir = output_dir
        self.steps_per_epoch = steps_per_epoch
        self.zipped_checkpoints = set()
    
    def on_save(self, args, state, control, **kwargs):
        """Called after checkpoint is saved."""
        # Check if we just completed an epoch (checkpoint saved)
        if state.global_step > 0 and state.global_step % self.steps_per_epoch == 0:
            epoch_num = state.global_step // self.steps_per_epoch
            checkpoint_dir = f"{self.output_dir}/checkpoint-{state.global_step}"
            
            # Only zip if not already zipped
            if checkpoint_dir not in self.zipped_checkpoints:
                print(f"\n{'='*60}")
                print(f"EPOCH {epoch_num} COMPLETE - Creating Backup ZIP")
                print('='*60)
                
                if Path(checkpoint_dir).exists():
                    zip_path = create_checkpoint_zip(checkpoint_dir, epoch_num)
                    if zip_path:
                        self.zipped_checkpoints.add(checkpoint_dir)
                        print(f"\n  ⚠️ IMPORTANT: Download {Path(zip_path).name} before session timeout!")
                        print(f"  → Location: /kaggle/working/{Path(zip_path).name}")
                else:
                    print(f"  ✗ Checkpoint folder not found: {checkpoint_dir}")
                
                print('='*60 + "\n")


def main():
    # Check for existing checkpoint
    print("\n[0/7] Checking for existing checkpoints...")
    resume_checkpoint = find_latest_checkpoint(OUTPUT_DIR)
    
    if resume_checkpoint:
        print(f"✓ Found checkpoint: {resume_checkpoint}")
        print("  → Will RESUME training from this point")
        print("  → Previously completed epochs preserved")
    else:
        print("✗ No checkpoint found")
        print("  → Will START training from scratch")
    
    # Load training data
    print("\n[1/7] Loading shuffled training data...")
    with open(TRAINING_DATA, 'r', encoding='utf-8') as f:
        training_examples = json.load(f)
    
    print(f"Loaded {len(training_examples)} examples")
    
    # Load tokenizer
    print("\n[2/7] Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    print(f"Tokenizer loaded (vocab: {len(tokenizer)})")
    
    # Prepare dataset
    print("\n[3/7] Formatting and tokenizing dataset...")
    formatted_texts = [format_example(ex) + tokenizer.eos_token for ex in training_examples]
    
    def tokenize_batch(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=MAX_LENGTH,
            padding="max_length"
        )
    
    dataset = Dataset.from_dict({"text": formatted_texts})
    tokenized_dataset = dataset.map(
        tokenize_batch,
        batched=True,
        remove_columns=["text"],
        desc="Tokenizing"
    )
    
    sample_length = sum(1 for t in tokenized_dataset[0]['input_ids'] if t != tokenizer.pad_token_id)
    print(f"Dataset prepared: {len(tokenized_dataset)} examples")
    print(f"Sample token length: {sample_length} tokens")
    
    # Split into train and validation
    print("\n[4/7] Splitting dataset into train/validation...")
    split_dataset = tokenized_dataset.train_test_split(test_size=VALIDATION_SPLIT, seed=42)
    train_dataset = split_dataset['train']
    eval_dataset = split_dataset['test']
    
    print(f"Training examples: {len(train_dataset)}")
    print(f"Validation examples: {len(eval_dataset)}")
    
    # Clear memory before loading model
    gc.collect()
    torch.cuda.empty_cache()
    
    # Load model
    print("\n[5/7] Loading base model...")
    
    if resume_checkpoint:
        print(f"Loading model WITH adapter from checkpoint: {resume_checkpoint}")
        # Load base model first
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        model.gradient_checkpointing_enable()
        
        # Apply LoRA config (needed before loading checkpoint)
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        model = get_peft_model(model, lora_config)
        print("  → Model prepared for checkpoint loading")
    else:
        print("Loading model from scratch (no checkpoint)")
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        model.gradient_checkpointing_enable()
        
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        model = get_peft_model(model, lora_config)
    
    print("Model loaded with LoRA adapter")
    print("\nTrainable parameters:")
    model.print_trainable_parameters()
    
    allocated = torch.cuda.memory_allocated(0) / 1024**3
    print(f"\nGPU Memory after model load: {allocated:.2f} GiB")
    
    # Configure training with EPOCH-based checkpointing
    print("\n[6/7] Configuring training...")
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUM,
        learning_rate=LEARNING_RATE,
        fp16=True,
        
        # CHECKPOINT STRATEGY: Save after each epoch
        save_strategy="epoch",  # ← KEY: Save after each epoch
        save_total_limit=3,     # Keep all 4 epoch checkpoints
        
        # Evaluation after each epoch (must match save_strategy for load_best_model_at_end)
        eval_strategy="epoch",
        
        logging_steps=20,
        warmup_steps=50,
        weight_decay=0.01,
        logging_dir=f"{OUTPUT_DIR}/logs",
        per_device_eval_batch_size=1,
        eval_accumulation_steps=2,
        report_to="none",
        optim="adamw_torch",
        gradient_checkpointing=True,
        max_grad_norm=0.3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )
    
    steps_per_epoch = len(train_dataset) // (BATCH_SIZE * GRADIENT_ACCUM)
    total_steps = steps_per_epoch * EPOCHS
    time_per_epoch = 4.5  # hours (with gradient checkpointing)
    
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Total steps: {total_steps}")
    print(f"Estimated time per epoch: ~{time_per_epoch} hours")
    print(f"Total estimated time: ~{time_per_epoch * EPOCHS} hours")
    print(f"\nCheckpoint strategy:")
    print(f"  - Save after EACH epoch (~{steps_per_epoch} steps)")
    print(f"  - Auto-create ZIP after each epoch")
    print(f"  - Keep all {EPOCHS} epoch checkpoints")
    if resume_checkpoint:
        print(f"  - RESUMING from: {Path(resume_checkpoint).name}")
    else:
        print(f"  - Starting fresh training")
    
    # Create checkpoint callback for auto-zipping
    checkpoint_callback = CheckpointZipCallback(OUTPUT_DIR, steps_per_epoch)
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
        callbacks=[checkpoint_callback],  # Add ZIP callback
    )
    
    # Train with resume support
    print("\n" + "=" * 80)
    print("[7/7] TRAINING")
    print("=" * 80)
    print("\nCheckpoint-Resume Training with Auto-Zip:")
    print("  ✓ Saves after EACH epoch")
    print("  ✓ Auto-creates ZIP for download")
    print("  ✓ Survives session disconnects")
    print("  ✓ Auto-resumes on re-run")
    print("  ✓ No progress lost")
    print("\nSession Management:")
    print(f"  - Epoch 1-2: Completes in first 9-10 hour session")
    print(f"  - Download ZIPs: /kaggle/working/checkpoint_epoch_X.zip")
    print(f"  - Re-run if timeout: Auto-resumes from last checkpoint")
    print(f"  - Or upload ZIP and update code to load from it")
    print("=" * 80 + "\n")
    
    # Train with automatic checkpoint resume
    trainer.train(resume_from_checkpoint=resume_checkpoint)
    
    # Create final checkpoint ZIP
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE - Creating Final Checkpoint ZIP")
    print("=" * 80)
    
    final_checkpoint = find_latest_checkpoint(OUTPUT_DIR)
    if final_checkpoint:
        epoch_num = EPOCHS
        create_checkpoint_zip(final_checkpoint, epoch_num, "/kaggle/working")
    
    # Save final model
    print("\n" + "=" * 80)
    print("SAVING FINAL MODEL")
    print("=" * 80)
    
    output_path = Path(OUTPUT_DIR)
    output_path.mkdir(parents=True, exist_ok=True)
    
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)
    
    training_info = {
        "base_model": BASE_MODEL,
        "training_data": TRAINING_DATA,
        "num_examples": len(training_examples),
        "num_train_examples": len(train_dataset),
        "num_eval_examples": len(eval_dataset),
        "validation_split": VALIDATION_SPLIT,
        "max_length": MAX_LENGTH,
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "gradient_accumulation": GRADIENT_ACCUM,
        "effective_batch_size": BATCH_SIZE * GRADIENT_ACCUM,
        "total_steps": total_steps,
        "quantization": "none",
        "platform": "kaggle",
        "checkpoint_strategy": "save_per_epoch",
        "approach": "single_shuffled_dataset_with_validation_and_checkpoint_resume",
        "notes": "Checkpoint saved after each epoch. Survives 12h session limit."
    }
    
    with open(output_path / "training_info.json", 'w', encoding='utf-8') as f:
        json.dump(training_info, f, ensure_ascii=False, indent=2)
    
    print(f"\nFinal model saved to: {OUTPUT_DIR}")
    print(f"Total examples trained: {len(training_examples)}")
    print(f"All {EPOCHS} epochs completed!")
    print("\n" + "=" * 80)
    print("CHECKPOINT ZIP FILES AVAILABLE FOR DOWNLOAD:")
    print("=" * 80)
    print("\nDownload these from /kaggle/working/ (right panel):")
    for i in range(1, EPOCHS + 1):
        zip_file = f"checkpoint_epoch_{i}.zip"
        zip_path = Path("/kaggle/working") / zip_file
        if zip_path.exists():
            size = zip_path.stat().st_size / (1024**2)
            print(f"  ✓ {zip_file} ({size:.1f} MB)")
        else:
            print(f"  ✗ {zip_file} (not found)")
    
    print("\n" + "=" * 80)
    print("Next steps:")
    print("  1. Download checkpoint ZIPs from /kaggle/working/")
    print("  2. Download final adapter from /kaggle/working/atlas_finetuned/")
    print("  3. Test locally or create new Kaggle notebook for testing")
    print("  4. Verify all behaviors work (greetings, boundaries, bad words, Q&A)")
    print("\nIf session timed out and need to resume:")
    print("  1. Upload checkpoint ZIP as Kaggle Dataset")
    print("  2. Extract in notebook: !unzip /kaggle/input/.../checkpoint_epoch_X.zip -d /kaggle/working/atlas_finetuned/")
    print("  3. Re-run this cell - will auto-resume from extracted checkpoint")
    print("=" * 80)


if __name__ == "__main__":
    main()

2026-02-02 09:09:46.217914: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770023386.443960      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770023386.515283      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770023387.078224      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770023387.078254      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770023387.078257      55 computation_placer.cc:177] computation placer alr

ATLAS FINE-TUNING - CHECKPOINT RESUME ENABLED

Session-Safe Training:
  - Saves checkpoint after EACH epoch
  - Auto-resumes if disconnected
  - Fits within 12-hour session limit

Configuration:
  Base Model: MBZUAI-Paris/Atlas-Chat-2B
  Training Data: /kaggle/input/finetuning-atlas/Kaggle_Atlas_data/merged_shuffled_training_data.json
  Max Length: 1024 tokens
  Epochs: 3 (saves after each)
  Learning Rate: 0.00014
  Batch Size: 2 (effective: 10)
  Validation Split: 10.0%
  Output: /kaggle/working/atlas_finetuned/

[0/7] Checking for existing checkpoints...
✗ No checkpoint found
  → Will START training from scratch

[1/7] Loading shuffled training data...
Loaded 14015 examples

[2/7] Loading tokenizer...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Tokenizer loaded (vocab: 256000)

[3/7] Formatting and tokenizing dataset...


Tokenizing:   0%|          | 0/14015 [00:00<?, ? examples/s]

Dataset prepared: 14015 examples
Sample token length: 127 tokens

[4/7] Splitting dataset into train/validation...
Training examples: 12613
Validation examples: 1402

[5/7] Loading base model...
Loading model from scratch (no checkpoint)


config.json:   0%|          | 0.00/886 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Model loaded with LoRA adapter

Trainable parameters:
trainable params: 20,766,720 || all params: 2,635,108,608 || trainable%: 0.7881

GPU Memory after model load: 2.43 GiB

[6/7] Configuring training...
Steps per epoch: 1261
Total steps: 3783
Estimated time per epoch: ~4.5 hours
Total estimated time: ~13.5 hours

Checkpoint strategy:
  - Save after EACH epoch (~1261 steps)
  - Auto-create ZIP after each epoch
  - Keep all 3 epoch checkpoints
  - Starting fresh training

[7/7] TRAINING

Checkpoint-Resume Training with Auto-Zip:
  ✓ Saves after EACH epoch
  ✓ Auto-creates ZIP for download
  ✓ Survives session disconnects
  ✓ Auto-resumes on re-run
  ✓ No progress lost

Session Management:
  - Epoch 1-2: Completes in first 9-10 hour session
  - Download ZIPs: /kaggle/working/checkpoint_epoch_X.zip
  - Re-run if timeout: Auto-resumes from last checkpoint
  - Or upload ZIP and update code to load from it



`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss


## Testing

In [4]:
"""
Atlas Fine-Tuned Model Testing - Multi-Category Evaluation

Tests all behavior categories from training:
1. Bad Words Detection & Handling
2. Boundaries Enforcement (out-of-scope queries)
3. Greetings & Conversational Flow
4. Question Answering (smoking cessation Q&A)

Test Data: Multiple test files or single categorized file
Training Format: "system: ...\nuser: ...\nassistant: ..."
"""

import json
import os
from difflib import SequenceMatcher
from collections import defaultdict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel


# ============================================================================
# CONFIGURATION
# ============================================================================

BASE_MODEL = "MBZUAI-Paris/Atlas-Chat-2B"
ADAPTER_DIR = "/kaggle/working/atlas_finetuned"
OUTPUT_FILE = "/kaggle/working/atlas_test_results.json"


# Use single categorized file
USE_SINGLE_FILE = True
SINGLE_TEST_FILE = "/kaggle/input/finetuning-atlas/Kaggle_Atlas_data/test_all_categories.json"  # Must have "category" field

MAX_NEW_TOKENS = 150
MAX_LENGTH = 1024

# CRITICAL: Must match training data system prompt EXACTLY
SYSTEM_PROMPT = "أنت مساعد جزائري حول التدخين ، تجيب على الأسئلة باللهجة الجزائرية"

# Kaggle environment detection
IS_KAGGLE = os.path.exists('/kaggle')
if IS_KAGGLE:
    ADAPTER_DIR = "/kaggle/input/atlas-finetuned/atlas_finetuned_single_dataset"
    if USE_SINGLE_FILE:
        SINGLE_TEST_FILE = "/kaggle/input/test-data/test_all_categories.json"

# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def print_header(text):
    """Print section header."""
    print("\n" + "=" * 80)
    print(text)
    print("=" * 80)


def check_darija_usage(text):
    """Check if text contains Darija markers."""
    darija_markers = [
        'راك', 'كيفاش', 'واش', 'علاش', 'بزاف', 'قاع', 'حتى', 'خير',
        'الدارجة', 'نتا', 'نتي', 'نحنا', 'هكا', 'باش', 'حاجة', 'شي',
        'برك', 'غير', 'زعما', 'ديما', 'موالف', 'مزير', 'دروك', 'خويا'
    ]
    return any(marker in text for marker in darija_markers)


def check_bad_word_handling(response, question):
    """Check if model properly refused offensive content."""
    refusal_markers = [
        'ما نقدرش', 'مانقدرش', 'ما نحبش', 'مانحبش', 'ماشي', 'لا',
        'غير مقبول', 'ما ينفعش', 'مينفعش', 'احترام', 'أدب'
    ]
    return any(marker in response for marker in refusal_markers)


def check_boundary_handling(response):
    """Check if model properly declined out-of-scope query."""
    boundary_markers = [
        'تخصص', 'تخصصي', 'ما نقدرش نساعدك', 'مانقدرش نساعدك',
        'التدخين', 'السجائر', 'ماشي', 'خارج', 'مجال'
    ]
    return any(marker in response for marker in boundary_markers)


def check_greeting_quality(response):
    """Check if greeting response is appropriate."""
    greeting_markers = [
        'أهلا', 'مرحبا', 'السلام', 'كيفاش', 'نعاونك', 'نساعدك',
        'خدمة', 'تحت', 'بخير', 'الحمد لله'
    ]
    return any(marker in response for marker in greeting_markers)


def calculate_similarity(text1, text2):
    """Calculate text similarity ratio."""
    if not text1 or not text2:
        return 0.0
    return SequenceMatcher(None, text1, text2).ratio()


# ============================================================================
# DATA LOADING
# ============================================================================

def load_test_data_single(filepath):
    """Load test cases from single categorized file.
    
    Expected format:
    [
        {
            "category": "qa" | "bad_words" | "boundaries" | "greetings",
            "messages": [
                {"role": "system", "content": "..."},
                {"role": "user", "content": "..."},
                {"role": "assistant", "content": "..."}
            ]
        }
    ]
    """
    if not os.path.exists(filepath):
        print(f"ERROR: Test file not found at {filepath}")
        return []
    
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except Exception as e:
        print(f"ERROR loading test file: {e}")
        return []
    
    test_cases = []
    for idx, item in enumerate(data, 1):
        try:
            messages = item['messages']
            category = item.get('category', 'qa')  # Default to qa if not specified
            user_msg = next(m for m in messages if m['role'] == 'user')
            assistant_msg = next(m for m in messages if m['role'] == 'assistant')
            
            test_cases.append({
                'id': idx,
                'category': category,
                'question': user_msg['content'],
                'expected': assistant_msg['content']
            })
        except (KeyError, StopIteration) as e:
            print(f"WARNING: Skipping test case {idx} - invalid format")
            continue
    
    return test_cases


def load_test_data_multiple(file_dict):
    """Load test cases from multiple category-specific files."""
    all_test_cases = []
    test_id = 1
    
    for category, filepath in file_dict.items():
        if not os.path.exists(filepath):
            print(f"WARNING: {category} test file not found: {filepath}")
            continue
        
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"ERROR loading {category} file: {e}")
            continue
        
        for item in data:
            try:
                messages = item['messages']
                user_msg = next(m for m in messages if m['role'] == 'user')
                assistant_msg = next(m for m in messages if m['role'] == 'assistant')
                
                all_test_cases.append({
                    'id': test_id,
                    'category': category,
                    'question': user_msg['content'],
                    'expected': assistant_msg['content']
                })
                test_id += 1
            except (KeyError, StopIteration):
                continue
    
    return all_test_cases


# ============================================================================
# MODEL LOADING
# ============================================================================

def load_tokenizer(model_name):
    """Load and configure tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def get_model_config():
    """Get model loading configuration based on environment."""
    if IS_KAGGLE:
        return {
            "torch_dtype": torch.float16,
            "device_map": "auto",
            "trust_remote_code": True
        }
    else:
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_enable_fp32_cpu_offload=True,
            bnb_8bit_compute_dtype=torch.float16
        )
        return {
            "quantization_config": bnb_config,
            "device_map": "auto",
            "trust_remote_code": True
        }


def load_base_model(model_name, config):
    """Load base model."""
    print(f"Loading base model: {model_name}")
    model = AutoModelForCausalLM.from_pretrained(model_name, **config)
    model.eval()
    return model


def load_finetuned_model(base_model, adapter_dir):
    """Load fine-tuned adapter."""
    if not os.path.exists(adapter_dir):
        print(f"ERROR: Adapter not found at {adapter_dir}")
        if IS_KAGGLE:
            print("Upload adapter as Kaggle Dataset")
        else:
            print("Place adapter folder in current directory")
        raise FileNotFoundError(adapter_dir)
    
    print(f"Loading fine-tuned adapter: {adapter_dir}")
    model = PeftModel.from_pretrained(base_model, adapter_dir)
    model.eval()
    return model


# ============================================================================
# INFERENCE
# ============================================================================

def format_prompt(system, user):
    """Format prompt matching training format."""
    return f"system: {system}\nuser: {user}\nassistant: "


def generate_response(model, tokenizer, system_prompt, user_question):
    """Generate model response."""
    prompt = format_prompt(system_prompt, user_question)
    
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        max_length=MAX_LENGTH,
        truncation=True
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = full_text.replace(prompt, "").strip()
    
    return response


# ============================================================================
# CATEGORY-SPECIFIC EVALUATION
# ============================================================================

def evaluate_by_category(category, response, expected, question):
    """Evaluate response based on category-specific criteria."""
    base_metrics = {
        'uses_darija': check_darija_usage(response),
        'similarity': calculate_similarity(response, expected),
        'length': len(response)
    }
    
    if category == 'bad_words':
        base_metrics['properly_refused'] = check_bad_word_handling(response, question)
    elif category == 'boundaries':
        base_metrics['properly_declined'] = check_boundary_handling(response)
    elif category == 'greetings':
        base_metrics['appropriate_greeting'] = check_greeting_quality(response)
    # qa category uses base metrics only
    
    return base_metrics


# ============================================================================
# TESTING
# ============================================================================

def run_test_case(test_case, base_model, finetuned_model, tokenizer):
    """Run single test case on both models."""
    test_id = test_case['id']
    category = test_case['category']
    question = test_case['question']
    expected = test_case['expected']
    
    print(f"\nTest {test_id} [{category.upper()}]")
    print(f"Q: {question[:80]}{'...' if len(question) > 80 else ''}")
    print(f"Expected: {expected[:60]}{'...' if len(expected) > 60 else ''}")
    
    # Base model
    print("\n[Base]")
    base_response = generate_response(base_model, tokenizer, SYSTEM_PROMPT, question)
    print(base_response[:100] + ('...' if len(base_response) > 100 else ''))
    base_metrics = evaluate_by_category(category, base_response, expected, question)
    
    # Fine-tuned model
    print("\n[Fine-tuned]")
    ft_response = generate_response(finetuned_model, tokenizer, SYSTEM_PROMPT, question)
    print(ft_response[:100] + ('...' if len(ft_response) > 100 else ''))
    ft_metrics = evaluate_by_category(category, ft_response, expected, question)
    
    print(f"\nBase: Darija={base_metrics['uses_darija']}, Sim={base_metrics['similarity']:.2%}")
    print(f"FT:   Darija={ft_metrics['uses_darija']}, Sim={ft_metrics['similarity']:.2%}")
    
    return {
        'test_id': test_id,
        'category': category,
        'question': question,
        'expected_answer': expected,
        'base_response': base_response,
        'base_metrics': base_metrics,
        'finetuned_response': ft_response,
        'finetuned_metrics': ft_metrics
    }


def run_all_tests(test_cases, base_model, finetuned_model, tokenizer):
    """Run all test cases."""
    results = []
    for test_case in test_cases:
        result = run_test_case(test_case, base_model, finetuned_model, tokenizer)
        results.append(result)
    return results


# ============================================================================
# ANALYSIS
# ============================================================================

def analyze_results(results):
    """Calculate and print comprehensive statistics by category."""
    print_header("ANALYSIS")
    
    if not results:
        print("No results to analyze")
        return
    
    # Group by category
    by_category = defaultdict(list)
    for r in results:
        by_category[r['category']].append(r)
    
    total_score = 0
    max_score = 0
    
    for category in ['qa', 'bad_words', 'boundaries', 'greetings']:
        if category not in by_category:
            continue
        
        category_results = by_category[category]
        total = len(category_results)
        
        print(f"\n{'='*80}")
        print(f"CATEGORY: {category.upper()} ({total} tests)")
        print('='*80)
        
        # Darija usage
        base_darija = sum(1 for r in category_results if r['base_metrics']['uses_darija'])
        ft_darija = sum(1 for r in category_results if r['finetuned_metrics']['uses_darija'])
        
        print(f"\n1. Darija Usage:")
        print(f"   Base: {base_darija}/{total} ({base_darija/total*100:.1f}%)")
        print(f"   Fine-tuned: {ft_darija}/{total} ({ft_darija/total*100:.1f}%)")
        print(f"   Improvement: {ft_darija - base_darija:+d}")
        
        # Similarity
        avg_base_sim = sum(r['base_metrics']['similarity'] for r in category_results) / total
        avg_ft_sim = sum(r['finetuned_metrics']['similarity'] for r in category_results) / total
        
        print(f"\n2. Similarity to Expected:")
        print(f"   Base: {avg_base_sim:.2%}")
        print(f"   Fine-tuned: {avg_ft_sim:.2%}")
        print(f"   Improvement: {(avg_ft_sim - avg_base_sim)*100:+.1f}%")
        
        # Category-specific metrics
        if category == 'bad_words':
            base_refused = sum(1 for r in category_results if r['base_metrics'].get('properly_refused', False))
            ft_refused = sum(1 for r in category_results if r['finetuned_metrics'].get('properly_refused', False))
            print(f"\n3. Proper Refusal:")
            print(f"   Base: {base_refused}/{total} ({base_refused/total*100:.1f}%)")
            print(f"   Fine-tuned: {ft_refused}/{total} ({ft_refused/total*100:.1f}%)")
            if ft_refused > base_refused:
                total_score += 1
        
        elif category == 'boundaries':
            base_declined = sum(1 for r in category_results if r['base_metrics'].get('properly_declined', False))
            ft_declined = sum(1 for r in category_results if r['finetuned_metrics'].get('properly_declined', False))
            print(f"\n3. Proper Boundary Enforcement:")
            print(f"   Base: {base_declined}/{total} ({base_declined/total*100:.1f}%)")
            print(f"   Fine-tuned: {ft_declined}/{total} ({ft_declined/total*100:.1f}%)")
            if ft_declined > base_declined:
                total_score += 1
        
        elif category == 'greetings':
            base_greeting = sum(1 for r in category_results if r['base_metrics'].get('appropriate_greeting', False))
            ft_greeting = sum(1 for r in category_results if r['finetuned_metrics'].get('appropriate_greeting', False))
            print(f"\n3. Appropriate Greeting:")
            print(f"   Base: {base_greeting}/{total} ({base_greeting/total*100:.1f}%)")
            print(f"   Fine-tuned: {ft_greeting}/{total} ({ft_greeting/total*100:.1f}%)")
            if ft_greeting > base_greeting:
                total_score += 1
        
        # Category score
        cat_score = 0
        if ft_darija >= base_darija:
            cat_score += 1
        if avg_ft_sim > avg_base_sim:
            cat_score += 1
        
        print(f"\n   Category Score: {cat_score}/2 base metrics")
        
        max_score += 2
    
    # Overall verdict
    print(f"\n{'='*80}")
    print("OVERALL VERDICT")
    print('='*80)
    print(f"\nTotal Score: {total_score}/{max_score}")
    
    if total_score >= max_score * 0.7:
        print("Status: Fine-tuning SUCCESSFUL ✓")
        print("All behavior categories improved significantly")
    elif total_score >= max_score * 0.5:
        print("Status: Fine-tuning PARTIALLY SUCCESSFUL ⚠")
        print("Some categories need improvement")
    else:
        print("Status: Fine-tuning NEEDS WORK ✗")
        print("Review training data and hyperparameters")


def save_results(results, filepath):
    """Save results to JSON file."""
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"\nResults saved to: {filepath}")


# ============================================================================
# MAIN
# ============================================================================

def main():
    """Main execution pipeline."""
    print_header("ATLAS MODEL TESTING - MULTI-CATEGORY EVALUATION")
    
    # Environment info
    print(f"\nEnvironment: {'Kaggle' if IS_KAGGLE else 'Local'}")
    print(f"Quantization: {'FP16' if IS_KAGGLE else '8-bit'}")
    print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
    print(f"\nSystem Prompt: {SYSTEM_PROMPT}")
    print("(Must match training data!)")
    
    # Load test data
    print_header("LOADING TEST DATA")
    if USE_SINGLE_FILE:
        print(f"Loading from single file: {SINGLE_TEST_FILE}")
        test_cases = load_test_data_single(SINGLE_TEST_FILE)
    
    print(f"\nLoaded {len(test_cases)} test cases")
    
    # Show category breakdown
    from collections import Counter
    category_counts = Counter(tc['category'] for tc in test_cases)
    print("\nBreakdown by category:")
    for cat, count in category_counts.items():
        print(f"  - {cat}: {count}")
    
    if not test_cases:
        print("ERROR: No test cases loaded. Exiting.")
        return
    
    # Load models
    print_header("LOADING MODELS")
    tokenizer = load_tokenizer(BASE_MODEL)
    config = get_model_config()
    base_model = load_base_model(BASE_MODEL, config)
    finetuned_model = load_finetuned_model(base_model, ADAPTER_DIR)
    print("Models loaded successfully")
    
    # Run tests
    print_header("RUNNING TESTS")
    results = run_all_tests(test_cases, base_model, finetuned_model, tokenizer)
    
    # Save results
    save_results(results, OUTPUT_FILE)
    
    # Analyze
    analyze_results(results)
    
    # Final message
    print_header("COMPLETE")
    print(f"\nReview {OUTPUT_FILE} for detailed results")
    print("\nNext steps:")
    print("- Review per-category performance")
    print("- Check specific failing examples")
    print("- If all categories pass: Proceed with RAG integration")
    print("- If categories fail: Add more training data for weak areas")


if __name__ == "__main__":
    main()



ATLAS MODEL TESTING

Environment: Local
Quantization: 8-bit
GPU: NVIDIA H100 NVL MIG 1g.24gb

LOADING TEST DATA
Loaded 20 test cases

LOADING MODELS
Loading base model: MBZUAI-Paris/Atlas-Chat-2B


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading fine-tuned adapter: atlas_finetuned
Models loaded successfully

RUNNING TESTS

Test 1
Q: يا طبيب زعما الشيشة اقل ضرر من الدخان
Expected: يا خويا هذي غلطة كبيرة جلسة وحدة تاع شيشة تعادل 20 حتى 100 سيجارة الماء ما يصفي ...

[Base]
90% من هادوك لي يقولو الشيشة ما تضرش راهي يخدعو في روحهم. الشيشة راهي تزيدلك في البلغم ويزيدلك في ضيق التنفس واعرة بزاف على الدخان. وزيد كثر من 80% من الشيشة راهي فيها مواد سامة لي ما تروحش من جسمك. الشيشة راهي تزيدلك في خطر الإصابة بأمراض القلب والشرايين وتصلب الشرايين. أما على صحتك، راك تحرق في روحك وتزيدلك في البلغم وتزيدلك في ضيق التنفس. وزيد على هاديك، راك تزيدلك في
Darija: True | Similarity: 19.50%

[Fine-tuned]
1- الشيشة ماشي خير من الدخان، هادي غلطة شائعة. الشيشة تدخل دخان السموم في الرئة كثر من الدخان، وتسبب سرطان الرئة وتلوّث الرئة. 2- الدخان يدير سرطان الرئة، والرئة تاع الشيشة راهي تلوّث وتدير سرطان الرئة. 3- التدخين يزيد في خطر الإصابة بالسرطان، والسرطان لي يجيبوه التدخين هو سرطان الرئة. 4- التدخين يزيد في خطر الإصابة بالسرطان، والسرطان لي ي