# Medical NER Fine-Tuning with Llama 3.2 3B + LoRA

This notebook implements fine-tuning of Llama 3.2 3B Instruct for medical Named Entity Recognition (NER) using:
- **SFT** (Supervised Fine-Tuning)
- **LoRA** (Low-Rank Adaptation)
- **Hugging Face Hub** integration for checkpoint uploads

## Tasks:
1. Chemical entity extraction
2. Disease entity extraction
3. Chemical-Disease relationship extraction

## Dataset:
- 3,000 medical text examples
- 80/10/10 train/validation/test split
- Weights & Biases tracking enabled

## 0. Environment Variables Setup

‚ö†Ô∏è **IMPORTANT**: Set your credentials before running this notebook!

Required:
- `HF_TOKEN`: Your Hugging Face token (needed to save models to HF Hub)

Optional:
- `WANDB_API_KEY`: Your Weights & Biases API key (for training tracking)

In [1]:
import os

# Set your Hugging Face token (required for uploading to HF Hub)
os.environ["HF_TOKEN"] = "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"

# Set your Weights & Biases API key (optional, for training tracking)
os.environ["WANDB_API_KEY"] = "d88df098d85360ac924ec2bf8dcf5520d745c411"

# Verify environment variables
print("‚úì Environment variables set")
print(f"  HF_TOKEN: {'‚úì Set' if os.environ.get('HF_TOKEN') and os.environ['HF_TOKEN'] != 'hf_YOUR_TOKEN_HERE' else '‚úó Not set - UPDATE THIS!'}")
print(f"  WANDB_API_KEY: {'‚úì Set' if os.environ.get('WANDB_API_KEY') else '‚óã Optional (will use wandb login cache)'}")

‚úì Environment variables set
  HF_TOKEN: ‚úì Set
  WANDB_API_KEY: ‚úì Set


## 1. Setup and Installation

First, let's install all required dependencies.

In [1]:
# Install required packages
!pip install -q transformers datasets peft accelerate bitsandbytes
!pip install -q huggingface-hub tokenizers trl scikit-learn
!pip install -q scipy sentencepiece protobuf wandb

print("‚úì All packages installed successfully!")

‚úì All packages installed successfully!


## 2. Import Libraries

In [2]:
import json
import torch
import os
import random
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    EarlyStoppingCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from huggingface_hub import login
import wandb

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/alberto/opt/anaconda3/envs/medical_ner/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/alberto/opt/anaconda3/envs/medical_ner/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/alberto/opt/anaconda3/envs/medical_ner/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/alberto/opt/anaconda3/envs/medical_ner/lib/python3.10/site-packag

PyTorch version: 2.2.2
CUDA available: False


## 3. Configuration

‚ö†Ô∏è **IMPORTANT**: Update `HF_USERNAME` with your Hugging Face username!

In [4]:
# Configuration Section
from datetime import datetime

HF_USERNAME = "albyos"  # Replace with your HF username

# Generate timestamp for checkpoint naming
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
HF_MODEL_ID = f"{HF_USERNAME}/llama3-medical-ner-lora-{TIMESTAMP}"
BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_NAME = BASE_MODEL  # Alias for consistency

# LoRA Configuration
LORA_RANK = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# Training Configuration
NUM_EPOCHS = 3
BATCH_SIZE = 4
GRADIENT_ACCUMULATION = 4
LEARNING_RATE = 2e-4

# Data Configuration
TRAIN_SPLIT_RATIO = 0.9
RANDOM_SEED = 42
RESHUFFLE_SPLITS_EACH_RUN = True  # When True, create a fresh validation split every run
SPLIT_SEED = random.randint(0, 1_000_000) if RESHUFFLE_SPLITS_EACH_RUN else RANDOM_SEED

