In [None]:
# ============================================================
# CELL 1: Setup + Config
# ============================================================
import json, re, torch, gc, os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# ‚îÄ‚îÄ UPDATE THESE PATHS ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
ADAPTER_DIR  = '/kaggle/input/medgemma-loopguard-v2/transformers/default/1'
PATIENT_FILE = '/kaggle/input/loopguard-patients/patients.json'
BASE_MODEL   = 'google/medgemma-1.5-4b-it'
OUTPUT_FILE  = '/kaggle/working/patients_with_ai_final.json'
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
)

print('‚úÖ Config ready')
print(f'   VRAM available: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated())/1e9:.2f} GB')

In [None]:
# ============================================================
# CELL 2: Load Model + Adapter
# ============================================================
print('Loading base MedGemma...')
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=BNB_CONFIG,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
)

print('Loading LoopGuard v2 adapter...')
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model.eval()

print(f'‚úÖ Model ready | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB')

In [None]:
# ============================================================
# CELL 3: Inference + Parsing Functions
#
# Key lessons applied from eval phase:
#   1. decode_output: split on '<start_of_turn>model' to preserve
#      the PRIMARY HYPOTHESIS label (eval showed it was being truncated)
#   2. parse_fields: lookahead to next field boundary avoids
#      bleeding content between fields
#   3. Urgency normalization: 'moderate' -> 'medium', default 'medium'
#   4. clinical_note is a dict ‚Äî always use ['text'] key
# ============================================================

FIELD_PATTERNS = [
    ('primary_hypothesis',      r'PRIMARY HYPOTHESIS:'),
    ('differential_diagnoses',  r'DIFFERENTIAL DIAGNOSES:'),
    ('key_supporting_evidence', r'KEY SUPPORTING EVIDENCE:'),
    ('urgency_level',           r'URGENCY LEVEL:'),
    ('tests_ordered',           r'TESTS ORDERED:'),
    ('clinical_reasoning',      r'CLINICAL REASONING:'),
]

FIELD_LABEL_RE = (
    r'(?:PRIMARY HYPOTHESIS|DIFFERENTIAL DIAGNOSES|KEY SUPPORTING EVIDENCE|'
    r'URGENCY LEVEL|TESTS ORDERED|CLINICAL REASONING):'
)


def decode_output(raw: str) -> str:
    """
    Extract model-generated text from full decoded string.
    MedGemma chat format: prompt ends with '<start_of_turn>model\n'.
    Take everything AFTER the last occurrence of that marker.
    This ensures PRIMARY HYPOTHESIS: label is fully preserved.
    """
    marker = '<start_of_turn>model'
    if marker in raw:
        return raw.split(marker)[-1].lstrip('\n').strip()
    # Should not happen with MedGemma chat format, but safe fallback
    return raw.strip()


def parse_fields(text: str) -> dict:
    """
    Extract all 6 structured fields from generated text.
    Uses lookahead to next field label as the boundary.
    """
    # Strip markdown bold that some outputs include
    clean = re.sub(r'\*\*', '', text)
    result = {}

    for field_key, pattern in FIELD_PATTERNS:
        m = re.search(pattern, clean, re.IGNORECASE)
        if m:
            start = m.end()
            # Find where next field starts
            next_m = re.search(FIELD_LABEL_RE, clean[start:], re.IGNORECASE)
            end = start + next_m.start() if next_m else start + 800
            value = clean[start:end].strip().strip('[]')
            result[field_key] = value
        else:
            result[field_key] = ''

    # Normalize urgency
    u = result.get('urgency_level', '').lower()
    if 'high' in u:
        result['urgency_level'] = 'high'
    elif 'low' in u:
        result['urgency_level'] = 'low'
    elif 'medium' in u or 'moderate' in u:
        result['urgency_level'] = 'medium'
    else:
        result['urgency_level'] = 'medium'

    return result


def run_inference(note_text: str) -> tuple:
    """Run LoopGuard v2 on clinical note text. Returns (parsed_dict, raw_generated)."""
    prompt = (
        '<start_of_turn>user\n'
        'Extract diagnostic information from this clinical note.\n\n'
        f'Clinical Note:\n{note_text}\n\n'
        'Output ONLY these 6 fields:\n'
        'PRIMARY HYPOTHESIS: [main diagnosis]\n'
        'DIFFERENTIAL DIAGNOSES: [comma-separated alternatives]\n'
        'KEY SUPPORTING EVIDENCE: [comma-separated findings]\n'
        'URGENCY LEVEL: [high/medium/low]\n'
        'TESTS ORDERED: [comma-separated tests]\n'
        'CLINICAL REASONING: [brief explanation]'
        '<end_of_turn>\n<start_of_turn>model\n'
    )

    inputs = tokenizer(
        prompt, return_tensors='pt',
        truncation=True, max_length=1024
    ).to(model.device)

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=600,
            do_sample=False,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
        )

    raw = tokenizer.decode(out[0], skip_special_tokens=True)
    generated = decode_output(raw)
    parsed = parse_fields(generated)
    return parsed, generated


