# Fine-tuning BioBERT model

## Package imports

In [14]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, roc_auc_score

## Data preparation

In [15]:
report_dir = "../download_data/textData/"
labels_file = "../download_data/metadata/edema+pleural_effusion_samples_v2.csv"
model_name = 'dmis-lab/biobert-base-cased-v1.1'
max_length = 256
tokenizer = BertTokenizer.from_pretrained(model_name)
class ReportDataset(Dataset):
    """
    Dataset for radiology reports stored as .txt files.
    Expects directory structure:
      /path/
          s<study_id>/report.txt
          ...
    And a labels file (CSV) with columns: study_id, edema, effusion
    """
    def __init__(self, report_dir, labels_file, tokenizer, max_length=512):
        import pandas as pd
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.reports = []
        self.labels = []
        # print(f"Loading reports from {report_dir} and labels from {labels_file}")

        df = pd.read_csv(labels_file)
        print(df)
        df['study_id'] = df['study_id'].astype(str)
        label_map = df.set_index('study_id')[['edema', 'effusion']].to_dict(orient='index')
        
        case_dirs = glob.glob(os.path.join(report_dir, '*/'))
       
        

        for case_dir in glob.glob(os.path.join(report_dir, '*/')):
            
            report_path = os.path.join(case_dir, 'report.txt')
            if not os.path.exists(report_path):
                print(f"Warning: Report file {report_path} does not exist, skipping.")
                continue
            
            case_name = os.path.basename(os.path.normpath(case_dir))
            
            study_id = case_name[1:]  # Remove the 's' prefix
            

            if study_id in label_map and os.path.exists(report_path):
                with open(report_path, 'r', encoding='utf-8') as f:
                    text = f.read().strip()
                self.reports.append(text)
                # Convert edema/effusion dict to list of 2 numbers
                labels = [label_map[study_id]['edema'], label_map[study_id]['effusion']]
                self.labels.append(labels)

    def __len__(self):
        return len(self.reports)

    def __getitem__(self, idx):
        text = self.reports[idx]
        labels = self.labels[idx]  # 2-element list [edema, effusion]

        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        item['labels'] = torch.tensor(labels, dtype=torch.float32)  # multi-label → float32 (for BCEWithLogitsLoss)
        return item
dataset = ReportDataset(report_dir, labels_file, tokenizer, max_length)
print(len(dataset)) 



      subject_id  study_id  edema  effusion
0       10068880  59302019      0         0
1       10095570  53351393      0         0
2       10119992  55836834      0         0
3       10124346  53643077      0         0
4       10150279  58104553      0         1
...          ...       ...    ...       ...
4995    19470900  51190081      1         1
4996    19275656  57501213      1         1
4997    17513800  52812616      1         1
4998    13506966  58911390      0         0
4999    11861017  53980574      1         1

[5000 rows x 4 columns]
5000


## Compute Metrics

In [16]:
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score
)

def compute_metrics(y_true, y_pred, y_probs):
    """
    y_true: true binary labels, shape (n_samples, n_labels)
    y_pred: predicted binary labels, shape (n_samples, n_labels)
    y_probs: predicted probabilities, shape (n_samples, n_labels)
    """
    metrics = {}
    for i, name in enumerate(['edema', 'effusion']):
        # Basic metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true[:, i], y_pred[:, i], average='binary', zero_division=0
        )
        acc = accuracy_score(y_true[:, i], y_pred[:, i])

        # AUROC
        try:
            auroc = roc_auc_score(y_true[:, i], y_probs[:, i])
        except ValueError:
            auroc = float('nan')  # if only one class present

        # AUPRC
        try:
            auprc = average_precision_score(y_true[:, i], y_probs[:, i])
        except ValueError:
            auprc = float('nan')

        # Save metrics
        metrics[name] = {
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'auroc': auroc,
            'auprc': auprc
        }

    return metrics


## Model Parameters and Tokenization

In [17]:
report_dir = "../download_data/textData/"
labels_file = "../download_data/metadata/edema+pleural_effusion_samples_v2.csv"
model_name = 'dmis-lab/biobert-base-cased-v1.1'
num_labels = 2
batch_size = 32
epochs = 5
lr = 2e-5
weight_decay = 0.01
dropout_rate = 0.1
max_length = 512
tokenizer = BertTokenizer.from_pretrained(model_name)
dataset = ReportDataset(report_dir, labels_file, tokenizer, max_length)



      subject_id  study_id  edema  effusion