print("‚úì Configuration loaded")
print(f"  Base model: {BASE_MODEL}")
print(f"  HF model ID: {HF_MODEL_ID}")
print(f"  Training timestamp: {TIMESTAMP}")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Training epochs: {NUM_EPOCHS}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  Data split seed: {SPLIT_SEED} ({'reshuffled' if RESHUFFLE_SPLITS_EACH_RUN else 'fixed'})")

‚úì Configuration loaded
  Base model: meta-llama/Llama-3.2-3B-Instruct
  HF model ID: albyos/llama3-medical-ner-lora-20251029_143110
  Training timestamp: 20251029_143110
  LoRA rank: 16
  Training epochs: 3
  Effective batch size: 16
  Data split seed: 644495 (reshuffled)


## 4. Hugging Face Authentication

Get your token from: https://huggingface.co/settings/tokens

In [5]:
# Authenticate with Hugging Face
from huggingface_hub import login

hf_token = os.environ.get("HF_TOKEN")
if hf_token and hf_token != "hf_YOUR_TOKEN_HERE":
    login(token=hf_token)
    print("‚úì Logged in to Hugging Face")
else:
    print("‚ö† HF_TOKEN not set. Please update Cell 3 before continuing.")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


‚úì Logged in to Hugging Face


## 4b. Weights & Biases Setup

Initialize W&B to track training metrics, validation loss, and experiments.
Get your API key from: https://wandb.ai/authorize

In [6]:
# Login to Weights & Biases
wandb_key = os.getenv('WANDB_API_KEY')

if wandb_key and wandb_key != 'your_wandb_key_here':
    wandb.login(key=wandb_key)
    print('‚úì Logged in to Weights & Biases using WANDB_API_KEY')
