# MedGemma Fine-Tuning v2 ‚Äî LoopGuard Hypothesis Extraction

**Model:** `google/medgemma-1.5-4b-it`  
**Method:** LoRA (r=16, alpha=32) + 4-bit quantization  
**Data:** `training_final_400.json` (421 examples, flat array format)  
**Expected runtime:** ~3‚Äì4 hours on T4  

---

## Workflow
1. **Cell 1** ‚Äî Install dependencies ‚Üí restart kernel
2. **Cells 2‚Äì3** ‚Äî Imports + data loading
3. ‚ö†Ô∏è **Cell 4** ‚Äî DATA INSPECTION ‚Üí **run and share output before continuing**
4. **Cell 5** ‚Äî Load model & tokenizer
5. **Cell 6** ‚Äî HuggingFace login (needed for MedGemma gated access)
6. **Cell 7** ‚Äî Tokenize with prompt masking fix
7. ‚ö†Ô∏è **Cell 8** ‚Äî TOKEN LENGTH CHECK ‚Üí **run and share output before continuing**
8. **Cell 9** ‚Äî LoRA config
9. **Cell 10** ‚Äî Training args
10. **Cell 11** ‚Äî Trainer setup
11. **Cell 12** ‚Äî üöÄ TRAIN (3‚Äì4 hours)
12. **Cell 13** ‚Äî Save model
13. **Cell 14** ‚Äî Test inference
14. **Cell 15** ‚Äî Training report

In [None]:
# ============================================================
# CELL 1: Install Dependencies
# Run once, then restart kernel before continuing
# ============================================================
print("üì¶ Installing dependencies...")

!pip uninstall -y -q transformers peft trl bitsandbytes accelerate
!pip install -q transformers>=4.47.0
!pip install -q peft>=0.13.0
!pip install -q trl>=0.11.0
!pip install -q accelerate>=0.34.0
!pip install -q bitsandbytes>=0.46.1

print("\n‚úÖ Done! ‚ö†Ô∏è  RESTART KERNEL NOW before running any other cells.")
print("   Kernel ‚Üí Restart ‚Üí then start from Cell 2")

In [None]:
# ============================================================
# CELL 2: Imports
# ============================================================
import torch
import json
import os
from datetime import datetime
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from trl import SFTTrainer
from datasets import Dataset

print("‚úÖ Imports successful")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 3: Load Data
# File: training_final_400.json (flat array format)
# ============================================================
DATA_FILENAME = "training_final_400.json"

# Search for the file in common Kaggle input locations
def find_data_file(filename):
    search_roots = ["/kaggle/input", "/kaggle/working"]
    for root in search_roots:
        for dirpath, _, filenames in os.walk(root):
            if filename in filenames:
                return os.path.join(dirpath, filename)
    return None

data_path = find_data_file(DATA_FILENAME)
assert data_path is not None, f"‚ùå Could not find '{DATA_FILENAME}' in /kaggle/input. Make sure you added the dataset."
print(f"üìÇ Found data at: {data_path}")

with open(data_path, 'r') as f:
    raw_data = json.load(f)

# Handle both flat array and wrapped format
if isinstance(raw_data, list):
    examples = raw_data
elif isinstance(raw_data, dict) and 'examples' in raw_data:
    examples = raw_data['examples']
else:
    raise ValueError("‚ùå Unexpected data format. Expected a list or dict with 'examples' key.")

print(f"‚úÖ Loaded {len(examples)} examples")
print(f"   Keys in first example: {list(examples[0].keys())}")
print(f"   Output fields: {list(examples[0]['output'].keys())}")

In [None]:
# ============================================================
# CELL 4: DATA INSPECTION  ‚ö†Ô∏è STOP HERE
# Run this cell and share the output before continuing.
# We need to verify data quality and set max_length correctly.
# ============================================================
print("üìä DATA QUALITY INSPECTION")
print("=" * 60)

# Check required output fields
REQUIRED_FIELDS = ['primary_hypothesis', 'differential_diagnoses', 'key_symptoms', 'urgency', 'tests_ordered', 'reasoning']
missing_field_counts = {f: 0 for f in REQUIRED_FIELDS}
urgency_counts = {}
bad_examples = []

for i, ex in enumerate(examples):
    out = ex.get('output', {})
    for field in REQUIRED_FIELDS:
        if field not in out:
            missing_field_counts[field] += 1
    urgency = out.get('urgency', 'MISSING')
    urgency_counts[urgency] = urgency_counts.get(urgency, 0) + 1
    if not ex.get('input') or not ex.get('output'):
        bad_examples.append(i)

print(f"\nTotal examples: {len(examples)}")
print(f"Malformed examples (missing input or output): {len(bad_examples)}")
if bad_examples:
    print(f"  Indices: {bad_examples[:10]}")

