In [1]:
import pandas as pd
from datasets import Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from transformers import TrainingArguments, Trainer
from peft import LoraConfig
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [2]:
load_in_4bit = True
max_seq_length = 2048

# Automatically set dtype based on GPU support
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16

# -------------------------------------------------------------------
# LOAD DATA
# -------------------------------------------------------------------
train_df = pd.read_csv("medical_cases_train/medical_cases_train.csv")
val_df = pd.read_csv("medical_cases_validation/medical_cases_validation.csv")
test_df = pd.read_csv("medical_cases_test/medical_cases_test.csv")

train_set = Dataset.from_pandas(train_df)
val_set = Dataset.from_pandas(val_df)
test_set = Dataset.from_pandas(test_df)

In [3]:
train_set.to_pandas()

Unnamed: 0,description,transcription,sample_name,medical_specialty,keywords
0,Pacemaker ICD interrogation. Severe nonischem...,"PROCEDURE NOTE: , Pacemaker ICD interrogation....",Pacemaker Interrogation,Cardiovascular / Pulmonary,"cardiovascular / pulmonary, cardiomyopathy, ve..."
1,"Erythema of the right knee and leg, possible s...","PREOPERATIVE DIAGNOSES: , Erythema of the righ...",Aspiration - Knee Joint,Orthopedic,"orthopedic, knee and leg, anterolateral portal..."
2,Left cardiac catheterization with selective ri...,"PREOPERATIVE DIAGNOSIS: , Post infarct angina....",Cardiac Cath & Selective Coronary Angiography,Cardiovascular / Pulmonary,"cardiovascular / pulmonary, selective, angiogr..."
3,Patient with a history of coronary artery dise...,"REASON FOR VISIT: , Acute kidney failure.,HIST...",Acute Kidney Failure,Nephrology,
4,Cardiac evaluation and treatment in a patient ...,"REASON FOR REFERRAL: , Cardiac evaluation and ...",Cardiac Consultation - 6,Cardiovascular / Pulmonary,
...,...,...,...,...,...
1719,"Arthroscopy of the left knee, left arthroscopi...","PREOPERATIVE DIAGNOSIS:, Medial meniscal tear...","Arthroscopy, Meniscoplasty, & Chondroplasty",Orthopedic,"orthopedic, medial meniscoplasty, arthroscopic..."
1720,Normal awake and drowsy (stage I sleep) EEG fo...,"DESCRIPTION OF RECORD: ,This tracing was obta...",Electroencephalogram,Neurology,"neurology, gold-plated surface disc electrodes..."
1721,MRI of the brain without contrast to evaluate ...,"EXAM: , MRI of the brain without contrast.,HIS...",MRI of Brain w/o Contrast.,Neurology,"neurology, mri, diffusion, posterior fossa, ax..."
1722,The patient comes for three-week postpartum ch...,"CHIEF COMPLAINT:, The patient comes for three...",Three-Week Postpartum Checkup,Obstetrics / Gynecology,"obstetrics / gynecology, checkup, allergies, p..."


In [9]:
torch.cuda.empty_cache()

In [4]:
# -------------------------------------------------------------------
# FORMAT PROMPTS
# -------------------------------------------------------------------
def format_prompt(example):
    return {
        "text": f"<start_of_turn>user\nDescription:{example['description']}<end_of_turn> \
        \n<start_of_turn>model\n{example['medical_specialty']}<end_of_turn>"
    }

train_dataset = train_set.map(format_prompt)
val_dataset = val_set.map(format_prompt)
test_dataset = test_set.map(format_prompt)


# -------------------------------------------------------------------
# LOAD MODEL
# -------------------------------------------------------------------
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

tokenizer.pad_token = tokenizer.eos_token

# -------------------------------------------------------------------
# APPLY LoRA
# -------------------------------------------------------------------
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

# -------------------------------------------------------------------
# TOKENIZATION
# -------------------------------------------------------------------

def tokenize(example):
    tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)
    input_ids = tokens["input_ids"]
    labels = input_ids.copy()
    
    # Optional: mask labels before <start_of_turn>model
    model_start = tokenizer("<start_of_turn>model")["input_ids"]
    try:
        model_start_index = [input_ids.index(tok) for tok in model_start][0]
    except:
        model_start_index = 0  # fallback if not found

    labels[:model_start_index] = [-100] * model_start_index
    tokens["labels"] = labels
    return tokens