0       10068880  59302019      0         0
1       10095570  53351393      0         0
2       10119992  55836834      0         0
3       10124346  53643077      0         0
4       10150279  58104553      0         1
...          ...       ...    ...       ...
4995    19470900  51190081      1         1
4996    19275656  57501213      1         1
4997    17513800  52812616      1         1
4998    13506966  58911390      0         0
4999    11861017  53980574      1         1

[5000 rows x 4 columns]


## Dataset splitting

In [18]:
# Split into train/val/test 70:10:20
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
print(total_size, train_size, val_size, test_size)
train_ds, val_ds, test_ds = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# DataLoaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
test_loader = DataLoader(test_ds, batch_size=batch_size)

5000 3500 500 1000


## Model config, optimizer, and scheduler

In [19]:
config = BertConfig.from_pretrained(
    model_name,
    num_labels=num_labels,
    hidden_dropout_prob=dropout_rate,
    attention_probs_dropout_prob=dropout_rate
)

model = BertForSequenceClassification.from_pretrained(
    model_name,
    config=config
)

# Tell the model it’s for multi-label classification
model.config.problem_type = "multi_label_classification"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)


Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

## Train loop and validation

In [None]:
# 🟢 Training for 20 epochs with early stopping
epochs = 20
patience = 3  # Stop if no improvement in 3 epochs
best_val_loss = float('inf')
no_improve_epochs = 0

for epoch in range(1, epochs + 1):
    # ----------------
    # 🔥 Training phase
    # ----------------
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}/{epochs} - Train loss: {avg_train_loss:.4f}")

    # -------------------
    # 🧪 Validation phase
    # -------------------
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            logits = model(**inputs).logits
            loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels)
            val_loss += loss.item()
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).int()
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch}/{epochs} - Val loss: {avg_val_loss:.4f}")
    print("Validation Report:")
    print(classification_report(val_labels, val_preds, target_names=['edema', 'effusion'], zero_division=0))

    # -------------------
    # ⏳ Early Stopping Check
    # -------------------
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        no_improve_epochs = 0
        # Save best model if needed
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        no_improve_epochs += 1
        if no_improve_epochs >= patience:
            print(f"Early stopping at epoch {epoch}")
            break


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## Evaluation

In [56]:
model.eval()
test_preds, test_labels = [], []
all_labels = []
all_probs = []

with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        logits = model(**inputs).logits
        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).int()
        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels.cpu().tolist())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

print("Test Classification Report:")
print(classification_report(test_labels, test_preds, target_names=['edema', 'effusion'], zero_division=0))

# 👉 **Here you calculate AUROC, AUPRC, etc. for each disease!**
import numpy as np
y_true = np.array(all_labels)
y_pred = np.array(test_preds)
y_probs = np.array(all_probs)

test_metrics = compute_metrics(y_true, y_pred, y_probs)
print("Test Metrics:")
print(test_metrics)


Test Classification Report:
              precision    recall  f1-score   support

       edema       0.99      0.99      0.99       491
    effusion       1.00      0.98      0.99       564

   micro avg       0.99      0.98      0.99      1055
   macro avg       0.99      0.98      0.99      1055
weighted avg       0.99      0.98      0.99      1055
 samples avg       0.63      0.63      0.63      1055

Test Metrics:
{'edema': {'accuracy': 0.988, 'precision': 0.9897750511247444, 'recall': 0.9857433808553971, 'f1': 0.9877551020408163, 'auroc': np.float64(0.9968469784210083), 'auprc': np.float64(0.9973469973208513)}, 'effusion': {'accuracy': 0.984, 'precision': 0.9963768115942029, 'recall': 0.975177304964539, 'f1': 0.985663082437276, 'auroc': np.float64(0.996466100592101), 'auprc': np.float64(0.9975877306845223)}}


In [22]:
import torch
print(torch.cuda.get_device_name(0))
print(torch.cuda.memory_summary(device=0, abbreviated=False))

NVIDIA GeForce RTX 4050 Laptop GPU
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  11931 MiB |  12027 MiB |  49494 GiB |  49482 GiB |
|       from large pool |  11927 MiB |  12023 MiB |  49299 GiB |  49287 GiB |
|       from small pool |      3 MiB |      4 MiB |    195 GiB |    195 GiB |
|---------------------------------------------------------------------------|
| Active memory         |  11931 MiB |  12027 MiB |  49494 GiB |  49482 GiB |
|       from large pool |  11927 MiB |  12023 MiB |  49299 GiB |  49287 GiB |
|       from small pool |      3 MiB |      4 MiB |    195 GiB |    195 GiB |
|----------------------------