print("\nMissing field counts (0 = all good):")
for field, count in missing_field_counts.items():
    status = "‚úÖ" if count == 0 else "‚ùå"
    print(f"  {status} {field}: {count} missing")

print("\nUrgency distribution:")
for urgency, count in sorted(urgency_counts.items()):
    pct = 100 * count / len(examples)
    print(f"  {urgency}: {count} ({pct:.1f}%)")

print("\nSample input (first example):")
print(examples[0]['input'][:300] + "...")
print("\nSample output (first example):")
print(json.dumps(examples[0]['output'], indent=2)[:500])

print("\n" + "=" * 60)
print("‚ö†Ô∏è  SHARE THIS OUTPUT before running Cell 5+")

In [None]:
# ============================================================
# CELL 5: Load Model & Tokenizer
# ~1‚Äì2 min
# ============================================================
model_id = "google/medgemma-1.5-4b-it"
print(f"\nüî• Loading {model_id}...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("‚úÖ Tokenizer loaded")

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    dtype=torch.bfloat16,
)

print(f"‚úÖ Model loaded ({torch.cuda.memory_allocated()/1e9:.2f} GB VRAM used)")

In [None]:
# ============================================================
# CELL 6: HuggingFace Login
# Required for gated MedGemma model access
# ============================================================
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# ============================================================
# CELL 7: Tokenize with Prompt Masking (v2 fix)
#
# KEY CHANGE FROM V1: We now set labels = -100 for the prompt
# portion so the model only learns to predict the output, not
# repeat back the input. This improves structured output quality.
# ============================================================

PROMPT_TEMPLATE = (
    "<start_of_turn>user\n"
    "Extract diagnostic information from this clinical note.\n\n"
    "Clinical Note:\n{note}<end_of_turn>\n"
    "<start_of_turn>model\n"
)

def format_output(out):
    """Convert structured output dict to flat text format."""
    diff_dx = ', '.join(out.get('differential_diagnoses', []))
    key_symptoms = ', '.join(out.get('key_symptoms', []))
    tests = ', '.join(out.get('tests_ordered', []))
    return (
        f"PRIMARY HYPOTHESIS: {out.get('primary_hypothesis', '')}\n"
        f"DIFFERENTIAL DIAGNOSES: {diff_dx}\n"
        f"KEY SUPPORTING EVIDENCE: {key_symptoms}\n"
        f"URGENCY LEVEL: {out.get('urgency', '')}\n"
        f"TESTS ORDERED: {tests}\n"
        f"CLINICAL REASONING: {out.get('reasoning', '')}"
    )

def tokenize_with_masking(ex, max_length=768):
    """Tokenize and mask prompt tokens in labels (output-only supervision)."""
    note = ex['input']
    output_text = format_output(ex['output'])

    prompt = PROMPT_TEMPLATE.format(note=note)
    full_text = prompt + output_text + "<end_of_turn>"

    # Tokenize full sequence
    tokenized = tokenizer(
        full_text,
        truncation=True,
        max_length=max_length,
        padding=False,
        return_attention_mask=True,
    )

    # Tokenize prompt only to find its token length
    prompt_tokenized = tokenizer(
        prompt,
        truncation=True,
        max_length=max_length,
        padding=False,
    )
    prompt_len = len(prompt_tokenized['input_ids'])

    # Labels: -100 for prompt tokens (ignored in loss), real ids for output
    input_ids = tokenized['input_ids']
    labels = [-100] * min(prompt_len, len(input_ids)) + input_ids[prompt_len:]

    # Ensure labels same length as input_ids
    labels = labels[:len(input_ids)]
    if len(labels) < len(input_ids):
        labels += [-100] * (len(input_ids) - len(labels))

    tokenized['labels'] = labels
    tokenized['token_type_ids'] = [0] * len(input_ids)

    return tokenized

print("üîÑ Tokenizing examples with prompt masking...")
formatted = [tokenize_with_masking(ex) for ex in examples]
dataset = Dataset.from_list(formatted)

split = dataset.train_test_split(test_size=0.1, seed=42)
train_data = split['train']
val_data = split['test']

print(f"‚úÖ Train: {len(train_data)}, Val: {len(val_data)}")
print(f"   Columns: {train_data.column_names}")
print(f"   Sample labels (first 10): {train_data[0]['labels'][:10]}  ‚Üê should be mostly -100")

In [None]:
# ============================================================
# CELL 8: TOKEN LENGTH CHECK  ‚ö†Ô∏è STOP HERE
# Run this and share the output.
# If P95 > 700, we need to increase max_length and re-run Cell 7.
# ============================================================
print("üìè TOKEN LENGTH ANALYSIS")
print("=" * 60)

