In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import torch
import os

# ‚úÖ Make sure we‚Äôre on MPS
device = "mps" if torch.backends.mps.is_available() else "cpu"

# üß† Load base model and tokenizer
model_name = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir="./hf_cache")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./hf_cache")
model = get_peft_model(model, LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
))

# üöö Move model to MPS
model.to(device)
print(f"üß† Model is on: {next(model.parameters()).device}")

# üìö Load dataset
dataset = load_dataset("json", data_files="deid_dataset.jsonl", split="train")

# üîß Format prompt
def format_example(example):
    prompt = (
        "You are a medical assistant. Given a clinical note containing sensitive patient information (PHI), your job is to:\n"
        "1. Identify all instances of PHI.\n"
        "2. Reason through what should be redacted and why.\n"
        "3. Output the redacted note, replacing PHI with [REDACTED] or placeholders like [DOB], [NAME], [ADDRESS].\n"
        "Note: PHI includes names, dates of birth, phone numbers, SSNs, addresses, provider names, hospitals, emails, etc.\n"
        "---\n"
        f"<data_with_phi>\n{example['data_with_phi']}\n</data_with_phi>\n"
        f"<data_hipaa_compliant>\n{example['data_hipaa_compliant']}\n</data_hipaa_compliant>"
    )
    return {"text": prompt}

# üîÑ Preprocess
dataset = dataset.map(format_example, remove_columns=dataset.column_names)

def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=512
    )

dataset = dataset.map(tokenize_function, batched=True)

# ü§ñ Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# üõ†Ô∏è Training args ‚Äî NO fp16
training_args = TrainingArguments(
    output_dir="./gemma-deid-lora",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=2e-4,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    fp16=False,           # üîí Explicitly disable mixed precision
    bf16=False,           # üîí Also disable bf16 just in case
    torch_compile=False,  # üîí Avoid weird compile errors on MPS
    report_to="none",
    load_best_model_at_end=False,
)

# üîÅ Resume logic
checkpoint_dir = training_args.output_dir
last_checkpoint = None
if os.path.isdir(checkpoint_dir):
    subdirs = [d for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint")]
    if subdirs:
        last_checkpoint = os.path.join(checkpoint_dir, sorted(subdirs)[-1])
        print(f"üîÅ Resuming from checkpoint: {last_checkpoint}")

# üèãÔ∏è Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# üöÄ Train
if __name__ == "__main__":
    print("üöÄ Starting training...")
    trainer.train(resume_from_checkpoint=last_checkpoint)

Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:01<00:00,  1.62it/s]


üß† Model is on: mps:0


  trainer = Trainer(


üöÄ Starting training...




Step,Training Loss
10,1.5982
20,1.42
30,1.2929
40,1.1287
50,1.0354
60,0.9637




In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import torch

# üîß Paths to model
BASE_MODEL = "google/gemma-2b"
LORA_MODEL_PATH = "./gemma-deid-lora/checkpoint-60"

# üîÅ Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# üß† Load base model and merge LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH)
model = model.merge_and_unload()
model.eval()

# üíª Use MPS if available on Mac, else CUDA, else CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model.to(device)

# üß™ Inference function
def redact_note(phi_note: str):
    prompt = f"""You are a medical assistant. Given a clinical note containing sensitive patient information (PHI), your job is to:

1. Identify all instances of PHI.
2. Reason through what should be redacted and why.
3. Output the redacted note, replacing PHI with [REDACTED].

PHI includes: names, dates of birth, phone numbers, SSNs, addresses, provider names, hospitals, emails, etc.

---
<data_with_phi>
{phi_note}
</data_with_phi>
<data_hipaa_compliant>
"""

    # Tokenize input
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=300,         # ‚¨ÜÔ∏è Allow longer outputs
            temperature=0.7,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id  # üëà Avoid warnings
            # eos_token_id removed to prevent premature stopping
        )

    # Decode and extract only redacted portion
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)

    if "<data_hipaa_compliant>" in decoded:
        redacted_part = decoded.split("<data_hipaa_compliant>")[-1]
        redacted_part = redacted_part.split("</data_hipaa_compliant>")[0].strip()
    else:
        redacted_part = decoded[len(prompt):].strip()

    return redacted_part

# üîç Example usage
phi_input = """
Patient: Ms. Yaeko Ming Kshlerin, SSN: 999-26-7676, born on 1999‚Äë06‚Äë07 in Oakes, North Dakota, presented to TOWNER COUNTY MEDICAL CENTER INC (HWY‚ÄØ281N, CANDO, ND‚ÄØ58324) on 2000‚Äë11‚Äë20 for an encounter for problem (procedure) related to allergic disposition; she reports a lifelong allergy to animal dander with moderate rhinoconjunctivitis and mild skin eruptions, and she is currently under the care of Dr. Shiloh Larson, general practice.  
The visit was classified as ambulatory, with a base encounter cost of $96.45 and a total claim cost of $483.55; payer coverage was $0.00, leaving her responsible for the full cost, while her total healthcare expenses amount to $127,546.31 against a coverage pool of $673,780.87, and her annual income is $63,061.  
Ms. Kshlerin resides at 523 O'Kon Orchard, Cando, ND‚ÄØ58324 (Towner County, FIPS‚ÄØ38095), and is a white, non‚ÄëHispanic female with no recorded marital status.
"""

print("üì§ Redacted note:\n", redact_note(phi_input))

Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:01<00:00,  1.82it/s]
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


üì§ Redacted note:
 Patient: Ms. Yaeko Ming [REDACTED], SSN: [REDACTED], born on [REDACTED] in [REDACTED], presented to [REDACTED] on [REDACTED] for an encounter for problem [REDACTED] related to allergic disposition; she reports a lifelong allergy to animal dander with moderate rhinoconjunctivitis and mild skin eruptions, and she is currently under the care of Dr. [REDACTED], general practice.  
The visit was classified as ambulatory, with a base encounter cost of $96.45 and a total claim cost of $483.55; payer coverage was $0.00, leaving her responsible for the full cost, while her total healthcare expenses amount to $127,546.31 against a coverage pool of $673,780.87, and her annual income is $63,061.  
Ms. Kshlerin resides at [REDACTED], Cando, ND‚ÄØ58324 (Towner County, FIPS‚ÄØ38095), and is a white, non‚ÄëHispanic female with no recorded marital status.