train_dataset = train_dataset.map(tokenize, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(tokenize, remove_columns=val_dataset.column_names)
test_dataset = test_dataset.map(tokenize, remove_columns=test_dataset.column_names)



# -------------------------------------------------------------------
# TRAINING ARGUMENTS
# -------------------------------------------------------------------
training_args = TrainingArguments(
    output_dir="./llama-lora-medical",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    num_train_epochs=5,
    learning_rate=2e-4,
    fp16=(dtype == torch.float16),
    bf16=(dtype == torch.bfloat16),
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none"
)

# -------------------------------------------------------------------
# TRAINER
# -------------------------------------------------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

trainer.train()

Map:   0%|          | 0/1724 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.1.
   \\   /|    NVIDIA RTX A2000 12GB. Num GPUs = 1. Max memory: 11.757 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.3.19 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


Map:   0%|          | 0/1724 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

  trainer = Trainer(
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,724 | Num Epochs = 5 | Total steps = 1,075
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 11,272,192/1,000,000,000 (1.13% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,0.5448
20,0.3202
30,0.2278
40,0.2187
50,0.1873
60,0.1891
70,0.1631
80,0.1567
90,0.1567
100,0.1794


TrainOutput(global_step=1075, training_loss=0.1017534526420194, metrics={'train_runtime': 3738.0109, 'train_samples_per_second': 2.306, 'train_steps_per_second': 0.288, 'total_flos': 2.595915792855859e+16, 'train_loss': 0.1017534526420194, 'epoch': 4.979118329466357})

In [5]:
# -------------------------------------------------------------------
# SAVE MODEL
# -------------------------------------------------------------------
model.save_pretrained("./llama-lora-medical")
tokenizer.save_pretrained("./llama-lora-medical")

torch.cuda.empty_cache()

In [6]:

# -------------------------------------------------------------------
# SETUP
# -------------------------------------------------------------------
target_classes = sorted(np.unique(test_df["medical_specialty"]))
target_classes_str = "\n".join(target_classes)

model.eval()

y_pt = []
y_gt = []

# Clear logs
open("llama.txt", "w").close()
open("llama_unknown.txt", "w").close()

print("\n=== Predictions on Test Set ===\n")

# -------------------------------------------------------------------
# MATCHING FUNCTION
# -------------------------------------------------------------------
def match_class(prediction_raw, target_classes):
    pred = prediction_raw.lower().strip()

    # Exact match
    for cls in target_classes:
        if pred == cls.lower():
            return cls

    # Substring match
    for cls in target_classes:
        if cls.lower() in pred:
            return cls

    # Word overlap
    pred_words = set(pred.split())
    for cls in target_classes:
        cls_words = set(cls.lower().split())
        if pred_words & cls_words:
            return cls

    return "Unknown"

# -------------------------------------------------------------------
# INFERENCE LOOP
# -------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

for i in tqdm(range(len(test_df))):
    true_label = test_df.iloc[i]["medical_specialty"]
    description = test_df.iloc[i]["description"]

    prompt = f"""Classify the following medical case description into one of the following medical specialties.

Respond with only the name of the specialty. One-word answer. No explanations.

Choices:
{target_classes_str}

Description:
{description}

Medical Specialty:"""

    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=20,
            do_sample=False
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    prediction_raw = decoded.split("Medical Specialty:")[-1].strip()

    matched_class = match_class(prediction_raw, target_classes)

    if matched_class == "Unknown":
        with open("llama_unknown.txt", "a") as f:
            f.write(f"[Unknown] Raw prediction: {prediction_raw}\nDescription: {description}\n\n")

    y_pt.append(matched_class)
    y_gt.append(true_label)

    with open("llama.txt", "a") as f:
        f.write(f"Prediction: {matched_class}\n")
        f.write(f"True Label: {true_label}\n\n")

# -------------------------------------------------------------------
# EVALUATION
# -------------------------------------------------------------------
filtered_preds = [p for p in y_pt if p != "Unknown"]
filtered_truth = [t for p, t in zip(y_pt, y_gt) if p != "Unknown"]

print("\n=== Evaluation Metrics (Excluding 'Unknown') ===")
print(f"Total predictions: {len(y_pt)}")
print(f"Unknown predictions: {y_pt.count('Unknown')}")
print("Accuracy:", accuracy_score(filtered_truth, filtered_preds))
print("Precision:", precision_score(filtered_truth, filtered_preds, average='macro', zero_division=0))
print("Recall:", recall_score(filtered_truth, filtered_preds, average='macro', zero_division=0))
print("F1 Score:", f1_score(filtered_truth, filtered_preds, average='macro', zero_division=0))



=== Predictions on Test Set ===



100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 370/370 [02:07<00:00,  2.90it/s]


=== Evaluation Metrics (Excluding 'Unknown') ===
Total predictions: 370
Unknown predictions: 2
Accuracy: 0.7038043478260869
Precision: 0.5771016963732423
Recall: 0.549551446793557
F1 Score: 0.52881573164824





=== Predictions on Test Set ===

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 370/370 [02:07<00:00,  2.90it/s]

=== Evaluation Metrics (Excluding 'Unknown') ===

Total predictions: 370

Unknown predictions: 2

Accuracy: 0.7038043478260869

Precision: 0.5771016963732423

Recall: 0.549551446793557

F1 Score: 0.52881573164824