# üè• Nursing FunctionGemma Training v2

**IMPROVED VERSION** - This notebook trains the model to extract actual parameter values from clinical notes.

**Key Improvements:**
- 550+ diverse training examples
- Higher LoRA rank (32) for better learning
- 8 training epochs
- Clearer system prompt with extraction cues

## Step 1: Mount Google Drive

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

# Create checkpoint directory
output_dir = "/content/drive/MyDrive/nursing-function-gemma-v2-checkpoints"
os.makedirs(output_dir, exist_ok=True)
print(f"‚úÖ Checkpoints will be saved to: {output_dir}")

## Step 2: Install Dependencies

In [None]:
!pip install -q -U torch bitsandbytes 
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/trl.git
!pip install -q -U accelerate datasets huggingface_hub

## Step 3: Login to Hugging Face

In [None]:
from huggingface_hub import login
login()

## Step 4: Load Dataset (v2 - Diverse Examples)

In [None]:
from datasets import load_dataset

# Path to v2 dataset - MAKE SURE YOU UPLOADED THIS FILE!
DATASET_PATH = "/content/drive/MyDrive/nmc_brain/data/nursing_functions_dataset_v2.jsonl"

if not os.path.exists(DATASET_PATH):
    print(f"‚ùå ERROR: Dataset not found at {DATASET_PATH}")
    print("Please upload nursing_functions_dataset_v2.jsonl to your Drive!")
else:
    dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
    print(f"‚úÖ Loaded {len(dataset)} function examples (v2 diverse dataset)")
    print(f"\nSample example:")
    print(f"  Input: {dataset[0]['instruction']}")
    print(f"  Output: {dataset[0]['output']}")

## Step 5: Format Dataset for Training

In [None]:
def formatting_prompts_func(example):
    """Format examples with clear extraction instructions."""
    output_texts = []
    instructions = example['instruction']
    outputs = example['output']
    
    if isinstance(instructions, str):
        instructions = [instructions]
        outputs = [outputs]
        
    for i in range(len(instructions)):
        # Clear system prompt that teaches value extraction
        tools_prompt = """You are a clinical AI agent. Convert clinical notes into function calls.

Functions:
- record_vitals(systolic=X, diastolic=Y, heart_rate=Z, temp_c=T)
- administer_medication(drug_name='X', dose='Y', route='Z')
- search_nmc_standards(query='X')

Extract the actual values from the input and output the correct function call."""

        inst = instructions[i]
        out = outputs[i] if i < len(outputs) else "error"
        
        text = f"<start_of_turn>user\n{tools_prompt}\n\nInput: {inst}<end_of_turn>\n<start_of_turn>model\n{out}<end_of_turn>"
        output_texts.append(text)
    return output_texts

# Apply formatting
dataset = dataset.map(lambda x: {"text": formatting_prompts_func(x)}, batched=True)
print(f"‚úÖ Dataset formatted! {len(dataset)} examples ready.")

## Step 6: Load Base Model (MedGemma 4B)

In [None]:
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training

# Fix for Gemma 3 distributed requirement
try:
    if not dist.is_initialized():
        dist.init_process_group(backend="gloo", init_method="file:///tmp/dist_init_v2", rank=0, world_size=1)
    print("‚úÖ Distributed process group initialized")
except Exception as e:
    print(f"‚ö†Ô∏è Warning: {e}")

# Model config
MODEL_ID = "google/medgemma-4b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

print(f"Loading base model: {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("‚úÖ Model loaded!")

## Step 7: Configure LoRA (Increased Rank for Better Learning)

In [None]:
peft_config = LoraConfig(
    r=32,           # Increased from 16
    lora_alpha=64,  # Increased from 32
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05
)

print("‚úÖ LoRA config ready")
print(f"   Rank: {peft_config.r}")
print(f"   Alpha: {peft_config.lora_alpha}")

## Step 8: Setup Trainer

In [None]:
from trl import SFTTrainer, SFTConfig

# Custom data collator for Gemma 3
def data_collator(features):
    batch = tokenizer.pad(features, padding=True, return_tensors="pt")
    batch = dict(batch)
    
    if "token_type_ids" not in batch:
        batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
        
    if "labels" not in batch:
        labels = batch["input_ids"].clone()
        if tokenizer.pad_token_id is not None:
            labels[labels == tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        
    return batch

# Training config - 8 epochs for better learning
sft_config = SFTConfig(output_dir=output_dir)
sft_config.max_seq_length = 512
sft_config.dataset_text_field = "text"
sft_config.num_train_epochs = 8         # Increased from 5
sft_config.per_device_train_batch_size = 4
sft_config.gradient_accumulation_steps = 4
sft_config.learning_rate = 2e-4         # Slightly higher
sft_config.fp16 = True
sft_config.logging_steps = 10
sft_config.save_steps = 100
sft_config.push_to_hub = True
sft_config.hub_model_id = "NurseCitizenDeveloper/nursing-function-gemma"
sft_config.hub_private_repo = True
sft_config.hub_strategy = "checkpoint"
sft_config.packing = False

# Create trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=sft_config,
    data_collator=data_collator
)

print("‚úÖ Trainer ready!")
print(f"   Dataset: {len(dataset)} examples")
print(f"   Epochs: 8")
print(f"   Batch size: 4 (effective: 16)")

## Step 9: Train! üöÄ

In [None]:
print("üöÄ Starting training v2...")
print("   This will take approximately 2-3 hours on T4 GPU")
print("="*50)

trainer.train()

print("="*50)
print("üéâ Training complete!")

## Step 10: Save and Push to Hub

In [None]:
# Save locally
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"‚úÖ Model saved to: {output_dir}")

# Push to Hub
trainer.push_to_hub()
print("‚úÖ Pushed to Hugging Face Hub!")
print("\nüéâ TRAINING v2 COMPLETE! ‚úÖ")
print("\nYour improved model is now at:")
print("https://huggingface.co/NurseCitizenDeveloper/nursing-function-gemma")