## Transformer-Based Experiment: Using `medicalai/ClinicalBERT`

This notebook evaluates the performance of the `medicalai/ClinicalBERT` model for multiclass classification of primary progressive aphasia (PPA) subtypes using clinical interview transcripts.

### Objective

To benchmark a transformer-based model against traditional machine learning pipelines by using direct fine-tuning for text classification.

### Preventing Data Leakage

To ensure valid evaluation, a **Group K-Fold cross-validation** strategy is applied:

- Each participant (`SubjectID`) appears in only one fold.
- This ensures that no data from the same individual is present in both training and testing sets, preventing data leakage and overestimation of performance.

### Experiment Details

- **Model**: `medicalai/ClinicalBERT` from Hugging Face Transformers
- **Tokenization**: Applied using `AutoTokenizer` with truncation, padding, and a maximum length of 128 tokens
- **Training**:
  - Optimizer: AdamW
  - Epochs: 10
  - Batch size: 16
- **Evaluation Metrics**:
  - F1-score (weighted)
  - Balanced Accuracy
  - Precision
  - Recall
  - Hamming Loss
  - AUC (One-vs-Rest multiclass setting)

### Dataset Description

The dataset contains transcribed utterances labeled by subtype. It includes four target classes:

- Logopenic Variant (lvPPA)
- Semantic Variant (svPPA)
- Nonfluent Variant (nfvPPA)
- Healthy Controls

Each entry is associated with:
- `SubjectID` (participant ID)
- `Text` (utterance)
- `Subtype` (target label)

### Output

The notebook prints:

- Fold-wise performance metrics
- Averaged scores across all five folds

### Notes

This approach complements other experiments in the study by allowing the transformer model to operate in an end-to-end fine-tuning fashion, rather than as a feature extractor.

In [4]:
from sklearn.metrics import f1_score, balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from sklearn.model_selection import GroupKFold
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import torch
from tqdm import tqdm
import pandas as pd
import io
import os

from sklearn.metrics import (
    f1_score,
    balanced_accuracy_score,
    precision_score,
    recall_score,
    hamming_loss,
    roc_auc_score
)
import copy
import random


In [1]:
# import data here

In [None]:
# For reproducibility
seed_value = 42
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)

In [5]:
label_encoder = LabelEncoder()
df['Subtype'] = label_encoder.fit_transform(df['Subtype'])

In [6]:
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")

In [8]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        text = str(self.texts[index])
        label = self.labels[index]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

MAX_LEN = 128
dataset = TextDataset(df['Text'].to_numpy(), df['Subtype'].to_numpy(), tokenizer, MAX_LEN)


In [9]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [10]:
# define parameters
MAX_LEN = 128
BATCH_SIZE = 16
EPOCHS = 10

# define 5-fold cross-validation
groups = df['SubjectID']  
kf = GroupKFold(n_splits=5)

# initialize metrics storage
f1_scores = []
balanced_accuracies = []
precisions = []
recalls = []
hamming_losses = []
auc_scores = []

# perform cross-validation
for fold, (train_index, val_index) in enumerate(kf.split(df, groups=groups)):
    print(f"\nFold {fold + 1}")

    # split the data for the current fold
    train_texts, val_texts = df.iloc[train_index]['Text'], df.iloc[val_index]['Text']
    train_labels, val_labels = df.iloc[train_index]['Subtype'], df.iloc[val_index]['Subtype']

    # create datasets and dataloaders
    train_dataset = TextDataset(train_texts.to_numpy(), train_labels.to_numpy(), tokenizer, MAX_LEN)
    val_dataset = TextDataset(val_texts.to_numpy(), val_labels.to_numpy(), tokenizer, MAX_LEN)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    # initialize model and optimizer for each fold
    model = AutoModelForSequenceClassification.from_pretrained(
        "medicalai/ClinicalBERT", num_labels=len(df['Subtype'].unique())
    )
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=2e-5)

    # training loop
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}")

    # evaluation loop
    model.eval()
    true_labels = []
    pred_labels = []
    probabilities = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            probs = torch.softmax(logits, dim=-1)  # get probabilities for AUC

            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(predictions.cpu().numpy())
            probabilities.extend(probs.cpu().numpy())

    # calculate metrics for this fold
    f1 = f1_score(true_labels, pred_labels, average='weighted')
    balanced_acc = balanced_accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='weighted')
    recall = recall_score(true_labels, pred_labels, average='weighted')
    hamming = hamming_loss(true_labels, pred_labels)

    # calculate AUC (one-vs-rest for multiclass)
    try:
        auc = roc_auc_score(
            true_labels, probabilities, multi_class='ovr', average='weighted'
        )
    except ValueError:
        auc = np.nan  # handle edge cases where AUC is undefined

    # append metrics for this fold
    f1_scores.append(f1)
    balanced_accuracies.append(balanced_acc)
    precisions.append(precision)
    recalls.append(recall)
    hamming_losses.append(hamming)
    auc_scores.append(auc)

    print(
        f"Fold {fold + 1} - F1-Score: {f1:.4f}, Balanced Accuracy: {balanced_acc:.4f}, "
        f"Precision: {precision:.4f}, Recall: {recall:.4f}, Hamming Loss: {hamming:.4f}, AUC: {auc:.4f}"
    )