print('‚úÖ Functions ready')

In [None]:
# ============================================================
# CELL 4: Load Patients + Smoke Test on P001
# Verify decode and parse are correct before full batch
# ============================================================
with open(PATIENT_FILE) as f:
    patients = json.load(f)['patient_scenarios']

print(f'‚úÖ Loaded {len(patients)} patients')

# Smoke test
test = patients[0]
note_text = test['clinical_note']['text']  # clinical_note is a dict

print(f'\nüî¨ Smoke test: {test["patient_id"]} ({test["ground_truth_diagnosis"]})')
parsed, raw = run_inference(note_text)

print(f'\n--- RAW OUTPUT (first 400 chars) ---')
print(raw[:400])
print(f'\n--- PARSED ---')
for k, v in parsed.items():
    status = '‚úÖ' if v else '‚ùå MISSING'
    print(f'  {status} {k}: {v[:80] if v else ""}')

fields_ok = sum(1 for v in parsed.values() if v)
print(f'\n  {fields_ok}/6 fields | urgency={parsed["urgency_level"]}')
print('\n‚úÖ Smoke test passed ‚Äî proceed to full batch' if fields_ok >= 5 else '\n‚ö†Ô∏è  Check output before proceeding')

In [None]:
# ============================================================
# CELL 5: Full Batch ‚Äî All 20 Patients
# ============================================================
results = []
field_keys = [k for k, _ in FIELD_PATTERNS]

print('Running LoopGuard v2 on all 20 patients...\n')

for i, patient in enumerate(patients):
    pid = patient['patient_id']
    note_text = patient['clinical_note']['text']

    print(f'[{i+1:02d}/20] {pid}...', end=' ', flush=True)

    parsed, raw = run_inference(note_text)

    fields_present = sum(1 for v in parsed.values() if v)
    urgency = parsed['urgency_level']

    tick = '‚úÖ' if fields_present == 6 else '‚ö†Ô∏è'
    print(f'{tick} {fields_present}/6 | urgency={urgency}')

    # Preserve all original patient data, add ai_analysis block
    enriched = {
        **patient,
        'ai_analysis': {
            **parsed,
            '_raw_output': raw,
            '_fields_present': fields_present,
        }
    }
    results.append(enriched)

print(f'\n‚úÖ Done: {len(results)}/20 patients')

In [None]:
# ============================================================
# CELL 6: QC Report + Save
# ============================================================
print('=== BATCH QC REPORT ===')
print(f'{"Patient":<8} {"Fields":<8} {"Urgency":<10} Missing')
print('-' * 55)

complete_count = 0
urgency_dist = {'high': 0, 'medium': 0, 'low': 0}
missing_tally = {k: 0 for k in field_keys}

for r in results:
    ai = r['ai_analysis']
    fp = ai['_fields_present']
    urg = ai['urgency_level']
    missing = [k for k in field_keys if not ai.get(k)]

    if fp == 6:
        complete_count += 1
    urgency_dist[urg] = urgency_dist.get(urg, 0) + 1
    for k in missing:
        missing_tally[k] += 1

    tick = '‚úÖ' if fp == 6 else '‚ö†Ô∏è'
    print(f"{r['patient_id']:<8} {tick}{fp}/6   {urg:<10} {missing or 'none'}")

print()
print(f'Complete (6/6):     {complete_count}/20')
print(f'Urgency dist:       {urgency_dist}')
print(f'Missing field tally:{[(k,v) for k,v in missing_tally.items() if v > 0]}')

# Save
output = {'patient_scenarios': results}
with open(OUTPUT_FILE, 'w') as f:
    json.dump(output, f, indent=2)

size_kb = len(json.dumps(output)) / 1024
print(f'\n‚úÖ Saved ‚Üí {OUTPUT_FILE} ({size_kb:.1f} KB)')

In [None]:
# ============================================================
# CELL 7: Spot Check ‚Äî 3 Patients Full Output
# P001 (ovarian cancer, high), P009 (hyponatremia, high), P019 (thyroid ca, low)
# Manually verify raw output and parsing look correct
# ============================================================
spot_ids = ['P001', 'P009', 'P019']

for r in results:
    if r['patient_id'] not in spot_ids:
        continue
    ai = r['ai_analysis']
    print(f"\n{'='*65}")
    print(f"  {r['patient_id']} | GT: {r['ground_truth_diagnosis']}")
    print(f"{'='*65}")
    print(f"RAW (first 500 chars):\n{ai['_raw_output'][:500]}")
    print(f"\nPARSED:")
    for k in field_keys:
        val = ai.get(k, '')
        icon = '‚úÖ' if val else '‚ùå'
        print(f"  {icon} {k}: {val[:100] if val else 'EMPTY'}")