all_lengths = [len(ex['input_ids']) for ex in formatted]
all_lengths.sort()
n = len(all_lengths)

p50 = all_lengths[int(0.50 * n)]
p75 = all_lengths[int(0.75 * n)]
p90 = all_lengths[int(0.90 * n)]
p95 = all_lengths[int(0.95 * n)]
p99 = all_lengths[int(0.99 * n)]
max_len = all_lengths[-1]
min_len = all_lengths[0]
mean_len = sum(all_lengths) // n

print(f"Min:  {min_len}")
print(f"Mean: {mean_len}")
print(f"P50:  {p50}")
print(f"P75:  {p75}")
print(f"P90:  {p90}")
print(f"P95:  {p95}  ‚Üê KEY NUMBER")
print(f"P99:  {p99}")
print(f"Max:  {max_len}")

truncated = sum(1 for ex in formatted if len(ex['input_ids']) >= 768)
print(f"\nExamples truncated at 768 tokens: {truncated} ({100*truncated/n:.1f}%)")

if p95 > 700:
    print("\n‚ö†Ô∏è  P95 > 700: Re-run Cell 7 with max_length=1024")
elif p95 > 500:
    print("\n‚úÖ P95 in range. Current max_length=768 is appropriate.")
else:
    print("\n‚úÖ Sequences are short. Could reduce max_length=512 to speed up training.")

print("\n" + "=" * 60)
print("‚ö†Ô∏è  SHARE THIS OUTPUT before running Cell 9+")

In [None]:
# ============================================================
# CELL 9: LoRA Configuration
# Identical to v1 (proven config)
# ============================================================
print("üîß Applying LoRA...")

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    inference_mode=False,
)

model = get_peft_model(model, lora_config)
trainable_params, total_params = model.get_nb_trainable_parameters()
print(f"‚úÖ LoRA applied: {trainable_params:,} trainable / {total_params:,} total ({100 * trainable_params / total_params:.2f}%)")

In [None]:
# ============================================================
# CELL 10: Training Arguments
# Same as v1 ‚Äî batch_size=1, grad_accum=16, lr=2e-4, 3 epochs
# warmup_steps replaces deprecated warmup_ratio
# ============================================================

# Estimated optimizer steps for warmup calculation
estimated_steps_per_epoch = len(train_data) // 16  # grad_accum=16
total_steps = estimated_steps_per_epoch * 3
warmup_steps = max(10, int(0.05 * total_steps))  # 5% warmup

print(f"Estimated optimizer steps: {estimated_steps_per_epoch}/epoch, {total_steps} total")
print(f"Warmup steps: {warmup_steps}")

training_args = TrainingArguments(
    output_dir="/kaggle/working/medgemma-v2-checkpoints",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=warmup_steps,
    fp16=False,
    bf16=True,
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    dataloader_num_workers=0,
    remove_unused_columns=False,
    label_names=["labels"],
)

print("‚úÖ Training args ready")

In [None]:
# ============================================================
# CELL 11: Data Collator + Trainer Setup
# ============================================================
from dataclasses import dataclass
from typing import Any, Dict, List

@dataclass
class PromptMaskedCollator:
    """Pads sequences while preserving -100 labels for prompt masking."""
    tokenizer: Any

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        max_length = max(len(f['input_ids']) for f in features)

        batch = {'input_ids': [], 'attention_mask': [], 'token_type_ids': [], 'labels': []}

        for f in features:
            pad_len = max_length - len(f['input_ids'])
            batch['input_ids'].append(f['input_ids'] + [self.tokenizer.pad_token_id] * pad_len)
            batch['attention_mask'].append(f['attention_mask'] + [0] * pad_len)
            batch['token_type_ids'].append(f.get('token_type_ids', [0] * len(f['input_ids'])) + [0] * pad_len)
            # Pad labels with -100 so padding is ignored in loss
            batch['labels'].append(f['labels'] + [-100] * pad_len)

        return {
            'input_ids': torch.tensor(batch['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(batch['attention_mask'], dtype=torch.long),
            'token_type_ids': torch.tensor(batch['token_type_ids'], dtype=torch.long),
            'labels': torch.tensor(batch['labels'], dtype=torch.long),
        }

data_collator = PromptMaskedCollator(tokenizer=tokenizer)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    processing_class=tokenizer,
    data_collator=data_collator,
)

print("‚úÖ Trainer ready")
print(f"   Train: {len(train_data)} examples")
print(f"   Val:   {len(val_data)} examples")
print(f"   Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")

In [None]:
# ============================================================
# CELL 12: üöÄ TRAIN
# Expected: ~3‚Äì4 hours on T4
# Watch for val_loss dropping each epoch.
# v1 baseline: 1.348 ‚Üí 1.101 ‚Üí 1.075
# v2 target:   ~1.1  ‚Üí ~0.9  ‚Üí ~0.8
# ============================================================
print("\n" + "=" * 70)
print("üöÄ STARTING FINE-TUNING v2")
print("=" * 70)
print(f"Start: {datetime.now().strftime('%H:%M:%S')}")
print("=" * 70 + "\n")

train_start = datetime.now()
result = trainer.train()
train_end = datetime.now()

print(f"\n‚úÖ Training complete!")
print(f"   Duration: {train_end - train_start}")
print(f"   Final train loss: {result.training_loss:.4f}")

eval_result = trainer.evaluate()
print(f"   Final val loss: {eval_result['eval_loss']:.4f}")
print(f"   v1 val loss was: 1.075 ‚Äî improvement: {1.075 - eval_result['eval_loss']:.3f}")

In [None]:
# ============================================================
# CELL 13: Save Model
# ============================================================
output_dir = "/kaggle/working/medgemma-hypothesis-extraction-v2"
os.makedirs(output_dir, exist_ok=True)

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"‚úÖ Model saved to {output_dir}")
print("\nüìÅ Files:")
for fname in os.listdir(output_dir):
    size_mb = os.path.getsize(os.path.join(output_dir, fname)) / 1e6
    print(f"   {fname}: {size_mb:.1f} MB")