# calculate and print the average metrics across all folds
avg_f1 = np.mean(f1_scores)
avg_balanced_acc = np.mean(balanced_accuracies)
avg_precision = np.mean(precisions)
avg_recall = np.mean(recalls)
avg_hamming = np.mean(hamming_losses)
avg_auc = np.nanmean(auc_scores)

print("\n5-Fold Cross-Validation Results:")
print(f"Average F1-Score: {avg_f1:.4f}")
print(f"Average Balanced Accuracy: {avg_balanced_acc:.4f}")
print(f"Average Precision: {avg_precision:.4f}")
print(f"Average Recall: {avg_recall:.4f}")
print(f"Average Hamming Loss: {avg_hamming:.4f}")
print(f"Average AUC: {avg_auc:.4f}")



Fold 1


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/114 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

100%|██████████| 114/114 [01:20<00:00,  1.41it/s]


Epoch 1, Loss: 1.1978942192437356


100%|██████████| 114/114 [01:16<00:00,  1.49it/s]


Epoch 2, Loss: 0.9757279690943266


100%|██████████| 114/114 [01:16<00:00,  1.49it/s]


Epoch 3, Loss: 0.8176007668177286


100%|██████████| 114/114 [01:17<00:00,  1.47it/s]


Epoch 4, Loss: 0.7140979834815913


100%|██████████| 114/114 [01:17<00:00,  1.48it/s]


Epoch 5, Loss: 0.5714350683908713


100%|██████████| 114/114 [01:17<00:00,  1.48it/s]


Epoch 6, Loss: 0.4639645451516436


100%|██████████| 114/114 [01:16<00:00,  1.49it/s]


Epoch 7, Loss: 0.4033481708744116


100%|██████████| 114/114 [01:17<00:00,  1.47it/s]


Epoch 8, Loss: 0.33469880470319796


100%|██████████| 114/114 [01:16<00:00,  1.49it/s]


Epoch 9, Loss: 0.27604547339050395


100%|██████████| 114/114 [01:17<00:00,  1.48it/s]


Epoch 10, Loss: 0.23660970585453406
Fold 1 - F1-Score: 0.6167, Balanced Accuracy: 0.4963, Precision: 0.6541, Recall: 0.5978, Hamming Loss: 0.4022, AUC: 0.7959

Fold 2


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 113/113 [01:17<00:00,  1.46it/s]


Epoch 1, Loss: 1.158223145303473


100%|██████████| 113/113 [01:17<00:00,  1.46it/s]


Epoch 2, Loss: 0.931246665199246


100%|██████████| 113/113 [01:16<00:00,  1.47it/s]


Epoch 3, Loss: 0.7725615145358364


100%|██████████| 113/113 [01:18<00:00,  1.44it/s]


Epoch 4, Loss: 0.6704349446613177


100%|██████████| 113/113 [01:16<00:00,  1.48it/s]


Epoch 5, Loss: 0.5541819811394785


100%|██████████| 113/113 [01:16<00:00,  1.47it/s]


Epoch 6, Loss: 0.43856917097505216


100%|██████████| 113/113 [01:16<00:00,  1.47it/s]


Epoch 7, Loss: 0.3559082484087058


100%|██████████| 113/113 [01:19<00:00,  1.42it/s]


Epoch 8, Loss: 0.2991864084208434


100%|██████████| 113/113 [01:23<00:00,  1.35it/s]


Epoch 9, Loss: 0.23872401988941483


100%|██████████| 113/113 [01:17<00:00,  1.45it/s]


Epoch 10, Loss: 0.22557250196031764
Fold 2 - F1-Score: 0.5677, Balanced Accuracy: 0.5010, Precision: 0.5680, Recall: 0.5768, Hamming Loss: 0.4232, AUC: 0.7923

Fold 3


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 113/113 [01:19<00:00,  1.43it/s]


Epoch 1, Loss: 1.162524049260975


