In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EvalPrediction
)
from sklearn.metrics import f1_score, roc_auc_score, hamming_loss

# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\U0001F5A5️  Using device: {device}")

# Load dataset
dataset = load_dataset("rntc/mimic-icd-visit")

# Extract unique ICD codes from all splits
all_codes = []
for split in ['train', 'validation', 'test']:
    all_codes.extend(code for example in dataset[split] for code in example['icd_code'])

unique_codes = sorted(set(all_codes))
NUM_ICD_CODES = len(unique_codes)
code_to_index = {code: idx for idx, code in enumerate(unique_codes)}

# Encode labels
def encode_labels(example):
    vec = [0] * NUM_ICD_CODES
    for code in example['icd_code']:
        vec[code_to_index[code]] = 1
    example['labels'] = vec
    return example

dataset = dataset.map(encode_labels)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

def tokenize_function(example):
    return tokenizer(
        example["cleaned_text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    "emilyalsentzer/Bio_ClinicalBERT",
    num_labels=NUM_ICD_CODES,
    problem_type="multi_label_classification"
)
model.to(device)

# Get predictions function
def get_preds(model, dataset):
    model.eval()
    loader = torch.utils.data.DataLoader(dataset, batch_size=4)
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output = model(input_ids, attention_mask=attention_mask)
            logits = torch.sigmoid(output.logits).cpu().numpy()
            preds.append(logits)
            labels.append(batch['labels'].numpy())
    return np.vstack(preds), np.vstack(labels)

# Save predictions from pretrained model
pre_probs, true_labels = get_preds(model, tokenized_datasets['validation'])

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="micro_f1",
    fp16=True if torch.cuda.is_available() else False,
    logging_dir="./logs",
)

# Metrics function
def compute_metrics(eval_pred: EvalPrediction):
    logits, labels = eval_pred
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    preds = (probs > 0.5).astype(int)
    return {
        "micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
        "macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
        "hamming_loss": hamming_loss(labels, preds),
        "roc_auc": roc_auc_score(labels, probs, average="macro")
    }

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
)

# Train
trainer.train()

# Save fine-tuned model
trainer.save_model("./best_bio_clinicalbert_icd_model")
tokenizer.save_pretrained("./best_bio_clinicalbert_icd_model")

# Evaluate fine-tuned model
ft_probs, _ = get_preds(trainer.model, tokenized_datasets['validation'])

# Evaluation function
def evaluate(true, probs):
    preds = (probs > 0.5).astype(int)
    return {
        "micro_f1": f1_score(true, preds, average="micro"),
        "macro_f1": f1_score(true, preds, average="macro"),
        "hamming_loss": hamming_loss(true, preds),
        "roc_auc": roc_auc_score(true, probs, average="macro"),
    }

pre_metrics = evaluate(true_labels, pre_probs)
ft_metrics = evaluate(true_labels, ft_probs)

# Plot comparison
labels = list(pre_metrics.keys())
pre_vals = [pre_metrics[l] for l in labels]
ft_vals = [ft_metrics[l] for l in labels]

x = np.arange(len(labels))
width = 0.35

plt.figure(figsize=(10, 6))
plt.bar(x - width/2, pre_vals, width, label='Pretrained')
plt.bar(x + width/2, ft_vals, width, label='Fine-tuned')

plt.ylabel("Score")
plt.title("Model Performance: Pretrained vs Fine-tuned")
plt.xticks(x, labels)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
