# LoRA Fine-tuning for NER Label Generation

This notebook fine-tunes the Qwen 0.5B-Instruct model using LoRA (Low-Rank Adaptation) for NER label generation as a text generation task.


In [None]:
# Imports
import json
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments
)
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType
from trl import SFTTrainer
import os


In [None]:
# Load processed data
with open("outputs/data/train_instruction_data.json", "r", encoding="utf-8") as f:
    train_data = json.load(f)

with open("outputs/data/val_instruction_data.json", "r", encoding="utf-8") as f:
    val_data = json.load(f)

print(f"Loaded {len(train_data)} training examples")
print(f"Loaded {len(val_data)} validation examples")

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

# Show a sample
print("\nSample training example:")
print(train_dataset[0]['text'][:500] + "...")


In [None]:
# Model configuration
model_path = "models/Qwen2.5-0.5B-Instruct"

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model without quantization
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.float16  # Use fp16 for efficiency
)

print(f"Model loaded: {model.__class__.__name__}")


In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Apply LoRA to model
print("Applying LoRA configuration...")
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"Trainable params: {trainable_params:,} || All params: {all_param:,} || Trainable%: {100 * trainable_params / all_param:.2f}")


In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="outputs/checkpoints",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    optim="adamw_torch",
    save_steps=200,
    logging_steps=50,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=1.0,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="cosine",
    evaluation_strategy="steps",
    eval_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",  # Disable wandb/tensorboard
)

# Data formatting function
def formatting_func(example):
    return example["text"]

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    args=training_args,
    formatting_func=formatting_func,
    max_seq_length=512,
)

print("Trainer initialized successfully!")


In [None]:
# Start training
print("Starting training...")
trainer.train()

# Save the final model
print("\nSaving the final model...")
trainer.save_model("outputs/final_model")
tokenizer.save_pretrained("outputs/final_model")

print("Training completed and model saved!")


In [None]:
# Save training history
import json

# Get training history
history = trainer.state.log_history

# Save to file
with open("outputs/results/training_history.json", "w") as f:
    json.dump(history, f, indent=2)

# Display final metrics
if history:
    final_metrics = history[-1]
    print("Final metrics:")
    for key, value in final_metrics.items():
        if isinstance(value, (int, float)):
            print(f"{key}: {value:.4f}")