100%|██████████| 113/113 [01:17<00:00,  1.46it/s]


Epoch 2, Loss: 0.9390158460731


100%|██████████| 113/113 [01:18<00:00,  1.43it/s]


Epoch 3, Loss: 0.7884995024816125


100%|██████████| 113/113 [01:16<00:00,  1.48it/s]


Epoch 4, Loss: 0.674178183869978


100%|██████████| 113/113 [01:16<00:00,  1.47it/s]


Epoch 5, Loss: 0.5450555994711092


100%|██████████| 113/113 [01:14<00:00,  1.52it/s]


Epoch 6, Loss: 0.4252887060288834


100%|██████████| 113/113 [01:13<00:00,  1.53it/s]


Epoch 7, Loss: 0.34968222427157175


100%|██████████| 113/113 [01:14<00:00,  1.51it/s]


Epoch 8, Loss: 0.31551559635363846


100%|██████████| 113/113 [01:14<00:00,  1.52it/s]


Epoch 9, Loss: 0.236993528968465


100%|██████████| 113/113 [01:15<00:00,  1.50it/s]


Epoch 10, Loss: 0.20678359858559825
Fold 3 - F1-Score: 0.5073, Balanced Accuracy: 0.4669, Precision: 0.5154, Recall: 0.5077, Hamming Loss: 0.4923, AUC: 0.7452

Fold 4


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 114/114 [01:17<00:00,  1.48it/s]


Epoch 1, Loss: 1.2059516723741566


100%|██████████| 114/114 [01:16<00:00,  1.49it/s]


Epoch 2, Loss: 0.981443931136215


100%|██████████| 114/114 [01:15<00:00,  1.50it/s]


Epoch 3, Loss: 0.8500826353566688


100%|██████████| 114/114 [01:12<00:00,  1.57it/s]


Epoch 4, Loss: 0.7286178929763928


100%|██████████| 114/114 [01:12<00:00,  1.57it/s]


Epoch 5, Loss: 0.6021450927905869


100%|██████████| 114/114 [01:12<00:00,  1.57it/s]


Epoch 6, Loss: 0.5237217410876039


100%|██████████| 114/114 [01:12<00:00,  1.57it/s]


Epoch 7, Loss: 0.4112189434991594


100%|██████████| 114/114 [01:12<00:00,  1.58it/s]


Epoch 8, Loss: 0.3255449426932293


100%|██████████| 114/114 [01:12<00:00,  1.56it/s]


Epoch 9, Loss: 0.28638087505507365


100%|██████████| 114/114 [01:14<00:00,  1.53it/s]


Epoch 10, Loss: 0.2828703832510336
Fold 4 - F1-Score: 0.5953, Balanced Accuracy: 0.4745, Precision: 0.6301, Recall: 0.6173, Hamming Loss: 0.3827, AUC: 0.8073

Fold 5


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 114/114 [01:18<00:00,  1.45it/s]


Epoch 1, Loss: 1.1453742060744971


100%|██████████| 114/114 [01:16<00:00,  1.50it/s]


Epoch 2, Loss: 0.9158372188869276


100%|██████████| 114/114 [01:16<00:00,  1.50it/s]


Epoch 3, Loss: 0.7831282537234457


100%|██████████| 114/114 [01:15<00:00,  1.51it/s]


Epoch 4, Loss: 0.6563846709435446


100%|██████████| 114/114 [01:15<00:00,  1.51it/s]


Epoch 5, Loss: 0.5522841124942428


100%|██████████| 114/114 [01:14<00:00,  1.54it/s]


Epoch 6, Loss: 0.4571115236010468


100%|██████████| 114/114 [01:16<00:00,  1.49it/s]


Epoch 7, Loss: 0.3618010597532256


100%|██████████| 114/114 [01:17<00:00,  1.47it/s]


Epoch 8, Loss: 0.2929921398233426


100%|██████████| 114/114 [01:14<00:00,  1.52it/s]


Epoch 9, Loss: 0.25807741745130014


100%|██████████| 114/114 [01:13<00:00,  1.54it/s]


Epoch 10, Loss: 0.21446907636301035
Fold 5 - F1-Score: 0.4809, Balanced Accuracy: 0.4751, Precision: 0.4956, Recall: 0.4989, Hamming Loss: 0.5011, AUC: 0.7390

5-Fold Cross-Validation Results:
Average F1-Score: 0.5536
Average Balanced Accuracy: 0.4828
Average Precision: 0.5726
Average Recall: 0.5597
Average Hamming Loss: 0.4403
Average AUC: 0.7759