In [None]:
# ============================================================
# CELL 14: Test Inference ‚Äî Compare v1 vs v2 Output Quality
# ============================================================
print("\nüß™ Testing fine-tuned model...\n")

# Load base model + v2 adapter fresh
base = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    dtype=torch.bfloat16,
)
ft_model = PeftModel.from_pretrained(base, output_dir)
ft_model.eval()
print("‚úÖ Fine-tuned model loaded\n")

# Pick a held-out val example
test_ex = examples[0]  # Use first example; you can change to any index
note = test_ex['input']
expected = format_output(test_ex['output'])

prompt = PROMPT_TEMPLATE.format(note=note)
inputs = tokenizer(prompt, return_tensors="pt").to(ft_model.device)

with torch.no_grad():
    outputs = ft_model.generate(
        **inputs,
        max_new_tokens=400,
        temperature=0.1,
        do_sample=False,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Extract model turn only
if "<start_of_turn>model" in response:
    generated = response.split("<start_of_turn>model")[-1].strip()
    generated = generated.split("<end_of_turn>")[0].strip()
else:
    generated = response

print("=" * 70)
print("EXPECTED OUTPUT:")
print("=" * 70)
print(expected)
print("\n" + "=" * 70)
print("FINE-TUNED MODEL (v2) OUTPUT:")
print("=" * 70)
print(generated)

# Quick field check
print("\n" + "=" * 70)
print("üìä FIELD PRESENCE CHECK:")
for field in ["PRIMARY HYPOTHESIS", "DIFFERENTIAL DIAGNOSES", "KEY SUPPORTING EVIDENCE", "URGENCY LEVEL", "TESTS ORDERED", "CLINICAL REASONING"]:
    present = "‚úÖ" if field in generated else "‚ùå"
    print(f"  {present} {field}")

In [None]:
# ============================================================
# CELL 15: Training Report
# ============================================================
report = {
    "model_name": "medgemma-hypothesis-extraction-v2",
    "base_model": model_id,
    "training_date": train_start.isoformat(),
    "training_duration": str(train_end - train_start),
    "dataset": {
        "file": DATA_FILENAME,
        "total_examples": len(examples),
        "train_examples": len(train_data),
        "val_examples": len(val_data),
    },
    "training_config": {
        "epochs": 3,
        "learning_rate": 2e-4,
        "lora_r": 16,
        "lora_alpha": 32,
        "batch_size": 1,
        "gradient_accumulation": 16,
        "warmup_steps": warmup_steps,
        "max_seq_length": 768,
        "prompt_masking": True,
    },
    "performance": {
        "final_train_loss": float(result.training_loss),
        "final_val_loss": float(eval_result['eval_loss']),
        "v1_val_loss_for_comparison": 1.075,
        "improvement_over_v1": round(1.075 - float(eval_result['eval_loss']), 4),
    },
    "usage_instructions": {
        "prompt_template": PROMPT_TEMPLATE,
        "generation_params": {
            "max_new_tokens": 400,
            "temperature": 0.1,
            "do_sample": False,
            "repetition_penalty": 1.1,
        },
    },
}

report_path = "/kaggle/working/training_report_v2.json"
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print("‚úÖ Report saved to", report_path)
print(json.dumps(report, indent=2))