# Medical NER Fine-Tuning with Llama 3.2 3B + LoRA

This notebook implements fine-tuning of Llama 3.2 3B Instruct for medical Named Entity Recognition (NER) using:
- **SFT** (Supervised Fine-Tuning)
- **LoRA** (Low-Rank Adaptation)
- **Hugging Face Hub** integration for checkpoint uploads

## Tasks:
1. Chemical entity extraction
2. Disease entity extraction
3. Chemical-Disease relationship extraction

## Dataset:
- 3,000 medical text examples
- 80/10/10 train/validation/test split
- **‚ö†Ô∏è CRITICAL**: Data is shuffled before splitting to ensure balanced task distribution
- Weights & Biases tracking enabled

## Important Note:
**Data splitting MUST use `shuffle=True`** to prevent task imbalance. Without shuffling, all relationship extraction examples may cluster in validation/test sets, leading to poor model performance on the most important task!

## 0. Environment Variables Setup

‚ö†Ô∏è **IMPORTANT**: Set your credentials before running this notebook!

Required:
- `HF_TOKEN`: Your Hugging Face token (needed to save models to HF Hub)

Optional:
- `WANDB_API_KEY`: Your Weights & Biases API key (for training tracking)

In [None]:
import os

# Set your Hugging Face token (required for uploading to HF Hub)
os.environ["HF_TOKEN"] = "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"

# Set your Weights & Biases API key (optional, for training tracking)
os.environ["WANDB_API_KEY"] = "d88df098d85360ac924ec2bf8dcf5520d745c411"

# Verify environment variables
print("‚úì Environment variables set")
print(f"  HF_TOKEN: {'‚úì Set' if os.environ.get('HF_TOKEN') and os.environ['HF_TOKEN'] != 'hf_YOUR_TOKEN_HERE' else '‚úó Not set - UPDATE THIS!'}")
print(f"  WANDB_API_KEY: {'‚úì Set' if os.environ.get('WANDB_API_KEY') else '‚óã Optional (will use wandb login cache)'}")

## 1. Setup and Installation

First, let's install all required dependencies.

In [None]:
# Install required packages
!pip install -q transformers datasets peft accelerate bitsandbytes
!pip install -q huggingface-hub tokenizers trl scikit-learn
!pip install -q scipy sentencepiece protobuf wandb

print("‚úì All packages installed successfully!")

## 2. Import Libraries

In [None]:
import json
import torch
import os
import random
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    EarlyStoppingCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from huggingface_hub import login
import wandb

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## 3. Configuration

‚ö†Ô∏è **IMPORTANT**: Update `HF_USERNAME` with your Hugging Face username!

In [None]:
# Configuration Section
from datetime import datetime

HF_USERNAME = "albyos"  # Replace with your HF username

# Generate timestamp for checkpoint naming
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
HF_MODEL_ID = f"{HF_USERNAME}/llama3-medical-ner-lora-{TIMESTAMP}"
BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_NAME = BASE_MODEL  # Alias for consistency

# LoRA Configuration
LORA_RANK = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# Training Configuration
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRADIENT_ACCUMULATION = 4
LEARNING_RATE = 2e-4

# Data Configuration
TRAIN_SPLIT_RATIO = 0.9
RANDOM_SEED = 42
RESHUFFLE_SPLITS_EACH_RUN = True  # When True, create a fresh validation split every run
SPLIT_SEED = random.randint(0, 1_000_000) if RESHUFFLE_SPLITS_EACH_RUN else RANDOM_SEED

print("‚úì Configuration loaded")
print(f"  Base model: {BASE_MODEL}")
print(f"  HF model ID: {HF_MODEL_ID}")
print(f"  Training timestamp: {TIMESTAMP}")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Training epochs: {NUM_EPOCHS}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  Data split seed: {SPLIT_SEED} ({'reshuffled' if RESHUFFLE_SPLITS_EACH_RUN else 'fixed'})")

## 4. Hugging Face Authentication

Get your token from: https://huggingface.co/settings/tokens

In [None]:
# Authenticate with Hugging Face
from huggingface_hub import login

hf_token = os.environ.get("HF_TOKEN")
if hf_token and hf_token != "hf_YOUR_TOKEN_HERE":
    login(token=hf_token)
    print("‚úì Logged in to Hugging Face")
else:
    print("‚ö† HF_TOKEN not set. Please update Cell 3 before continuing.")

## 4b. Weights & Biases Setup

