# Patient Triage: Urgent vs. Non-Urgent Case Classification with PEFT

**Features:**
- Synthetic dataset for urgent/non-urgent free-text symptoms
- BlueBERT (or ClinicalBERT) + LoRA parameter-efficient fine-tuning
- Training, evaluation (accuracy, precision, recall, F1)
- Inference example
- (Bonus) API deployment guidance

In [ ]:
# Install dependencies
!pip install transformers datasets peft scikit-learn

## 1. Synthetic Data Creation

In [ ]:
import pandas as pd
import random

urgent_symptoms = [
    "sudden severe chest pain and shortness of breath",
    "loss of consciousness and slurred speech",
    "massive bleeding from a wound",
    "severe abdominal pain with vomiting blood",
    "high fever with neck stiffness",
    "new onset seizure",
    "severe allergic reaction and swelling of lips",
    "acute confusion and inability to move arm or leg"
]

nonurgent_symptoms = [
    "mild headache for two days, no other symptoms",
    "intermittent knee pain when walking",
    "itchy skin with mild redness",
    "occasional cough without fever",
    "runny nose and mild sore throat",
    "feeling tired after work",
    "low back pain for one week after lifting box",
    "seasonal allergies causing sneezing"
]

data = []
for _ in range(120):
    if random.random() < 0.5:
        symptom = random.choice(urgent_symptoms)
        label = 1  # urgent
    else:
        symptom = random.choice(nonurgent_symptoms)
        label = 0  # non-urgent
    data.append({"symptom": symptom, "label": label})

df = pd.DataFrame(data)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
df.to_csv("synthetic_triage_data.csv", index=False)
df.head()

## 2. LoRA (PEFT) Fine-Tuning with BlueBERT/ClinicalBERT

In [ ]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType

# Choose model: BlueBERT or ClinicalBERT
# model_name = "emilyalsentzer/Bio_ClinicalBERT"
model_name = "bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Data
df = pd.read_csv("synthetic_triage_data.csv")
dataset = Dataset.from_pandas(df)
dataset = dataset.rename_column("symptom", "text")
dataset = dataset.rename_column("label", "labels")
dataset = dataset.train_test_split(test_size=0.2, seed=42)

# Model
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, r=8, lora_alpha=16, lora_dropout=0.1, bias="none"
)
model = get_peft_model(model, lora_config)

def preprocess(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=64)

tokenized_ds = dataset.map(preprocess, batched=False)

training_args = TrainingArguments(
    output_dir="./triage_bluebert_lora",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1"
)

# Custom metrics
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
def compute_metrics(pred):
    preds = pred.predictions.argmax(axis=1)
    labels = pred.label_ids
    acc = accuracy_score(labels, preds)
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average=None, labels=[0,1])
    return {
        "accuracy": acc,
        "precision_nonurgent": p[0],
        "recall_nonurgent": r[0],
        "f1_nonurgent": f1[0],
        "precision_urgent": p[1],
        "recall_urgent": r[1],
        "f1_urgent": f1[1],
        "eval_f1": f1.mean()
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

model.save_pretrained("./triage_bluebert_lora")
tokenizer.save_pretrained("./triage_bluebert_lora")

## 3. Evaluation: See metrics in logs and get predictions

In [ ]:
# Evaluate on test set
from sklearn.metrics import classification_report

y_true = tokenized_ds["test"]["labels"]
y_pred = trainer.predict(tokenized_ds["test"]).predictions.argmax(axis=1)
print(classification_report(y_true, y_pred, target_names=["non-urgent", "urgent"]))

## 4. Inference Example

In [ ]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained("./triage_bluebert_lora")
tokenizer = AutoTokenizer.from_pretrained("./triage_bluebert_lora")

def triage_predict(symptom_text):
    inputs = tokenizer(symptom_text, return_tensors="pt", truncation=True, max_length=64)
    with torch.no_grad():
        logits = model(**inputs).logits
        pred = logits.argmax(dim=1).item()
        return "Urgent" if pred == 1 else "Non-Urgent"

# Example inference
test_symptoms = [
    "acute confusion and inability to move arm or leg",
    "runny nose and mild sore throat"
]
for text in test_symptoms:
    print(f'Symptom: {text}\nPredicted triage: {triage_predict(text)}\n')

## 5. (Bonus) API Deployment Guidance
- Export model and tokenizer as above
- Use [FastAPI](https://fastapi.tiangolo.com/) or [Flask](https://flask.palletsprojects.com/) for a REST API:
    - Load model at startup
    - Accept POST requests with a symptom text
    - Return the triage prediction
- Example FastAPI endpoint:
```python
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()

class Input(BaseModel):
    symptom: str

@app.post("/triage")
def triage(input: Input):
    pred = triage_predict(input.symptom)
    return {"triage": pred}
```
- Deploy on [Hugging Face Spaces](https://huggingface.co/spaces), [AWS Lambda](https://aws.amazon.com/lambda/), or your own server.