# Nursing Proficiency+ Model Training

This notebook fine-tunes the MedGemma 4B model on the custom NMC dataset. 
It supports persistent checkpointing to Google Drive.

In [None]:
!nvidia-smi

## 1. Mount Google Drive
This ensures checkpoints are saved if the runtime disconnects.

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

# Create checkpoint directory
checkpoint_dir = "/content/drive/My Drive/nursing-proficiency-plus-checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {checkpoint_dir}")

## 2. Install Dependencies

In [None]:
!pip install -q -U torch bitsandbytes git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/peft.git git+https://github.com/huggingface/trl.git accelerate datasets huggingface_hub

## 3. Hugging Face Login

In [None]:
from huggingface_hub import login
# Paste your token below when prompted or hardcode it
login()

## 4. Run Training
Make sure you have uploaded `nmc_dataset_web.jsonl` to the `data/` folder in Colab files.

**Note:** You might need to create the `data` folder first: `!mkdir -p data` and upload the file there.

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, SFTConfig

# --- Configuration ---
model_id = "google/medgemma-4b-it"
# Updated Path based on User input
dataset_path = "/content/drive/MyDrive/nmc_brain/data/nmc_dataset_web.jsonl"
output_dir = "/content/drive/My Drive/nursing-proficiency-plus-checkpoints"
hub_model_id = "NurseCitizenDeveloper/nursing-proficiency-plus"

# Verify dataset exists
if not os.path.exists(dataset_path):
    print(f"ERROR: Dataset not found at {dataset_path}. Please check your Drive folders!")
else:
    print("Dataset found. Starting setup...")

    # 1. Load Dataset
    dataset = load_dataset("json", data_files=dataset_path, split="train")
    print(f"Loaded {len(dataset)} examples")

    # 2. BitsAndBytes Config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )

    # 3. Load Model
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    model.config.use_cache = False
    model = prepare_model_for_kbit_training(model)

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    # Fix for fp16
    tokenizer.padding_side = "right"

    # 4. LoRA Config
    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    )

    # 5. Training Args
    sft_config = SFTConfig(output_dir=output_dir)
    # sft_config.dataset_text_field = "output" # COMMENTED OUT to use formatting_func
    sft_config.max_seq_length = 1024
    sft_config.num_train_epochs = 3
    sft_config.per_device_train_batch_size = 4
    sft_config.gradient_accumulation_steps = 4
    sft_config.optim = "paged_adamw_32bit"
    sft_config.save_strategy = "steps"
    sft_config.save_steps = 50
    sft_config.save_total_limit = 2
    sft_config.logging_steps = 10
    sft_config.learning_rate = 2e-4
    sft_config.weight_decay = 0.001
    sft_config.bf16 = True
    sft_config.push_to_hub = True
    sft_config.hub_model_id = hub_model_id
    sft_config.hub_private_repo = True
    sft_config.hub_strategy = "checkpoint"
    sft_config.report_to = "none"
    sft_config.packing = False 

    # --- 1. Formatting Function (Combine Instruction + Output) ---
    def formatting_prompts_func(example):
        output_texts = []
        # Robust check for single string vs list
        instructions = example['instruction']
        outputs = example['output']
        
        if isinstance(instructions, str):
            instructions = [instructions]
            outputs = [outputs]
            
        for i in range(len(instructions)):
            inst = instructions[i]
            out = outputs[i] if i < len(outputs) else "error"
            
            # Standard Alpaca-style or simple User/Assistant format
            text = f"<start_of_turn>user\n{inst}<end_of_turn>\n<start_of_turn>model\n{out}<end_of_turn>"
            output_texts.append(text)
        
        return output_texts

    # PRE-PROCESSING: Apply formatting manually to avoid 'add_eos' TRL bugs
    print("Pre-formatting dataset...")
    # Map the formatting function to create a 'text' column
    dataset = dataset.map(lambda x: {"text": formatting_prompts_func(x)}, batched=True)

    # --- 2. Custom Collator (Fixes Gemma 3 & Labels) ---
    def data_collator(features):
        batch = tokenizer.pad(features, padding=True, return_tensors="pt")
        batch = dict(batch) 
        
        if "token_type_ids" not in batch:
            batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
            
        if "labels" not in batch:
            labels = batch["input_ids"].clone()
            if tokenizer.pad_token_id is not None:
                labels[labels == tokenizer.pad_token_id] = -100
            batch["labels"] = labels
            
        return batch

    sft_config.dataset_text_field = "text" # Point to our new column

    # 6. Trainer
    trainer = SFTTrainer(
        model=model, 
        train_dataset=dataset, 
        peft_config=peft_config, 
        processing_class=tokenizer,
        args=sft_config,
        # formatting_func=formatting_prompts_func, # REMOVED
        data_collator=data_collator 
    )

    # 7. Train (Resume Check)
    checkpoint = None
    if os.path.isdir(output_dir) and any(d.startswith("checkpoint") for d in os.listdir(output_dir)):
        checkpoint = True
        print(f"Resuming from checkpoint in {output_dir}")
    
    trainer.train(resume_from_checkpoint=checkpoint)
    
    # 8. Save & Push
    trainer.model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    trainer.push_to_hub()