Initialize W&B to track training metrics, validation loss, and experiments.
Get your API key from: https://wandb.ai/authorize

In [None]:
# Login to Weights & Biases
wandb_key = os.getenv('WANDB_API_KEY')

if wandb_key and wandb_key != 'your_wandb_key_here':
    wandb.login(key=wandb_key)
    print('‚úì Logged in to Weights & Biases using WANDB_API_KEY')
else:
    print('‚ö† Warning: WANDB_API_KEY not set. Attempting to use cached login...')
    try:
        wandb.login()
        print('‚úì Logged in to Weights & Biases using cached credentials')
    except Exception as e:
        print(f'‚ö† Warning: Could not login to W&B: {e}')
        print('  Run wandb.login() interactively or set WANDB_API_KEY environment variable')

In [None]:
# Initialize Weights & Biases
wandb.init(
    project="medical-ner-finetuning",
    name=f"llama3-medical-ner-{TIMESTAMP}",
    config={
        "model": BASE_MODEL,
        "lora_rank": LORA_RANK,
        "lora_alpha": LORA_ALPHA,
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION,
    }
)

print("‚úì Weights & Biases initialized")
print(f"  Project: medical-ner-finetuning")
print(f"  Run name: llama3-medical-ner-{TIMESTAMP}")
print(f"  Dashboard: https://wandb.ai")

## 5. Data Exploration

Let's examine the dataset structure.

