# Medical NER Fine-Tuning with Llama 3.2 3B + LoRA

This notebook implements fine-tuning of Llama 3.2 3B Instruct for medical Named Entity Recognition (NER) using:
- **SFT** (Supervised Fine-Tuning)
- **LoRA** (Low-Rank Adaptation)
- **Hugging Face Hub** integration for checkpoint uploads

## Tasks:
1. Chemical entity extraction
2. Disease entity extraction
3. Chemical-Disease relationship extraction

## Dataset:
- 3,000 medical text examples
- 80/10/10 train/validation/test split
- Weights & Biases tracking enabled

## 1. Setup and Installation

First, let's install all required dependencies.

In [None]:
# Install required packages
!pip install -q transformers datasets peft accelerate bitsandbytes
!pip install -q huggingface-hub tokenizers trl scikit-learn
!pip install -q scipy sentencepiece protobuf wandb

print("‚úì All packages installed successfully!")

## 2. Import Libraries

In [None]:
import json
import torch
import os
import random
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from huggingface_hub import login
import wandb

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## 3. Configuration

‚ö†Ô∏è **IMPORTANT**: Update `HF_USERNAME` with your Hugging Face username!

In [None]:
# Configuration Section
from datetime import datetime

HF_USERNAME = "albyos"  # Replace with your HF username

# Generate timestamp for checkpoint naming
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
HF_MODEL_ID = f"{HF_USERNAME}/llama3-medical-ner-lora-{TIMESTAMP}"
BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_NAME = BASE_MODEL  # Alias for consistency

# LoRA Configuration
LORA_RANK = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# Training Configuration
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRADIENT_ACCUMULATION = 4
LEARNING_RATE = 2e-4

# Data Configuration
TRAIN_SPLIT_RATIO = 0.9
RANDOM_SEED = 42

print("‚úì Configuration loaded")
print(f"  Base model: {BASE_MODEL}")
print(f"  HF model ID: {HF_MODEL_ID}")
print(f"  Training timestamp: {TIMESTAMP}")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Training epochs: {NUM_EPOCHS}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")



## 4. Hugging Face Authentication

Get your token from: https://huggingface.co/settings/tokens

In [None]:
# Option 1: Set token as environment variable (recommended)
os.environ["HF_TOKEN"] = "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"

# Option 2: Interactive login
from huggingface_hub import notebook_login
notebook_login()

print("‚úì Logged in to Hugging Face")

## 4b. Weights & Biases Setup

Initialize W&B to track training metrics, validation loss, and experiments.
Get your API key from: https://wandb.ai/authorize

In [None]:
# Login to Weights & Biases
import os

if os.getenv('d88df098d85360ac924ec2bf8dcf5520d745c411'):
    wandb.login(key=os.getenv('d88df098d85360ac924ec2bf8dcf5520d745c411'))
    print('‚úì Logged in to Weights & Biases using WANDB_API_KEY')
else:
    print('‚ö† Warning: WANDB_API_KEY not found. Attempting to use cached login...')
    try:
        wandb.login()
        print('‚úì Logged in to Weights & Biases using cached credentials')
    except Exception as e:
        print(f'‚ö† Warning: Could not login to W&B: {e}')
        print('  Run wandb.login() interactively or set WANDB_API_KEY environment variable')

In [None]:
# Initialize Weights & Biases
wandb.init(
    project="medical-ner-finetuning",
    name=f"llama3-medical-ner-{TIMESTAMP}",
    config={
        "model": BASE_MODEL,
        "lora_rank": LORA_RANK,
        "lora_alpha": LORA_ALPHA,
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION,
    }
)

print("‚úì Weights & Biases initialized")
print(f"  Project: medical-ner-finetuning")
print(f"  Run name: llama3-medical-ner-{TIMESTAMP}")
print(f"  Dashboard: https://wandb.ai")

## 5. Data Exploration

Let's examine the dataset structure.

