# Package imports

In [None]:
import os
import pydicom
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from torchvision import transforms, datasets
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


Hyperparameters

In [None]:
dropout_rate  = 0.2       # classifier dropout in ViT
weight_decay  = 1e-2      # optimizer weight decay
learning_rate = 3e-5      # optimizer learning rate
batch_size    = 16        # train & test batch size
num_epochs    = 10        # total epochs
random_seed   = 42

## Data Preprocessing (Resize, Augmentation, Normalization)

In [None]:
class CustomMedicalImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the DICOM image
        dicom_image = pydicom.dcmread(self.image_paths[idx])
        pixel_array = dicom_image.pixel_array

        # Normalize the pixel values to [0, 255] and convert to 3-channel image
        image = Image.fromarray(
            (pixel_array / np.max(pixel_array) * 255).astype(np.uint8)
        ).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


# Define transformations (resize, augment, normalize)
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),  # ViT default mean and std
    ]
)

## Load Image Paths and Labels

In [None]:
image_data_dir = "../download_data/mimic-cxr-download/imageData/"
image_paths = []
for root, _, files in os.walk(image_data_dir):
    for file in files:
        if file.endswith(".dcm"):
            image_path = os.path.join(root, file)
            image_paths.append(image_path)

labels_csv_path = "../download_data/metadata/pleural_effusion_samples.csv"
labels_df = pd.read_csv(labels_csv_path)

# Prepend "s" to study_id in the CSV
labels_df["study_id"] = "s" + labels_df["study_id"].astype(str)

# Create a mapping of study_id to label
labels_mapping = {row["study_id"]: row["class"] for _, row in labels_df.iterrows()}
# Create labels list based on image paths
labels = []
for image_path in image_paths:
    study_id = os.path.basename(os.path.dirname(os.path.dirname(image_path)))
    labels.append(labels_mapping.get(study_id, -1))  # Default to -1 if not found

# add 1 to labels to make them 0, 1, 2 instead of -1, 0, 1
labels = [label + 1 for label in labels]

## Split the dataset into Training and Testing (80:20 ratio)

In [None]:
train_image_paths, test_image_paths, train_labels, test_labels = train_test_split(
    image_paths, labels, test_size=0.2, random_state=random_seed, stratify=labels)

# Create DataLoader for training and testing datasets
train_dataset = CustomMedicalImageDataset(
    image_paths=train_image_paths, labels=train_labels, transform=transform
)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)

test_dataset = CustomMedicalImageDataset(
    image_paths=test_image_paths, labels=test_labels, transform=transform
)
test_loader  = DataLoader(
    test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4
)

## Load Pre-trained ViT Model

In [None]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k", num_labels=3
)  # num_labels=3 for three classes

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Set up Optimizer and Loss Function

In [None]:
model.config.classifier_dropout = dropout_rate

# AdamW optimizer with weight decay
optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)
criterion = nn.CrossEntropyLoss(label_smoothing=0.0)


## Training Loop

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def train_model(model, train_loader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images).logits
            loss = criterion(outputs, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

## Fine-tuning the ViT Model

In [None]:
# Fine-tune for a few epochs to adapt to your dataset
# train_model(model, train_loader, optimizer, criterion, epochs=20)

Epoch 1/20, Loss: 0.8750
Epoch 2/20, Loss: 0.7936
Epoch 3/20, Loss: 0.6959
Epoch 4/20, Loss: 0.6582
Epoch 5/20, Loss: 0.5670
Epoch 6/20, Loss: 0.5255
Epoch 7/20, Loss: 0.5247
Epoch 8/20, Loss: 0.4516
Epoch 9/20, Loss: 0.4267
Epoch 10/20, Loss: 0.3365
Epoch 11/20, Loss: 0.3293
Epoch 12/20, Loss: 0.3055
Epoch 13/20, Loss: 0.2324
Epoch 14/20, Loss: 0.2527
Epoch 15/20, Loss: 0.2037
Epoch 16/20, Loss: 0.2128
Epoch 17/20, Loss: 0.1315
Epoch 18/20, Loss: 0.1009
Epoch 19/20, Loss: 0.0733
Epoch 20/20, Loss: 0.0708


## Evaluation

In [None]:
label_set = [0, 1, 2]
# Define evaluation metrics
def evaluate_model(model, test_loader):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images).logits
            probs  = torch.softmax(outputs, dim=1)

            preds = probs.argmax(dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy()
            )  # Probabilities for AUROC

    precision = precision_score(
        all_labels, all_preds, average="macro", labels=label_set
    )
    recall = recall_score(all_labels, all_preds, average="macro", labels=label_set)
    f1 = f1_score(all_labels, all_preds, average="macro", labels=label_set)
    # Compute AUROC by considering each class one-vs-rest
    auroc = roc_auc_score(
        all_labels, np.array(all_probs), multi_class="ovr", average="macro"
    )

    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"AUROC: {auroc:.4f}")


# Run Training & Evaluation
    
if __name__ == "__main__":
    print("Starting training with:")
    print(f"  dropout_rate  = {dropout_rate}")
    print(f"  weight_decay  = {weight_decay}")
    print(f"  learning_rate = {learning_rate}")
    print(f"  batch_size    = {batch_size}")
    print(f"  num_epochs    = {num_epochs}\n")

    train_model(model, train_loader, optimizer, criterion, num_epochs)
    print("\nEvaluating on test set:")
    evaluate_model(model, test_loader)


Precision: 0.5949
Recall: 0.5683
F1-Score: 0.5649
AUROC: 0.7170
