# Fine-tuning BioBERT model

## Package imports

In [6]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, roc_auc_score

## Data preparation

In [7]:
class ReportDataset(Dataset):
    """
    Dataset for radiology reports stored as .txt files.
    Expects directory structure:
      /path/
          s<study_id>/report.txt
          ...
    And a labels file (CSV) with columns: subject_id,study_id,class
    """
    def __init__(self, report_dir, labels_file, tokenizer, max_length=512):
        import pandas as pd
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.reports = []
        self.labels = []

        df = pd.read_csv(labels_file)
        # Adjust labels to be 0, 1, 2 instead of -1, 0, 1
        df['class'] = df['class'] + 1
        label_map = dict(zip(df['study_id'].astype(str), df['class']))

        for case_dir in glob.glob(os.path.join(report_dir, '*/')):
            report_path = os.path.join(case_dir, 'report.txt')
            case_name = os.path.basename(os.path.normpath(case_dir))
            study_id = case_name[1:]  # Remove the 's' prefix to get study_id
            if study_id in label_map and os.path.exists(report_path):
                with open(report_path, 'r', encoding='utf-8') as f:
                    text = f.read().strip()
                self.reports.append(text)
                self.labels.append(label_map[study_id])

    def __len__(self):
        return len(self.reports)

    def __getitem__(self, idx):
        text = self.reports[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        item['labels'] = torch.tensor(label, dtype=torch.long)
        return item

## Compute Metrics

In [8]:
def compute_metrics(y_true, y_pred):
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    acc = accuracy_score(y_true, y_pred)
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1}

## Model Parameters and Tokenization

In [9]:
report_dir = "../download_data/mimic-cxr-download/textData/"
labels_file = "../download_data/metadata/pleural_effusion_brand_new.csv"
model_name = 'dmis-lab/biobert-base-cased-v1.1'
num_labels = 3
batch_size = 8
epochs = 5
lr = 2e-5
weight_decay = 0.01
dropout_rate = 0.1
max_length = 512
tokenizer = BertTokenizer.from_pretrained(model_name)
dataset = ReportDataset(report_dir, labels_file, tokenizer, max_length)

## Dataset splitting

In [10]:
# Split into train/val/test 70:10:20
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
train_ds, val_ds, test_ds = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# DataLoaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
test_loader = DataLoader(test_ds, batch_size=batch_size)

## Model config, optimizer, and scheduler

In [11]:
# Model configuration with dropout
config = BertConfig.from_pretrained(
    model_name,
    num_labels=num_labels,
    hidden_dropout_prob=dropout_rate,
    attention_probs_dropout_prob=dropout_rate
)
model = BertForSequenceClassification.from_pretrained(model_name, config=config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train loop and validation

In [12]:
# Training and validation
epochs = 1
for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}/{epochs} - Train loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            logits = model(**inputs).logits
            preds = torch.argmax(logits, dim=1)
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())
    print("Validation Report:")
    print(classification_report(val_labels, val_preds))

Epoch 1/1 - Train loss: 0.9972
Validation Report:
              precision    recall  f1-score   support

           0       0.62      0.77      0.69        31
           1       0.79      0.93      0.85        40
           2       0.88      0.61      0.72        49

    accuracy                           0.76       120
   macro avg       0.76      0.77      0.75       120
weighted avg       0.78      0.76      0.76       120



## Evaluation

In [16]:
# Test set evaluation
model.eval()
test_preds, test_labels = [], []
all_labels = []
all_probs = []
with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        logits = model(**inputs).logits
        preds = torch.argmax(logits, dim=1)
        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels.cpu().tolist())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(torch.softmax(logits, dim=1).cpu().numpy())
print("Test Report:")
print(classification_report(test_labels, test_preds))

Test Report:
              precision    recall  f1-score   support

           0       0.69      0.72      0.70        82
           1       0.81      0.95      0.87        83
           2       0.77      0.57      0.66        75

    accuracy                           0.75       240
   macro avg       0.75      0.75      0.74       240
weighted avg       0.75      0.75      0.75       240



In [18]:
import numpy as np
auroc = roc_auc_score(
        all_labels, np.array(all_probs), multi_class="ovr", average="macro"
    )
print(f"AUROC: {auroc:.4f}")

AUROC: 0.8986