In [None]:
# Load and inspect the dataset
# Load data
with open('../data/both_rel_instruct_all.jsonl', 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

print(f"Total samples: {len(data)}")
print(f"\nSample structure:")
print(json.dumps(data[0], indent=2)[:500] + "...")

In [None]:
# Analyze task distribution
task_counts = {}
for sample in data:
    if "chemicals mentioned" in sample['prompt']:
        task = "Chemical Extraction"
    elif "diseases mentioned" in sample['prompt']:
        task = "Disease Extraction"
    elif "influences between" in sample['prompt']:
        task = "Relationship Extraction"
    else:
        task = "Other"
    
    task_counts[task] = task_counts.get(task, 0) + 1

print("Task Distribution:")
for task, count in task_counts.items():
    print(f"  {task}: {count} ({count/len(data)*100:.1f}%)")

In [None]:
# Show example from each task type
print("="*80)
print("EXAMPLE: Chemical Extraction")
print("="*80)
chem_example = [s for s in data if "chemicals mentioned" in s['prompt']][0]
print(f"Prompt:\n{chem_example['prompt'][:300]}...")
print(f"\nCompletion:\n{chem_example['completion']}")

print("\n" + "="*80)
print("EXAMPLE: Disease Extraction")
print("="*80)
disease_example = [s for s in data if "diseases mentioned" in s['prompt']][0]
print(f"Prompt:\n{disease_example['prompt'][:300]}...")
print(f"\nCompletion:\n{disease_example['completion']}")

## 6. Dataset Splitting

Split into:
- **80% Training** (2,400 samples) - for fine-tuning
- **10% Validation** (300 samples) - for monitoring during training (W&B)
- **10% Test** (300 samples) - for final evaluation after training

In [None]:
# Split data into train/val/test (80/10/10)
random.seed(42)

# First split: 80% train, 20% temp (for val + test)
train_data, temp_data = train_test_split(
    data,
    test_size=0.2,  # 20% for validation + test
    random_state=42,
    shuffle=True
)

# Second split: split the 20% into 10% val, 10% test
val_data, test_data = train_test_split(
    temp_data,
    test_size=0.5,  # 50% of 20% = 10% of total
    random_state=42,
    shuffle=True
)

# Save splits
with open('../data/train.jsonl', 'w', encoding='utf-8') as f:
    for item in train_data:
        f.write(json.dumps(item) + '\n')

with open('../data/validation.jsonl', 'w', encoding='utf-8') as f:
    for item in val_data:
        f.write(json.dumps(item) + '\n')

with open('../data/test.jsonl', 'w', encoding='utf-8') as f:
    for item in test_data:
        f.write(json.dumps(item) + '\n')

print(f"‚úì Dataset split complete:")
print(f"  Train samples: {len(train_data)} ({len(train_data)/len(data)*100:.1f}%)")
print(f"  Validation samples: {len(val_data)} ({len(val_data)/len(data)*100:.1f}%) - for training monitoring")
print(f"  Test samples: {len(test_data)} ({len(test_data)/len(data)*100:.1f}%) - for final evaluation")
print(f"\nüìä Usage:")
print(f"  - Train: Used for fine-tuning")
print(f"  - Validation: Monitored during training (shown in W&B)")
print(f"  - Test: Used ONLY after training for final evaluation")

## 7. Data Formatting

Format data into Llama 3 chat format with system, user, and assistant roles.

In [None]:
def format_instruction(sample):
    """Format data into Llama 3 chat format."""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a medical NER expert. Extract the requested entities from medical texts accurately.<|eot_id|><|start_header_id|>user<|end_header_id|>

{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{sample['completion']}<|eot_id|>"""

# Test formatting
formatted_example = format_instruction(train_data[0])
print("Formatted Example:")
print(formatted_example[:500] + "...")

In [None]:
# Format all data
train_formatted = [{"text": format_instruction(sample)} for sample in train_data]
val_formatted = [{"text": format_instruction(sample)} for sample in val_data]

# Create HuggingFace datasets
train_dataset = Dataset.from_list(train_formatted)
val_dataset = Dataset.from_list(val_formatted)

print(f"‚úì Datasets formatted:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Validation: {len(val_dataset)} samples")

## 8. Load Model and Tokenizer

Load Llama 3.2 3B with 4-bit quantization for memory efficiency.

In [None]:
# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("‚úì Quantization config created (4-bit NF4)")

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",
    add_eos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

print(f"‚úì Tokenizer loaded: {MODEL_NAME}")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  PAD token: {tokenizer.pad_token}")
print(f"  EOS token: {tokenizer.eos_token}")

In [None]:
# Load base model
print("Loading model... (this may take a few minutes)")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

print(f"‚úì Base model loaded: {MODEL_NAME}")
print(f"  Model size: {model.get_memory_footprint() / 1e9:.2f} GB")

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
print("‚úì Model prepared for k-bit training")

## 9. Configure LoRA

Apply Low-Rank Adaptation for efficient fine-tuning.

In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=LORA_RANK,                   # LoRA rank
    lora_alpha=LORA_ALPHA,         # LoRA alpha (scaling)
    target_modules=[               # Layers to apply LoRA
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
    lora_dropout=0.05,             # Dropout for regularization
    bias="none",                   # No bias training
    task_type="CAUSAL_LM"          # Causal language modeling
)

print(f"‚úì LoRA configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Dropout: {lora_config.lora_dropout}")
print(f"  Target modules: {len(lora_config.target_modules)}")

In [None]:
# Apply LoRA to model
model = get_peft_model(model, lora_config)

print("‚úì LoRA applied to model")
print("\nTrainable parameters:")
model.print_trainable_parameters()

## 10. Tokenize Datasets

In [None]:
def tokenize_function(examples):
    """Tokenize the texts."""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=2048,
        padding=False,
    )

# Tokenize datasets
print("Tokenizing datasets...")

tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing train set"
)

tokenized_val = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Tokenizing validation set"
)

print(f"‚úì Train set tokenized: {len(tokenized_train)} samples")
print(f"‚úì Validation set tokenized: {len(tokenized_val)} samples")

In [None]:
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Causal LM, not masked LM
)

print("‚úì Data collator created")

## 11. Training Configuration

In [None]:
# Training arguments
training_args = TrainingArguments(
    # Output and logging
    output_dir="./llama3-medical-ner-lora",
    logging_dir="./logs",
    logging_steps=10,
    
    # Training parameters
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    
    # Optimization
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    
    # Evaluation
    eval_strategy="steps",
    eval_steps=100,
    
    # Checkpointing
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    
    # Memory optimization
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    
    # Mixed precision
    fp16=True,
    
    # Hugging Face Hub
    push_to_hub=True,
    hub_model_id=HF_MODEL_ID,
    hub_strategy="checkpoint",
    hub_private_repo=False,
    
    # Misc
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="wandb",  # Enable Weights & Biases logging
    run_name=f"llama3-medical-ner-{TIMESTAMP}",  # W&B run name
    seed=42,
)

print(f"‚úì Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size (per device): {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Hub model ID: {HF_MODEL_ID}")

## 12. Initialize Trainer

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
)

# Calculate training steps
total_steps = (len(tokenized_train) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)) * training_args.num_train_epochs

print(f"‚úì Trainer initialized")
print(f"‚úì Expected training steps: ~{total_steps}")
print(f"‚úì Expected checkpoints: ~{total_steps // training_args.save_steps}")

## 13. Start Training

‚ö†Ô∏è **This will take 2-3 hours on an A100 GPU**

The training will:
- Save checkpoints every 100 steps
- Upload checkpoints to Hugging Face Hub
- Evaluate on validation set every 100 steps
- Save the best model based on validation loss

In [None]:
# Start training
print("="*80)
print("STARTING TRAINING")
print("="*80)
print("This may take 2-3 hours on A100 GPU...\n")

trainer.train()

## 14. Save Final Model

In [None]:
# Save model locally
print("Saving final model...")
trainer.save_model("./final_model")
tokenizer.save_pretrained("./final_model")

print(f"‚úì Model saved to: ./final_model")

In [None]:
# Push to Hugging Face Hub
print("Pushing to Hugging Face Hub...")

try:
    trainer.push_to_hub(commit_message="Training complete - final model")
    print(f"‚úì Model pushed to: https://huggingface.co/{HF_MODEL_ID}")
except Exception as e:
    print(f"‚ö† Failed to push to hub: {e}")
    print("  You can manually push later using: trainer.push_to_hub()")

## 15. Training Analysis

In [None]:
# Plot training metrics
import pandas as pd
import matplotlib.pyplot as plt

# Get training history
log_history = trainer.state.log_history

# Extract losses
train_loss = [entry['loss'] for entry in log_history if 'loss' in entry]
eval_loss = [entry['eval_loss'] for entry in log_history if 'eval_loss' in entry]

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss, label='Training Loss', color='blue')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(eval_loss, label='Validation Loss', color='orange')
plt.xlabel('Evaluation Steps')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Training metrics plotted and saved to: training_metrics.png")

In [None]:
# Summary statistics
print("="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"Total training steps: {len(train_loss)}")
print(f"Final training loss: {train_loss[-1]:.4f}")
print(f"Final validation loss: {eval_loss[-1]:.4f}")
print(f"Best validation loss: {min(eval_loss):.4f}")
print(f"Loss reduction: {((train_loss[0] - train_loss[-1]) / train_loss[0] * 100):.1f}%")

## 16. Final Evaluation on Test Set

Test the fine-tuned model on **test samples that were NOT seen during training or validation**.
The test set is completely separate - it was never used for training or monitoring.
This gives us the truest measure of the model's generalization ability.

In [None]:
# Load the fine-tuned model for inference
print("Loading fine-tuned model for validation...")

# Clear GPU memory first
del model
del trainer
torch.cuda.empty_cache()

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
)

# Load LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "./final_model",
)
model.eval()

print("‚úì Fine-tuned model loaded for inference")

In [None]:
def generate_response(prompt_text, max_new_tokens=512):
    """Generate a response for a given prompt."""
    formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a medical NER expert. Extract the requested entities from medical texts accurately.<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=True,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract assistant's response
    if "assistant\n\n" in response:
        response = response.split("assistant\n\n")[-1]
    elif "assistant" in response:
        response = response.split("assistant")[-1].strip()
    
    return response.strip()

print("‚úì Inference function ready")

In [None]:
# Test on COMPLETELY UNSEEN test samples
# The test set was not used for training OR validation monitoring
with open('../data/test.jsonl', 'r', encoding='utf-8') as f:
    test_data = [json.loads(line) for line in f]

num_test_samples = 5
print(f"Testing on {num_test_samples} samples from TEST SET")
print(f"Total test set size: {len(test_data)}")
print(f"\n‚ö†Ô∏è  IMPORTANT:")
print(f"  - Training set (80%): Used for fine-tuning")
print(f"  - Validation set (10%): Monitored during training (W&B)")
print(f"  - Test set (10%): Used ONLY NOW for final evaluation")

# Aggregate metrics
total_correct = 0
total_predicted = 0
total_expected = 0

for i, sample in enumerate(test_data[:num_test_samples]):
    print("\n" + "="*80)
    print(f"FINAL TEST EXAMPLE {i+1}/{num_test_samples}")
    print("="*80)
    
    # Show prompt (truncated for readability)
    print(f"\nüìù PROMPT:")
    prompt_preview = sample['prompt'][:250] + "..." if len(sample['prompt']) > 250 else sample['prompt']
    print(f"{prompt_preview}")
    
    # Show expected output
    print(f"\n‚úÖ EXPECTED OUTPUT:")
    print(f"{sample['completion']}")
    
    # Generate prediction
    print(f"\nü§ñ MODEL PREDICTION:")
    prediction = generate_response(sample['prompt'])
    print(f"{prediction}")
    
    # Calculate metrics
    expected_items = set([item.strip() for item in sample['completion'].split('\n') if item.strip()])
    predicted_items = set([item.strip() for item in prediction.split('\n') if item.strip()])
    
    common = expected_items & predicted_items
    missing = expected_items - predicted_items
    extra = predicted_items - expected_items
    
    # Update aggregate counts
    total_correct += len(common)
    total_predicted += len(predicted_items)
    total_expected += len(expected_items)
    
    # Per-sample metrics
    accuracy = len(common) / len(expected_items) * 100 if len(expected_items) > 0 else 0
    precision = len(common) / len(predicted_items) * 100 if len(predicted_items) > 0 else 0
    recall = len(common) / len(expected_items) * 100 if len(expected_items) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\nüìä EVALUATION METRICS:")
    print(f"  ‚úì Correct extractions: {len(common)}/{len(expected_items)}")
    print(f"  ‚úó Missed extractions: {len(missing)}")
    print(f"  ‚ö† Extra extractions: {len(extra)}")
    print(f"\n  üìà Per-Sample Metrics:")
    print(f"    Accuracy:  {accuracy:.1f}%")
    print(f"    Precision: {precision:.1f}%")
    print(f"    Recall:    {recall:.1f}%")
    print(f"    F1 Score:  {f1:.1f}%")
    
    if missing:
        print(f"\n  Missed items: {list(missing)[:3]}")
    if extra:
        print(f"  Extra items: {list(extra)[:3]}")


In [None]:
# Aggregate Metrics across all test samples
print("\n" + "="*80)
print("AGGREGATE METRICS ACROSS TEST SAMPLES")
print("="*80)

# Calculate aggregate metrics
aggregate_precision = total_correct / total_predicted * 100 if total_predicted > 0 else 0
aggregate_recall = total_correct / total_expected * 100 if total_expected > 0 else 0
aggregate_f1 = 2 * (aggregate_precision * aggregate_recall) / (aggregate_precision + aggregate_recall) if (aggregate_precision + aggregate_recall) > 0 else 0
aggregate_accuracy = total_correct / total_expected * 100 if total_expected > 0 else 0

print(f"\nEvaluated on {num_test_samples} test samples:")
print(f"\nüìä Overall Performance:")
print(f"  Total expected entities:  {total_expected}")
print(f"  Total predicted entities: {total_predicted}")
print(f"  Correctly predicted:      {total_correct}")

print(f"\nüìà Aggregate Metrics:")
print(f"  Accuracy:  {aggregate_accuracy:.2f}%")
print(f"  Precision: {aggregate_precision:.2f}% (fewer false positives)")
print(f"  Recall:    {aggregate_recall:.2f}% (fewer false negatives)")
print(f"  F1 Score:  {aggregate_f1:.2f}% (balanced metric)")

print(f"\nüí° Interpretation:")
print(f"  - Accuracy: {aggregate_accuracy:.1f}% of expected entities were found")
print(f"  - Precision: Of all entities predicted, {aggregate_precision:.1f}% were correct")
print(f"  - Recall: Of all actual entities, {aggregate_recall:.1f}% were found")
print(f"  - F1: Harmonic mean balancing precision and recall")

print(f"\nüéØ What these metrics mean:")
print(f"  - High Precision, Low Recall ‚Üí Model is conservative (misses entities)")
print(f"  - Low Precision, High Recall ‚Üí Model is aggressive (predicts too many)")
print(f"  - High F1 Score ‚Üí Good balance between precision and recall")

## 16b. Understanding the Metrics

### Accuracy
- **Formula**: `Correct / Total Expected`
- **Meaning**: Percentage of expected entities that were correctly predicted
- **Limitation**: Doesn't account for false positives (extra predictions)

### Precision
- **Formula**: `Correct / Total Predicted`
- **Meaning**: Of all entities the model predicted, how many were correct?
- **High Precision**: Model rarely makes false positive errors (rarely predicts wrong entities)

### Recall
- **Formula**: `Correct / Total Expected`
- **Meaning**: Of all actual entities, how many did the model find?
- **High Recall**: Model rarely makes false negative errors (rarely misses entities)

### F1 Score
- **Formula**: `2 √ó (Precision √ó Recall) / (Precision + Recall)`
- **Meaning**: Harmonic mean that balances precision and recall
- **Best metric**: When you care equally about false positives and false negatives

**Example**:
```
Ground truth: ['aspirin', 'ibuprofen', 'NSAIDs']
Prediction:   ['aspirin', 'ibuprofen']

Accuracy:  66.7% (2/3 found)
Precision: 100% (2/2 predicted were correct)
Recall:    66.7% (2/3 actual entities found)
F1 Score:  80.0% (balanced metric)
```

## 17. Custom Test Cases - Comprehensive NER Evaluation

Test the model's ability to:
1. **Extract Chemicals** - Identify drug names and chemical compounds
2. **Extract Diseases** - Identify medical conditions and diseases
3. **Extract Relationships** - Identify which chemicals are related to which diseases

In [None]:
# Test 1: Chemical Extraction
print("="*80)
print("TEST 1: CHEMICAL EXTRACTION")
print("="*80)

chemical_test = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.

A patient was treated with aspirin and ibuprofen for pain relief. The combination of these NSAIDs proved effective in reducing inflammation. Additionally, metformin was prescribed for glucose control.

List of extracted chemicals:
"""

print(f"\nüìù Prompt:\n{chemical_test}")
print("\nü§ñ Model Output:")
print(generate_response(chemical_test))

In [None]:
# Test 2: Disease Extraction
print("\n" + "="*80)
print("TEST 2: DISEASE EXTRACTION")
print("="*80)

disease_test = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the diseases mentioned.

The patient presented with hypertension, diabetes mellitus, and chronic kidney disease. Laboratory findings revealed proteinuria and elevated creatinine levels, suggesting diabetic nephropathy.

List of extracted diseases:
"""

print(f"\nüìù Prompt:\n{disease_test}")
print("\nü§ñ Model Output:")
print(generate_response(disease_test))

In [None]:
# Test 3: Chemical-Disease Relationship Extraction
print("\n" + "="*80)
print("TEST 3: RELATIONSHIP EXTRACTION - BASIC")
print("="*80)

relationship_test_1 = """The following article contains technical terms including diseases, drugs and chemicals. Extract the relationships between chemicals and diseases mentioned in the text.

Metformin is commonly prescribed for type 2 diabetes by improving insulin sensitivity and reducing hepatic glucose production. Aspirin is used in cardiovascular disease management in high-risk patients.

List the chemical-disease relationships:
"""

print(f"\nüìù Prompt:\n{relationship_test_1}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_1, max_new_tokens=600))

In [None]:
# Test 4: Multiple Relationship Extraction
print("\n" + "="*80)
print("TEST 4: RELATIONSHIP EXTRACTION - MULTIPLE PAIRS")
print("="*80)

relationship_test_2 = """The following article contains technical terms including diseases, drugs and chemicals. Identify all chemical-disease pairs and their relationships.

Long-term use of corticosteroids is associated with osteoporosis and increases the risk of bone fractures. NSAIDs are linked to chronic kidney disease and gastrointestinal bleeding in susceptible patients.

List of chemical-disease relationships:
"""

print(f"\nüìù Prompt:\n{relationship_test_2}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_2, max_new_tokens=600))

In [None]:
# Test 5: Complex Multi-Entity Relationship Extraction
print("\n" + "="*80)
print("TEST 5: COMPREHENSIVE EXTRACTION - ALL ENTITIES & RELATIONSHIPS")
print("="*80)

relationship_test_3 = """The following article contains technical terms including diseases, drugs and chemicals. Extract:
1. All chemicals mentioned
2. All diseases mentioned
3. All relationships between chemicals and diseases

The patient with rheumatoid arthritis was started on methotrexate for inflammatory joint disease. However, methotrexate is associated with hepatotoxicity and requires monitoring. The patient also has hypertension managed with lisinopril. Statins were prescribed for cardiovascular disease prevention given elevated cholesterol levels.

Extracted information:
"""

print(f"\nüìù Prompt:\n{relationship_test_3}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_3, max_new_tokens=800))

In [None]:
# Test 1: Chemical Extraction
print("="*80)
print("TEST 1: CHEMICAL EXTRACTION")
print("="*80)

chemical_test = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.

A patient was treated with aspirin and ibuprofen for pain relief. The combination of these NSAIDs proved effective in reducing inflammation. Additionally, metformin was prescribed for glucose control.

List of extracted chemicals:
"""

print(f"\nüìù Prompt:\n{chemical_test}")
print("\nü§ñ Model Output:")
print(generate_response(chemical_test))

In [None]:
# Test 2: Disease Extraction
print("\n" + "="*80)
print("TEST 2: DISEASE EXTRACTION")
print("="*80)

disease_test = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the diseases mentioned.

The patient presented with hypertension, diabetes mellitus, and chronic kidney disease. Laboratory findings revealed proteinuria and elevated creatinine levels, suggesting diabetic nephropathy.

List of extracted diseases:
"""

print(f"\nüìù Prompt:\n{disease_test}")
print("\nü§ñ Model Output:")
print(generate_response(disease_test))

In [None]:
# Test 3: Relationship Extraction with TYPE explanation
print("\n" + "="*80)
print("TEST 3: RELATIONSHIP EXTRACTION - TREATMENT")
print("="*80)

relationship_test_1 = """The following article contains technical terms including diseases, drugs and chemicals. For each disease-chemical pair, identify the relationship and explain the TYPE of relationship (e.g., treats, prevents, causes, worsens, etc.).

Metformin is commonly prescribed to treat type 2 diabetes by improving insulin sensitivity and reducing hepatic glucose production. Aspirin is used to prevent cardiovascular disease in high-risk patients.

List the relationships between chemicals and diseases with their types:
"""

print(f"\nüìù Prompt:\n{relationship_test_1}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_1, max_new_tokens=600))

In [None]:
# Test 4: Relationship Extraction - Adverse Effects
print("\n" + "="*80)
print("TEST 4: RELATIONSHIP EXTRACTION - ADVERSE EFFECTS")
print("="*80)

relationship_test_2 = """The following article contains technical terms including diseases, drugs and chemicals. For each disease-chemical pair, identify the relationship and explain the TYPE of relationship (e.g., treats, prevents, causes, worsens, induces, etc.).

Long-term use of corticosteroids can cause osteoporosis and increase the risk of bone fractures. NSAIDs may worsen chronic kidney disease and induce gastrointestinal bleeding in susceptible patients.

List the relationships between chemicals and diseases with their types:
"""

print(f"\nüìù Prompt:\n{relationship_test_2}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_2, max_new_tokens=600))

In [None]:
# Test 5: Complex Multi-Relationship Scenario
print("\n" + "="*80)
print("TEST 5: COMPLEX MULTI-RELATIONSHIP SCENARIO")
print("="*80)

relationship_test_3 = """The following article contains technical terms including diseases, drugs and chemicals. Extract all chemicals, all diseases, and explain the TYPE of relationship between each chemical-disease pair (treats, prevents, causes, worsens, contraindicates, etc.).

The patient with rheumatoid arthritis was started on methotrexate, which effectively treats inflammatory joint disease. However, methotrexate can cause hepatotoxicity and must be monitored carefully. The patient also has hypertension controlled with lisinopril. Statins were prescribed to prevent cardiovascular disease given the patient's elevated cholesterol levels.

Provide:
1. List of chemicals
2. List of diseases  
3. Relationships with their types
"""

print(f"\nüìù Prompt:\n{relationship_test_3}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_3, max_new_tokens=800))

## 18. Model Information and Next Steps

In [None]:
print("="*80)
print("TRAINING COMPLETE! üéâ")
print("="*80)
print(f"\nYour fine-tuned model is available at:")
print(f"  üìÅ Local: ./final_model")
print(f"  ü§ó Hub: https://huggingface.co/{HF_MODEL_ID}")
print(f"\nModel details:")
print(f"  Base model: {MODEL_NAME}")
print(f"  Training samples: {len(train_data)}")
print(f"  Validation samples: {len(val_data)}")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Training epochs: {NUM_EPOCHS}")
print(f"\nNext steps:")
print(f"  1. Test on more validation examples")
print(f"  2. Try the model on completely new medical texts")
print(f"  3. Compare with base model (ablation study)")
print(f"  4. Deploy via Hugging Face Inference API")
print(f"  5. Share your model with the community!")

## 19. Usage Example

How to use the model in production:

In [None]:
# Example: How to load and use the model later
usage_code = '''
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load LoRA adapter from Hub
model = PeftModel.from_pretrained(
    base_model,
    "your-username/llama3-medical-ner-lora"  # Your model ID
)
model.eval()

# Use the model
prompt = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.

Patient was treated with metformin and insulin for diabetes management.

List of extracted chemicals:
"""

# Generate response
# ... (use the generate_response function from above)
'''

print("Usage Example:")
print("="*80)
print(usage_code)

---

## Summary

This notebook successfully:
1. ‚úÖ Loaded and analyzed 3,000 medical NER examples
2. ‚úÖ Split data into train/validation/test sets (80/10/10)
3. ‚úÖ Formatted data in Llama 3 chat format
4. ‚úÖ Configured Weights & Biases for tracking
5. ‚úÖ Loaded Llama 3.2 3B with 4-bit quantization
6. ‚úÖ Applied LoRA for efficient fine-tuning
7. ‚úÖ Trained the model with SFT (monitored via W&B)
8. ‚úÖ Uploaded checkpoints to Hugging Face Hub
9. ‚úÖ Evaluated on completely unseen test set
10. ‚úÖ Saved the final model locally and on Hub

**Your medical NER model is ready to use! üöÄ**