# üè• Nursing FunctionGemma Training

This notebook fine-tunes a model to perform **Nursing Function Calling** (Simulated EPR, RAG, etc.).
It teaches the model to output structured `<function_call>` tags when appropriate.

## 1. Mount Google Drive
**Crucial Step:** We need to access your dataset and save the model!

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

# Create checkpoint directory in Drive
output_dir = "/content/drive/My Drive/nursing-function-gemma-checkpoints"
os.makedirs(output_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {output_dir}")

In [None]:
!pip install -q -U torch bitsandbytes git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/peft.git git+https://github.com/huggingface/trl.git accelerate datasets huggingface_hub

In [None]:
from huggingface_hub import login
login()

In [None]:
import torch
import torch.distributed as dist
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

# --- üîß FIX: Initialize Distributed Process Group ---
try:
    if not dist.is_initialized():
        dist.init_process_group(backend="gloo", init_method="file:///tmp/unique_init_file", rank=0, world_size=1)
    print("‚úÖ Distributed process group initialized (dummy for single-GPU).")
except Exception as e:
    print(f"‚ö†Ô∏è Warning: Could not init process group: {e}")

# 1. Config
MODEL_ID = "google/medgemma-4b-it"
DATASET_PATH = "/content/drive/MyDrive/nmc_brain/data/nursing_functions_dataset.jsonl"
OUTPUT_DIR = output_dir

# 2. Load Dataset
if not os.path.exists(DATASET_PATH):
    print(f"ERROR: Dataset not found at {DATASET_PATH}. Please check your Drive path!")
else:
    dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
    print(f"Loaded {len(dataset)} function examples")

    # 3. Formatting Function
    def formatting_prompts_func(example):
        output_texts = []
        instructions = example['instruction']
        outputs = example['output']
        
        if isinstance(instructions, str):
            instructions = [instructions]
            outputs = [outputs]
            
        for i in range(len(instructions)):
            tools_prompt = """You are a clinical AI agent. You have access to the following tools:
- record_vitals(systolic, diastolic, heart_rate, temp_c)
- administer_medication(drug_name, dose, route)
- search_nmc_standards(query)
If the user's request requires a tool, output the function call XML."""

            inst = instructions[i]
            out = outputs[i] if i < len(outputs) else "error"
            
            text = f"<start_of_turn>user\n{tools_prompt}\n\n{inst}<end_of_turn>\n<start_of_turn>model\n{out}<end_of_turn>"
            output_texts.append(text)
        return output_texts

    # 4. Pre-format
    dataset = dataset.map(lambda x: {"text": formatting_prompts_func(x)}, batched=True)

    # 5. Model Setup
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )

    device_map = {"": 0}

    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, quantization_config=bnb_config, device_map=device_map, trust_remote_code=True)
    model = prepare_model_for_kbit_training(model)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # 6. LoRA
    peft_config = LoraConfig(
        r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM"
    )

    # --- üîß FIX: Custom Data Collator for Gemma 3 (token_type_ids + labels) ---
    def data_collator(features):
        batch = tokenizer.pad(features, padding=True, return_tensors="pt")
        batch = dict(batch) # Make mutable
        
        # Gemma 3 REQUIRES token_type_ids
        if "token_type_ids" not in batch:
            batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
            
        # Generate labels (mask padding with -100)
        if "labels" not in batch:
            labels = batch["input_ids"].clone()
            if tokenizer.pad_token_id is not None:
                labels[labels == tokenizer.pad_token_id] = -100
            batch["labels"] = labels
            
        return batch

    # 7. SFTConfig
    sft_config = SFTConfig(output_dir=OUTPUT_DIR)
    sft_config.max_seq_length = 512
    sft_config.dataset_text_field = "text"
    sft_config.num_train_epochs = 5
    sft_config.per_device_train_batch_size = 4
    sft_config.gradient_accumulation_steps = 4
    sft_config.fp16 = True
    sft_config.logging_steps = 10
    sft_config.push_to_hub = True
    sft_config.hub_model_id = "NurseCitizenDeveloper/nursing-function-gemma"
    sft_config.hub_private_repo = True
    sft_config.hub_strategy = "checkpoint"
    sft_config.packing = False

    # 8. Trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        processing_class=tokenizer,
        args=sft_config,
        data_collator=data_collator  # üîß ADDED to fix token_type_ids
    )

    trainer.train()
    
    # Save & Push final
    trainer.model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    trainer.push_to_hub()
    print("TRAINING COMPLETE! ‚úÖ")