else:
    print('‚ö† Warning: WANDB_API_KEY not set. Attempting to use cached login...')
    try:
        wandb.login()
        print('‚úì Logged in to Weights & Biases using cached credentials')
    except Exception as e:
        print(f'‚ö† Warning: Could not login to W&B: {e}')
        print('  Run wandb.login() interactively or set WANDB_API_KEY environment variable')

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mclemalb[0m ([33malberto-clemente[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚úì Logged in to Weights & Biases using WANDB_API_KEY


In [7]:
# Initialize Weights & Biases
wandb.init(
    project="medical-ner-finetuning",
    name=f"llama3-medical-ner-{TIMESTAMP}",
    config={
        "model": BASE_MODEL,
        "lora_rank": LORA_RANK,
        "lora_alpha": LORA_ALPHA,
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION,
    }
)

print("‚úì Weights & Biases initialized")
print(f"  Project: medical-ner-finetuning")
print(f"  Run name: llama3-medical-ner-{TIMESTAMP}")
print(f"  Dashboard: https://wandb.ai")

‚úì Weights & Biases initialized
  Project: medical-ner-finetuning
  Run name: llama3-medical-ner-20251029_143110
  Dashboard: https://wandb.ai


## 5. Data Exploration

Let's examine the dataset structure.

In [5]:
# Load and inspect the dataset
# Load data
with open('../data/both_rel_instruct_all.jsonl', 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

print(f"Total samples: {len(data)}")
print(f"\nSample structure:")
print(json.dumps(data[0], indent=2)[:500] + "...")

Total samples: 3000

Sample structure:
{
  "prompt": "The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.\n\nIn unanesthetized, spontaneously hypertensive rats the decrease in blood pressure and heart rate produced by intravenous clonidine, 5 to 20 micrograms/kg, was inhibited or reversed by nalozone, 0.2 to 2 mg/kg. The hypotensive effect of 100 mg/kg alpha-methyldopa was also partially reversed by naloxone. Naloxone alone did not affect either blood ...


In [10]:
# Analyze task distribution
task_counts = {}
for sample in data:
    if "chemicals mentioned" in sample['prompt']:
        task = "Chemical Extraction"
    elif "diseases mentioned" in sample['prompt']:
        task = "Disease Extraction"
    elif "influences between" in sample['prompt']:
        task = "Relationship Extraction"
    else:
        task = "Other"
    
    task_counts[task] = task_counts.get(task, 0) + 1

print("Task Distribution:")
for task, count in task_counts.items():
    print(f"  {task}: {count} ({count/len(data)*100:.1f}%)")

Task Distribution:
  Chemical Extraction: 1000 (33.3%)
  Disease Extraction: 2000 (66.7%)


In [11]:
# Show example from each task type
print("="*80)
print("EXAMPLE: Chemical Extraction")
print("="*80)
chem_example = [s for s in data if "chemicals mentioned" in s['prompt']][0]
print(f"Prompt:\n{chem_example['prompt'][:300]}...")
print(f"\nCompletion:\n{chem_example['completion']}")

print("\n" + "="*80)
print("EXAMPLE: Disease Extraction")
print("="*80)
disease_example = [s for s in data if "diseases mentioned" in s['prompt']][0]
print(f"Prompt:\n{disease_example['prompt'][:300]}...")
print(f"\nCompletion:\n{disease_example['completion']}")

EXAMPLE: Chemical Extraction
Prompt:
The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.

In unanesthetized, spontaneously hypertensive rats the decrease in blood pressure and heart rate produced by intravenous clonidine, 5 to 20 micrograms/kg, was inhib...

Completion:
- clonidine
- nalozone
- alpha-methyldopa
- naloxone
- Naloxone
- [3H]-naloxone
- [3H]-dihydroergocryptine

EXAMPLE: Disease Extraction
Prompt:
The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the diseases mentioned.

In unanesthetized, spontaneously hypertensive rats the decrease in blood pressure and heart rate produced by intravenous clonidine, 5 to 20 micrograms/kg, was inhibi...

Completion:
- hypertensive
- hypotensive


## 6. Dataset Splitting

Split into:
- **80% Training** (2,400 samples) - for fine-tuning
- **10% Validation** (300 samples) - for monitoring during training (W&B)
- **10% Test** (300 samples) - for final evaluation after training

### ‚ö†Ô∏è CRITICAL FIX Applied:
**Problem**: Previous version used `shuffle=False`, causing severe task imbalance:
- Training: 41.5% chemical, 41.5% disease, **17% relationship** ‚Üê Underrepresented!
- Validation: **100% relationship** ‚Üê Wrong distribution!
- Test: **100% relationship** ‚Üê Wrong distribution!

**Solution**: Now using `shuffle=True` to ensure **balanced** task distribution across all splits.
This means the model will see all three tasks (chemical, disease, relationship extraction) proportionally during training.

In [None]:
# Split data into train/val/test (80/10/10)
# ‚ö†Ô∏è CRITICAL FIX: Enable shuffle=True to ensure balanced task distribution
SPLIT_SEED = 644495

# First split: 80% train, 20% temp (for val + test)
train_data, temp_data = train_test_split(
    data,
    test_size=0.2,  # 20% for validation + test
    random_state=SPLIT_SEED,
    shuffle=True  # ‚úÖ FIXED: Was False, now True for balanced splits
)

# Second split: split the 20% into 10% val, 10% test
val_data, test_data = train_test_split(
    temp_data,
    test_size=0.5,  # 50% of 20% = 10% of total
    random_state=SPLIT_SEED + 1,
    shuffle=True  # ‚úÖ FIXED: Was False, now True for balanced splits
)

# Analyze task distribution to verify balanced split
def get_task_type(prompt):
    prompt_lower = prompt.lower()
    if "influences between" in prompt_lower:
        return "relationship"
    elif "chemicals mentioned" in prompt_lower:
        return "chemical"
    elif "diseases mentioned" in prompt_lower:
        return "disease"
    return "other"

print("="*80)
print("TASK DISTRIBUTION ANALYSIS")
print("="*80)

for name, dataset in [("Train", train_data), ("Validation", val_data), ("Test", test_data)]:
    tasks = {}
    for sample in dataset:
        task = get_task_type(sample['prompt'])
        tasks[task] = tasks.get(task, 0) + 1
    
    print(f"\n{name} ({len(dataset)} samples):")
    for task, count in sorted(tasks.items()):
        print(f"  {task}: {count} ({count/len(dataset)*100:.1f}%)")

# Save splits
with open('train.jsonl', 'w', encoding='utf-8') as f:
    for item in train_data:
        f.write(json.dumps(item) + '\n')

with open('validation.jsonl', 'w', encoding='utf-8') as f:
    for item in val_data:
        f.write(json.dumps(item) + '\n')

with open('test.jsonl', 'w', encoding='utf-8') as f:
    for item in test_data:
        f.write(json.dumps(item) + '\n')

print(f"\n{'='*80}")
print(f"‚úì Dataset split complete (seed={SPLIT_SEED})")
print(f"  Train samples: {len(train_data)} ({len(train_data)/len(data)*100:.1f}%)")
print(f"  Validation samples: {len(val_data)} ({len(val_data)/len(data)*100:.1f}%) - for training monitoring")
print(f"  Test samples: {len(test_data)} ({len(test_data)/len(data)*100:.1f}%) - for final evaluation")
print(f"\nüìä Usage:")
print(f"  - Train: Used for fine-tuning")
print(f"  - Validation: Monitored during training (shown in W&B)")
print(f"  - Test: Used ONLY after training for final evaluation")
print(f"\n‚ö†Ô∏è  IMPORTANT: Splits are now SHUFFLED for balanced task distribution!")

‚úì Dataset split complete (seed=644495)
  Train samples: 2400 (80.0%)
  Validation samples: 300 (10.0%) - for training monitoring
  Test samples: 300 (10.0%) - for final evaluation

üìä Usage:
  - Train: Used for fine-tuning
  - Validation: Monitored during training (shown in W&B)
  - Test: Used ONLY after training for final evaluation


## 7. Data Formatting

Format data into Llama 3 chat format with system, user, and assistant roles.

In [13]:
def format_instruction(sample):
    """Format data into Llama 3 chat format."""
    return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a medical NER expert. Extract the requested entities from medical texts accurately.<|eot_id|><|start_header_id|>user<|end_header_id|>

{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{sample['completion']}<|eot_id|>"""

# Test formatting
formatted_example = format_instruction(train_data[0])
print("Formatted Example:")
print(formatted_example[:500] + "...")

Formatted Example:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a medical NER expert. Extract the requested entities from medical texts accurately.<|eot_id|><|start_header_id|>user<|end_header_id|>

The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the influences between the chemicals and diseases mentioned.

BACKGROUND/AIMS: It is still unclear what happens in the glomerulus when proteinuria starts. Using puromycin aminonucleoside...


In [14]:
# Format all data
train_formatted = [{"text": format_instruction(sample)} for sample in train_data]
val_formatted = [{"text": format_instruction(sample)} for sample in val_data]
test_formatted = [{"text": format_instruction(sample)} for sample in test_data]

# Create HuggingFace datasets
train_dataset = Dataset.from_list(train_formatted)
val_dataset = Dataset.from_list(val_formatted)
test_dataset = Dataset.from_list(test_formatted)

print(f"‚úì Datasets formatted:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Validation: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

‚úì Datasets formatted:
  Train: 2400 samples
  Validation: 300 samples
  Test: 300 samples


## 8. Load Model and Tokenizer

Load Llama 3.2 3B with 4-bit quantization for memory efficiency.

In [15]:
# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("‚úì Quantization config created (4-bit NF4)")

‚úì Quantization config created (4-bit NF4)


In [17]:
!pip install hf_transfer

Collecting hf_transfer
  Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m3.6/3.6 MB[0m [31m65.4 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: hf_transfer
Successfully installed hf_transfer-0.1.9


In [18]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",
    add_eos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

print(f"‚úì Tokenizer loaded: {MODEL_NAME}")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  PAD token: {tokenizer.pad_token}")
print(f"  EOS token: {tokenizer.eos_token}")

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

‚úì Tokenizer loaded: meta-llama/Llama-3.2-3B-Instruct
  Vocab size: 128256
  PAD token: <|eot_id|>
  EOS token: <|eot_id|>


In [19]:
# Load base model
print("Loading model... (this may take a few minutes)")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

print(f"‚úì Base model loaded: {MODEL_NAME}")
print(f"  Model size: {model.get_memory_footprint() / 1e9:.2f} GB")

Loading model... (this may take a few minutes)


config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

‚úì Base model loaded: meta-llama/Llama-3.2-3B-Instruct
  Model size: 2.20 GB


In [20]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
print("‚úì Model prepared for k-bit training")

‚úì Model prepared for k-bit training


## 9. Configure LoRA

Apply Low-Rank Adaptation for efficient fine-tuning.

In [21]:
# LoRA configuration
lora_config = LoraConfig(
    r=LORA_RANK,                   # LoRA rank
    lora_alpha=LORA_ALPHA,         # LoRA alpha (scaling)
    target_modules=[               # Layers to apply LoRA
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
    lora_dropout=0.05,             # Dropout for regularization
    bias="none",                   # No bias training
    task_type="CAUSAL_LM"          # Causal language modeling
)

print(f"‚úì LoRA configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Dropout: {lora_config.lora_dropout}")
print(f"  Target modules: {len(lora_config.target_modules)}")

‚úì LoRA configuration:
  Rank (r): 16
  Alpha: 32
  Dropout: 0.05
  Target modules: 7


In [22]:
# Apply LoRA to model
model = get_peft_model(model, lora_config)

print("‚úì LoRA applied to model")
print("\nTrainable parameters:")
model.print_trainable_parameters()

‚úì LoRA applied to model

Trainable parameters:
trainable params: 24,313,856 || all params: 3,237,063,680 || trainable%: 0.7511


## 10. Tokenize Datasets

In [23]:
def tokenize_function(examples):
    """Tokenize the texts."""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=2048,
        padding=False,
    )

# Tokenize datasets
print("Tokenizing datasets...")

tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing train set"
)

tokenized_val = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Tokenizing validation set"
)

print(f"‚úì Train set tokenized: {len(tokenized_train)} samples")
print(f"‚úì Validation set tokenized: {len(tokenized_val)} samples")

Tokenizing datasets...


Tokenizing train set:   0%|          | 0/2400 [00:00<?, ? examples/s]

Tokenizing validation set:   0%|          | 0/300 [00:00<?, ? examples/s]

‚úì Train set tokenized: 2400 samples
‚úì Validation set tokenized: 300 samples


In [24]:
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Causal LM, not masked LM
)

print("‚úì Data collator created")

‚úì Data collator created


## 11. Training Configuration

In [25]:
# Training arguments
training_args = TrainingArguments(
    # Output and logging
    output_dir="./llama3-medical-ner-lora",
    logging_dir="./logs",
    logging_steps=10,
    
    # Training parameters
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    
    # Optimization
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    
    # Evaluation
    eval_strategy="steps",
    eval_steps=100,
    
    # Checkpointing
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    
    # Memory optimization
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    
    # Mixed precision
    fp16=True,
    
    # Hugging Face Hub
    push_to_hub=True,
    hub_model_id=HF_MODEL_ID,
    hub_strategy="checkpoint",
    hub_private_repo=False,
    
    # Misc
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="wandb",  # Enable Weights & Biases logging
    run_name=f"llama3-medical-ner-{TIMESTAMP}",  # W&B run name
    seed=42,
)

print(f"‚úì Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size (per device): {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Hub model ID: {HF_MODEL_ID}")

‚úì Training configuration:
  Epochs: 3
  Batch size (per device): 4
  Gradient accumulation: 4
  Effective batch size: 16
  Learning rate: 0.0002
  Hub model ID: albyos/llama3-medical-ner-lora-20251029_143110


## 12. Initialize Trainer

In [26]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
)

# Configure early stopping to prevent overfitting
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.0))

# Calculate training steps
total_steps = (len(tokenized_train) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)) * training_args.num_train_epochs

print(f"‚úì Trainer initialized")
print(f"‚úì Expected training steps: ~{total_steps}")
print(f"‚úì Expected checkpoints: ~{max(1, total_steps // training_args.save_steps)}")
print("‚úì Early stopping enabled (patience = 3 evaluations)")

‚úì Trainer initialized
‚úì Expected training steps: ~450
‚úì Expected checkpoints: ~4
‚úì Early stopping enabled (patience = 3 evaluations)


## 13. Start Training

‚ö†Ô∏è **This will take 2-3 hours on an A100 GPU**

The training will:
- Save checkpoints every 100 steps
- Upload checkpoints to Hugging Face Hub
- Evaluate on validation set every 100 steps
- Save the best model based on validation loss

In [27]:
# Start training
print("="*80)
print("STARTING TRAINING")
print("="*80)
print("This may take 2-3 hours on A100 GPU...\n")

trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


STARTING TRAINING
This may take 2-3 hours on A100 GPU...



Step,Training Loss,Validation Loss
100,1.3488,1.353626
200,1.0701,1.173005
300,0.876,0.977606
400,0.6529,0.894576


TrainOutput(global_step=450, training_loss=1.0581105242835152, metrics={'train_runtime': 2470.5835, 'train_samples_per_second': 2.914, 'train_steps_per_second': 0.182, 'total_flos': 6.17495844698112e+16, 'train_loss': 1.0581105242835152, 'epoch': 3.0})

## 14. Save Final Model

In [28]:
# Save model locally
print("Saving final model...")
trainer.save_model("./final_model")
tokenizer.save_pretrained("./final_model")

print(f"‚úì Model saved to: ./final_model")

Saving final model...


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

‚úì Model saved to: ./final_model


In [29]:
# Push to Hugging Face Hub
print("Pushing to Hugging Face Hub...")

try:
    trainer.push_to_hub(commit_message="Training complete - final model")
    print(f"‚úì Model pushed to: https://huggingface.co/{HF_MODEL_ID}")
except Exception as e:
    print(f"‚ö† Failed to push to hub: {e}")
    print("  You can manually push later using: trainer.push_to_hub()")

Pushing to Hugging Face Hub...


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

No files have been modified since last commit. Skipping to prevent empty commit.


‚úì Model pushed to: https://huggingface.co/albyos/llama3-medical-ner-lora-20251029_143110


## 15. Training Analysis

In [31]:
!pip install matplotlib


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting matplotlib
  Downloading matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.60.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (112 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.9-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (6.3 kB)
Downloading matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.7/8.7 MB[0m [31m87

In [30]:
# Plot training metrics
import pandas as pd
import matplotlib.pyplot as plt

# Get training history
log_history = trainer.state.log_history

# Extract losses
train_loss = [entry['loss'] for entry in log_history if 'loss' in entry]
eval_loss = [entry['eval_loss'] for entry in log_history if 'eval_loss' in entry]

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss, label='Training Loss', color='blue')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(eval_loss, label='Validation Loss', color='orange')
plt.xlabel('Evaluation Steps')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Training metrics plotted and saved to: training_metrics.png")

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
# Summary statistics
print("="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"Total training steps: {len(train_loss)}")
print(f"Final training loss: {train_loss[-1]:.4f}")
print(f"Final validation loss: {eval_loss[-1]:.4f}")
print(f"Best validation loss: {min(eval_loss):.4f}")
print(f"Loss reduction: {((train_loss[0] - train_loss[-1]) / train_loss[0] * 100):.1f}%")

## Next Steps

Training is complete! Your model has been saved.

**To evaluate your model:**
1. Open `Medical_NER_Evaluation.ipynb`
2. Run the evaluation on the test set
3. Test custom examples

**Model locations:**
- Local: `./final_model`
- HuggingFace Hub: Check the output above for your model URL
