# LoRA Training Pipeline

This notebook demonstrates how to fine-tune MedGemma using the `model-training` component.

In [None]:
# Install dependencies and local package
!pip install -q -e ../components/model-training

In [None]:
import os
from model_training.dataset import load_training_data, format_extraction_prompt, format_grounding_prompt
from model_training.config import get_lora_config, load_model_and_tokenizer
from model_training.train import train_model

## 1. Load Data
Load training data from JSONL files.

In [None]:
DATA_PATH = "../data/training/criteria_train.jsonl"
# Ensure data exists (or use dummy path for demo)
if not os.path.exists(DATA_PATH):
    print(f"Warning: {DATA_PATH} not found. Please export data first.")
else:
    train_data = load_training_data(DATA_PATH)
    print(f"Loaded {len(train_data)} examples")

## 2. Format Data
Format data for the specific task (Extraction or Grounding).

In [None]:
# Example: Formatting for Extraction Task
if 'train_data' in locals():
    formatted_dataset = train_data.map(format_extraction_prompt)
    print(formatted_dataset[0]['text'])

## 3. Configure Model
Load 4-bit/8-bit quantized model and apply LoRA.

In [None]:
MODEL_NAME = "google/medgemma-4b-it" # Replace with accessible model if needed

lora_config = get_lora_config(r=16, alpha=32, dropout=0.05)

# Note: This requires GPU. Set load_in_8bit=False for CPU (but will be slow/OOM)
try:
    model, tokenizer = load_model_and_tokenizer(MODEL_NAME, lora_config, load_in_8bit=True)
    model.print_trainable_parameters()
except Exception as e:
    print(f"Could not load model (expected if no GPU/internet): {e}")

## 4. Train
Run the training loop.

In [None]:
OUTPUT_DIR = "../models/extraction_lora"

if 'model' in locals() and 'formatted_dataset' in locals():
    train_model(
        model=model,
        tokenizer=tokenizer,
        train_dataset=formatted_dataset,
        output_dir=OUTPUT_DIR,
        num_epochs=3,
        batch_size=4
    )