In [None]:
# Load and inspect the dataset
# Load data
with open('both_rel_instruct_all.jsonl', 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

print(f"Total samples: {len(data)}")
print(f"\nSample structure:")
print(json.dumps(data[0], indent=2)[:500] + "...")

In [None]:
# Analyze task distribution
task_counts = {}
for sample in data:
    if "chemicals mentioned" in sample['prompt']:
        task = "Chemical Extraction"
    elif "diseases mentioned" in sample['prompt']:
        task = "Disease Extraction"
    elif "influences between" in sample['prompt']:
        task = "Relationship Extraction"
    else:
        task = "Other"
    
    task_counts[task] = task_counts.get(task, 0) + 1

print("Task Distribution:")
for task, count in task_counts.items():
    print(f"  {task}: {count} ({count/len(data)*100:.1f}%)")

In [None]:
# Show example from each task type
print("="*80)
print("EXAMPLE: Chemical Extraction")
print("="*80)
chem_example = [s for s in data if "chemicals mentioned" in s['prompt']][0]
print(f"Prompt:\n{chem_example['prompt'][:300]}...")
print(f"\nCompletion:\n{chem_example['completion']}")

print("\n" + "="*80)
print("EXAMPLE: Disease Extraction")
print("="*80)
disease_example = [s for s in data if "diseases mentioned" in s['prompt']][0]
print(f"Prompt:\n{disease_example['prompt'][:300]}...")
print(f"\nCompletion:\n{disease_example['completion']}")

## 6. Dataset Splitting

‚ö†Ô∏è **CRITICAL**: Using **stratified splitting** for guaranteed balanced task distribution!

**Previous Issue**: Without shuffling, all relationship extraction examples ended up in validation/test sets, causing poor model performance.

**New Solution**: Stratified splitting ensures EXACT proportions in all splits (not just probabilistic).

Split into:
- **80% Training** (2,400 samples) - for fine-tuning
- **10% Validation** (300 samples) - for monitoring during training (W&B)
- **10% Test** (300 samples) - for final evaluation after training

**Guaranteed distribution in each split** (with stratification):
- **Exactly 33.3%** Chemical extraction
- **Exactly 33.3%** Disease extraction  
- **Exactly 33.3%** Relationship extraction

**Why stratified?**
- `shuffle=True` gives ~33% ¬± 2-3% (probabilistic, good enough)
- `stratify=labels` gives **exactly 33.3%** (guaranteed, better!)

In [None]:
# Split data into train/val/test (80/10/10) with STRATIFIED sampling
random.seed(SPLIT_SEED)

# Helper function to classify task type for stratification
def get_task_type(prompt):
    """Classify the task type based on prompt for stratification."""
    prompt_lower = prompt.lower()
    if "influences between" in prompt_lower:
        return "relationship"
    elif "chemicals mentioned" in prompt_lower:
        return "chemical"
    elif "diseases mentioned" in prompt_lower:
        return "disease"
    return "other"

# Create stratification labels for all data
stratify_labels = [get_task_type(sample['prompt']) for sample in data]

print(f"Creating stratified splits to guarantee balanced task distribution...")
print(f"Original task distribution: {set(stratify_labels)}")

# First split: 80% train, 20% temp (for val + test)
# Using stratify= ensures EXACT proportions in both splits
train_data, temp_data, train_labels, temp_labels = train_test_split(
    data,
    stratify_labels,
    test_size=0.2,  # 20% for validation + test
    random_state=SPLIT_SEED,
    stratify=stratify_labels  # ‚úÖ GUARANTEES exact 33.3% in both train and temp!
)

# Second split: split the 20% into 10% val, 10% test
# Stratify again to ensure exact proportions in val and test
val_data, test_data, val_labels, test_labels = train_test_split(
    temp_data,
    temp_labels,
    test_size=0.5,  # 50% of 20% = 10% of total
    random_state=SPLIT_SEED + 1,
    stratify=temp_labels  # ‚úÖ GUARANTEES exact 33.3% in both val and test!
)

# Save splits
with open('train.jsonl', 'w', encoding='utf-8') as f:
    for item in train_data:
        f.write(json.dumps(item) + '\n')

with open('validation.jsonl', 'w', encoding='utf-8') as f:
    for item in val_data:
        f.write(json.dumps(item) + '\n')

with open('test.jsonl', 'w', encoding='utf-8') as f:
    for item in test_data:
        f.write(json.dumps(item) + '\n')

print(f"‚úì Dataset split complete (seed={SPLIT_SEED}, stratified=True)")
print(f"  Train samples: {len(train_data)} ({len(train_data)/len(data)*100:.1f}%)")
print(f"  Validation samples: {len(val_data)} ({len(val_data)/len(data)*100:.1f}%) - for training monitoring")
print(f"  Test samples: {len(test_data)} ({len(test_data)/len(data)*100:.1f}%) - for final evaluation")
print(f"\nüìä Usage:")
print(f"  - Train: Used for fine-tuning")
print(f"  - Validation: Monitored during training (shown in W&B)")
print(f"  - Test: Used ONLY after training for final evaluation")

In [None]:
# Verify task distribution across splits
def get_task_type_display(prompt):
    """Classify the task type based on prompt for display."""
    prompt_lower = prompt.lower()
    if "influences between" in prompt_lower:
        return "Relationship Extraction"
    elif "chemicals mentioned" in prompt_lower:
        return "Chemical Extraction"
    elif "diseases mentioned" in prompt_lower:
        return "Disease Extraction"
    return "Other"

print("\n" + "="*80)
print("TASK DISTRIBUTION VERIFICATION (STRATIFIED SPLITTING)")
print("="*80)

for split_name, split_data in [("Train", train_data), ("Validation", val_data), ("Test", test_data)]:
    task_counts = {}
    for sample in split_data:
        task = get_task_type_display(sample['prompt'])
        task_counts[task] = task_counts.get(task, 0) + 1
    
    print(f"\n{split_name} ({len(split_data)} samples):")
    for task, count in sorted(task_counts.items()):
        percentage = count / len(split_data) * 100
        # Check if exactly balanced (within 0.5% tolerance)
        is_perfect = abs(percentage - 33.33) < 0.5
        marker = "‚úÖ" if is_perfect else "‚ö†Ô∏è"
        print(f"  {marker} {task}: {count} ({percentage:.1f}%)")

# Verify no data leakage between splits
train_prompts = set(s['prompt'] for s in train_data)
val_prompts = set(s['prompt'] for s in val_data)
test_prompts = set(s['prompt'] for s in test_data)

print(f"\n{'='*80}")
print("DATA INTEGRITY CHECK")
print("="*80)
overlap_train_val = len(train_prompts & val_prompts)
overlap_train_test = len(train_prompts & test_prompts)
overlap_val_test = len(val_prompts & test_prompts)

print(f"Train-Validation overlap: {overlap_train_val} samples {'‚úÖ Perfect!' if overlap_train_val == 0 else '‚ö†Ô∏è  WARNING - Data leakage detected!'}")
print(f"Train-Test overlap: {overlap_train_test} samples {'‚úÖ Perfect!' if overlap_train_test == 0 else '‚ö†Ô∏è  WARNING - Data leakage detected!'}")
print(f"Validation-Test overlap: {overlap_val_test} samples {'‚úÖ Perfect!' if overlap_val_test == 0 else '‚ö†Ô∏è  WARNING - Data leakage detected!'}")

if overlap_train_val == 0 and overlap_train_test == 0 and overlap_val_test == 0:
    print("\n‚úÖ All splits are properly separated - no data leakage detected!")
    print("‚úÖ Stratified splitting guarantees exact task proportions in all splits!")
else:
    print("\n‚ö†Ô∏è  WARNING: Data leakage detected! Splits contain overlapping samples!")

## 7. Data Formatting

Format data into Llama 3 chat format with system, user, and assistant roles.

In [None]:
def format_instruction(sample):
    """Format data into Llama 3 chat format."""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a medical NER expert. Extract the requested entities from medical texts accurately.<|eot_id|><|start_header_id|>user<|end_header_id|>

{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{sample['completion']}<|eot_id|>"""

# Test formatting
formatted_example = format_instruction(train_data[0])
print("Formatted Example:")
print(formatted_example[:500] + "...")

In [None]:
# Format all data
train_formatted = [{"text": format_instruction(sample)} for sample in train_data]
val_formatted = [{"text": format_instruction(sample)} for sample in val_data]
test_formatted = [{"text": format_instruction(sample)} for sample in test_data]

# Create HuggingFace datasets
train_dataset = Dataset.from_list(train_formatted)
val_dataset = Dataset.from_list(val_formatted)
test_dataset = Dataset.from_list(test_formatted)

print(f"‚úì Datasets formatted:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Validation: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

## 8. Load Model and Tokenizer

Load Llama 3.2 3B with 4-bit quantization for memory efficiency.

In [None]:
# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("‚úì Quantization config created (4-bit NF4)")

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",
    add_eos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

print(f"‚úì Tokenizer loaded: {MODEL_NAME}")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  PAD token: {tokenizer.pad_token}")
print(f"  EOS token: {tokenizer.eos_token}")

In [None]:
# Load base model
print("Loading model... (this may take a few minutes)")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

print(f"‚úì Base model loaded: {MODEL_NAME}")
print(f"  Model size: {model.get_memory_footprint() / 1e9:.2f} GB")

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
print("‚úì Model prepared for k-bit training")

## 9. Configure LoRA

Apply Low-Rank Adaptation for efficient fine-tuning.

In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=LORA_RANK,                   # LoRA rank
    lora_alpha=LORA_ALPHA,         # LoRA alpha (scaling)
    target_modules=[               # Layers to apply LoRA
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
    lora_dropout=0.05,             # Dropout for regularization
    bias="none",                   # No bias training
    task_type="CAUSAL_LM"          # Causal language modeling
)

print(f"‚úì LoRA configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Dropout: {lora_config.lora_dropout}")
print(f"  Target modules: {len(lora_config.target_modules)}")

In [None]:
# Apply LoRA to model
model = get_peft_model(model, lora_config)

print("‚úì LoRA applied to model")
print("\nTrainable parameters:")
model.print_trainable_parameters()

## 10. Tokenize Datasets

In [None]:
def tokenize_function(examples):
    """Tokenize the texts."""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=2048,
        padding=False,
    )

# Tokenize datasets
print("Tokenizing datasets...")

tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing train set"
)

tokenized_val = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Tokenizing validation set"
)

print(f"‚úì Train set tokenized: {len(tokenized_train)} samples")
print(f"‚úì Validation set tokenized: {len(tokenized_val)} samples")

In [None]:
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Causal LM, not masked LM
)

print("‚úì Data collator created")

## 11. Training Configuration

In [None]:
# Training arguments
training_args = TrainingArguments(
    # Output and logging
    output_dir="./llama3-medical-ner-lora",
    logging_dir="./logs",
    logging_steps=10,
    
    # Training parameters
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    
    # Optimization
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    
    # Evaluation
    eval_strategy="steps",
    eval_steps=50,  # Evaluate every 50 steps
    
    # Checkpointing - Save every 50 steps
    save_strategy="steps",
    save_steps=50,  # Checkpoint every 50 steps
    save_total_limit=None,  # Keep all checkpoints (will push to HF with unique names)
    
    # Memory optimization
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    
    # Mixed precision
    fp16=True,
    
    # Hugging Face Hub - Disable default push (we'll use custom callback)
    push_to_hub=False,  # Custom callback will handle timestamped uploads
    
    # Misc
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="wandb",  # Enable Weights & Biases logging
    run_name=f"llama3-medical-ner-{TIMESTAMP}",  # W&B run name
    seed=42,
)

print(f"‚úì Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size (per device): {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Checkpoint frequency: Every {training_args.save_steps} steps")
print(f"  Base HF model ID: {HF_MODEL_ID}")
print(f"  ‚ö†Ô∏è Checkpoints will be pushed to HF with timestamp suffix")

## 11b. Custom Checkpoint Upload Callback

This callback will automatically push each checkpoint to Hugging Face Hub with a unique timestamped name every 50 steps.

In [None]:
from transformers import TrainerCallback
from huggingface_hub import HfApi
import shutil
from pathlib import Path

class CheckpointUploadCallback(TrainerCallback):
    """
    Custom callback to upload checkpoints to Hugging Face Hub with timestamped names.
    
    Each checkpoint will be saved with format:
    {HF_USERNAME}/llama3-medical-ner-lora-checkpoint-{step}-{timestamp}
    """
    
    def __init__(self, base_model_id, hf_username):
        self.base_model_id = base_model_id
        self.hf_username = hf_username
        self.api = HfApi()
        
    def on_save(self, args, state, control, **kwargs):
        """
        Called when a checkpoint is saved.
        Uploads the checkpoint to HF Hub with a timestamped name.
        """
        # Get the checkpoint directory that was just saved
        checkpoint_dir = f"{args.output_dir}/checkpoint-{state.global_step}"
        
        if not Path(checkpoint_dir).exists():
            print(f"‚ö†Ô∏è Checkpoint directory not found: {checkpoint_dir}")
            return
        
        # Create timestamped model ID
        checkpoint_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        checkpoint_model_id = f"{self.hf_username}/llama3-medical-ner-checkpoint-{state.global_step}-{checkpoint_timestamp}"
        
        print(f"\n{'='*80}")
        print(f"üì§ Uploading checkpoint to Hugging Face Hub")
        print(f"   Step: {state.global_step}")
        print(f"   Model ID: {checkpoint_model_id}")
        print(f"{'='*80}\n")
        
        try:
            # Upload the checkpoint folder to HF Hub
            self.api.upload_folder(
                folder_path=checkpoint_dir,
                repo_id=checkpoint_model_id,
                repo_type="model",
                commit_message=f"Checkpoint at step {state.global_step}",
            )
            
            print(f"‚úÖ Checkpoint uploaded successfully!")
            print(f"   URL: https://huggingface.co/{checkpoint_model_id}\n")
            
            # Log to wandb if available
            if wandb.run is not None:
                wandb.log({
                    "checkpoint_step": state.global_step,
                    "checkpoint_url": f"https://huggingface.co/{checkpoint_model_id}"
                })
                
        except Exception as e:
            print(f"‚ùå Failed to upload checkpoint: {e}")
            print(f"   Checkpoint saved locally at: {checkpoint_dir}\n")

# Initialize the callback
checkpoint_upload_callback = CheckpointUploadCallback(
    base_model_id=HF_MODEL_ID,
    hf_username=HF_USERNAME
)

print(f"‚úì Checkpoint upload callback initialized")
print(f"  Checkpoints will be uploaded to: {HF_USERNAME}/llama3-medical-ner-checkpoint-<step>-<timestamp>")
print(f"  Upload frequency: Every {training_args.save_steps} steps")

In [None]:
# Preview expected checkpoint uploads
total_steps_estimate = (len(train_data) // (BATCH_SIZE * GRADIENT_ACCUMULATION)) * NUM_EPOCHS
checkpoint_count = total_steps_estimate // 50

print("="*80)
print("CHECKPOINT UPLOAD PREVIEW")
print("="*80)
print(f"\nTraining Configuration:")
print(f"  Total samples: {len(train_data)}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Estimated total steps: ~{total_steps_estimate}")
print(f"\nCheckpoint Configuration:")
print(f"  Frequency: Every 50 steps")
print(f"  Expected checkpoints: ~{checkpoint_count}")
print(f"  Local storage: ./llama3-medical-ner-lora/checkpoint-<step>/")
print(f"\nHugging Face Upload:")
print(f"  Format: {HF_USERNAME}/llama3-medical-ner-checkpoint-<step>-<timestamp>")
print(f"\nExample checkpoint names:")
for i, step in enumerate(range(50, min(250, total_steps_estimate), 50), 1):
    example_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    print(f"  {i}. {HF_USERNAME}/llama3-medical-ner-checkpoint-{step}-{example_time}")
if checkpoint_count > 4:
    print(f"  ... (~{checkpoint_count - 4} more checkpoints)")
    
print(f"\nFinal model:")
print(f"  {HF_USERNAME}/llama3-medical-ner-lora-final-<timestamp>")
print("="*80)

## 12. Initialize Trainer

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    callbacks=[checkpoint_upload_callback],  # Add custom checkpoint upload callback
)

# Configure early stopping to prevent overfitting
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.0))

# Calculate training steps
total_steps = (len(tokenized_train) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)) * training_args.num_train_epochs

print(f"‚úì Trainer initialized")
print(f"‚úì Expected training steps: ~{total_steps}")
print(f"‚úì Expected checkpoints: ~{max(1, total_steps // training_args.save_steps)}")
print(f"‚úì Checkpoint upload callback enabled")
print("‚úì Early stopping enabled (patience = 3 evaluations)")
print(f"\nüìã Checkpoint naming format:")
print(f"   {HF_USERNAME}/llama3-medical-ner-checkpoint-<step>-<timestamp>")
print(f"\n   Example: {HF_USERNAME}/llama3-medical-ner-checkpoint-50-20251104_143022")

## 13. Start Training

‚ö†Ô∏è **This will take 2-3 hours on an A100 GPU**

The training will:
- **Save checkpoints every 50 steps** to local disk
- **Upload each checkpoint to Hugging Face Hub** with timestamped names
  - Format: `{username}/llama3-medical-ner-checkpoint-{step}-{timestamp}`
  - Example: `albyos/llama3-medical-ner-checkpoint-50-20251104_143022`
- Evaluate on validation set every 50 steps
- Save the best model based on validation loss
- Log all metrics to Weights & Biases

**Checkpoint URLs will be printed during training and logged to W&B.**

In [None]:
# Start training
print("="*80)
print("STARTING TRAINING")
print("="*80)
print("This may take 2-3 hours on A100 GPU...\n")

trainer.train()

## 14. Save Final Model

In [None]:
# Save model locally
print("Saving final model...")
trainer.save_model("./final_model")
tokenizer.save_pretrained("./final_model")

print(f"‚úì Model saved to: ./final_model")

In [None]:
# Push final model to Hugging Face Hub with timestamped name
print("Pushing final model to Hugging Face Hub...")

# Create final model ID with timestamp
final_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_model_id = f"{HF_USERNAME}/llama3-medical-ner-lora-final-{final_timestamp}"

try:
    # Push the final model
    model.push_to_hub(
        final_model_id,
        commit_message="Training complete - final model"
    )
    tokenizer.push_to_hub(
        final_model_id,
        commit_message="Training complete - final tokenizer"
    )
    
    print(f"‚úÖ Final model pushed successfully!")
    print(f"   Model ID: {final_model_id}")
    print(f"   URL: https://huggingface.co/{final_model_id}")
    
    # Log to wandb
    if wandb.run is not None:
        wandb.log({
            "final_model_url": f"https://huggingface.co/{final_model_id}",
            "final_model_id": final_model_id
        })
        
except Exception as e:
    print(f"‚ö† Failed to push to hub: {e}")
    print("  Final model saved locally at: ./final_model")
    print(f"  You can manually push later using:")
    print(f"    model.push_to_hub('{final_model_id}')")

## 15. Training Analysis

In [None]:
# Plot training metrics
import pandas as pd
import matplotlib.pyplot as plt

# Get training history
log_history = trainer.state.log_history

# Extract losses
train_loss = [entry['loss'] for entry in log_history if 'loss' in entry]
eval_loss = [entry['eval_loss'] for entry in log_history if 'eval_loss' in entry]

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss, label='Training Loss', color='blue')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(eval_loss, label='Validation Loss', color='orange')
plt.xlabel('Evaluation Steps')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Training metrics plotted and saved to: training_metrics.png")

In [None]:
# Summary statistics
print("="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"Total training steps: {len(train_loss)}")
print(f"Final training loss: {train_loss[-1]:.4f}")
print(f"Final validation loss: {eval_loss[-1]:.4f}")
print(f"Best validation loss: {min(eval_loss):.4f}")
print(f"Loss reduction: {((train_loss[0] - train_loss[-1]) / train_loss[0] * 100):.1f}%")

## Next Steps

Training is complete! Your model has been saved.

**To evaluate your model:**
1. Open `Medical_NER_Evaluation.ipynb`
2. Run the evaluation on the test set
3. Test custom examples

**Model locations:**
- Local: `./final_model`
- HuggingFace Hub: Check the output above for your model URL
