In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import itertools

pathologies = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
              "Clavicle fracture", "Consolidation", "Emphysema", "Enlarged PA",
              "ILD", "Infiltration", "Lung Opacity", "Lung cavity", "Lung cyst",
              "Mediastinal shift","Nodule/Mass", "Pleural effusion", "Pleural thickening",
              "Pneumothorax", "Pulmonary fibrosis","Rib fracture", "Other lesion",
              "No finding"] # edema in train set not test set but only 19 images
# TODO: ^ does order of pathologies matter? 

In [2]:
test_all_pathology_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/VinDr/no_tuning_4934e91451945c8218c267aae9c34929a7677829/collated_test_all.pkl")
test_all_pathology = pd.read_pickle(test_all_pathology_path)

In [8]:
VinDr_test_train_split_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_train_split.txt")
VinDr_test_val_split_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_val_split.txt")
VinDr_test_test_split_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_test_split.txt")
# split test_all_pathology into test, train, val based on the image_ids in the splits
with open(VinDr_test_train_split_path, 'r') as f:
    test_train_image_ids = f.read().splitlines()
with open(VinDr_test_val_split_path, 'r') as f:
    test_val_image_ids = f.read().splitlines()
with open(VinDr_test_test_split_path, 'r') as f:
    test_test_image_ids = f.read().splitlines()

test_train_pathology = test_all_pathology[test_all_pathology["image_id"].isin(test_train_image_ids)]
test_val_pathology = test_all_pathology[test_all_pathology["image_id"].isin(test_val_image_ids)]
test_test_pathology = test_all_pathology[test_all_pathology["image_id"].isin(test_test_image_ids)]

In [30]:
class VinDrWithBinaryPathologyPresence(Dataset):
    def __init__(self, dataframe, pathology_columns, layer = "post_layer_norm"): # "post_layer_norm"
    # def __init__(self, dataframe, pathology_columns, layer = "q_former"): # "post_layer_norm"
        self.dataframe = dataframe
        self.pathology_columns = pathology_columns
        self.layer = layer

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract embeddings and drop the first dimension i.e. [1,128,768] -> [128,768]
        embeddings = torch.tensor(row[self.layer], dtype=torch.float)
        
        # Extract binary labels for pathologies
        labels_values = row[self.pathology_columns].values
        # set the dtype of each item in values to be an int
        labels_values = [int(item) for item in labels_values]
        labels_tensor = torch.tensor(labels_values, dtype=torch.float)

        # image_id 
        image_id = row['image_id']
        
        return image_id, embeddings, labels_tensor

sample_dataset = VinDrWithBinaryPathologyPresence(test_train_pathology, pathologies)



In [31]:
def calculate_metrics(outputs, labels):
    # Apply sigmoid to get predicted probabilities 
    predictions = torch.sigmoid(outputs)  

    # Threshold for determining positive predictions (adjust as needed)
    predictions = predictions > 0.5   

    # Calculate metrics row-wise, considering all pathologies
    accuracy = (predictions == labels).all(dim=1).float().mean()  
    true_positives = torch.logical_and(predictions, labels).float().sum(dim=1)
    predicted_positives = predictions.sum(dim=1).float()
    actual_positives = labels.sum(dim=1).float()
    
    # Avoid division by zero using clamp
    precision = true_positives / predicted_positives.clamp(min=1)
    recall = true_positives / actual_positives.clamp(min=1)
    f1_score = 2 * (precision * recall) / (precision + recall).clamp(min=1)
    
    # Mean of metrics - After handling division by zero, no need for NaN replacement
    accuracy = accuracy.mean()
    precision = precision.mean()
    recall = recall.mean()
    f1_score = f1_score.mean()
    
    return accuracy.item(), precision.item(), recall.item(), f1_score.item()

In [12]:
class LinearClassifier(nn.Module):
    def __init__(self, num_features, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(num_features, num_classes)
    
    def forward(self, x):
        x = self.linear(x)
        # Shape of x becomes [batch_size, num_classes]
        return x

flattened_q_former_dimension_size = sample_dataset[0][1].view(1,-1).shape[1]
q_former_linear_probe_flattened = LinearClassifier(flattened_q_former_dimension_size, len(pathologies))

In [13]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_linear_probe_early_stopping(model, train_loader, val_loader, num_epochs, criterion, learning_rate, early_stopping_patience = 50):

    early_stopping_patience = round(num_epochs * early_stopping_patience/100)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    best_val_loss = float('inf')
    best_val_accuracy = 0   
    best_epoch = 0
    early_stop = False

    for epoch in range(num_epochs):
        if early_stop:
            print(f'Stopping early at epoch {epoch} of {num_epochs} because no improvement in {early_stopping_patience} epochs.')
            break

        model.train()
        for image_id, embeddings, labels in train_loader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(embeddings.view(embeddings.size(0), -1))
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_accuracy = 0
            for image_id, embeddings, labels in val_loader:
                embeddings = embeddings.to(device)
                labels = labels.to(device)

                outputs = model(embeddings.view(embeddings.size(0), -1))
                val_loss += criterion(outputs, labels)
                batch_accuracy, _,_,_ = calculate_metrics(outputs, labels)
                val_accuracy += batch_accuracy

        val_loss /= len(val_loader)
        val_accuracy /= len(val_loader)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
        elif epoch - best_epoch >= early_stopping_patience:
            early_stop = True

        best_val_accuracy = max(val_accuracy, best_val_accuracy)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Val Loss: {val_loss.item()}, Val Accuracy: {val_accuracy}")

    return best_val_accuracy
    

In [18]:
num_epochs = [10, 20, 40]
learning_rates = [0.0001, 0.001, 0.01]
batch_sizes = [64, 128, 256, 512, 1024, 3000]

# Generate all possible combinations of hyperparameters
n=10
top_n_hyperparam_configurations = []

best_val_accuracy = 0
best_hyperparameters = None

for batch_size in batch_sizes:
    
    train_loader = DataLoader(VinDrWithBinaryPathologyPresence(test_train_pathology, pathologies), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(VinDrWithBinaryPathologyPresence(test_val_pathology, pathologies), batch_size=batch_size, shuffle=True)
    
    for num_epoch, learning_rate in itertools.product(num_epochs, learning_rates):
        q_former_linear_probe_flattened = LinearClassifier(flattened_q_former_dimension_size, len(pathologies))
        print(f"Hyperparameters: batch_size={batch_size}, num_epochs={num_epoch}, learning_rate={learning_rate}")
        
        val_accuracy = train_linear_probe_early_stopping(q_former_linear_probe_flattened, train_loader, val_loader, num_epoch, criterion, learning_rate)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_hyperparameters = (batch_size, num_epoch, learning_rate)
            top_n_hyperparam_configurations.append((batch_size, num_epoch, learning_rate, val_accuracy))
            

Hyperparameters: batch_size=64, num_epochs=10, learning_rate=0.0001
Epoch 1/10, Loss: 0.07202397286891937, Val Loss: 0.10276070982217789, Val Accuracy: 0.6630681872367858
Epoch 2/10, Loss: 0.08086004108190536, Val Loss: 0.09274834394454956, Val Accuracy: 0.6536931872367859
Epoch 3/10, Loss: 0.06350249797105789, Val Loss: 0.0906132236123085, Val Accuracy: 0.6846590995788574
Epoch 4/10, Loss: 0.08738981187343597, Val Loss: 0.0999673530459404, Val Accuracy: 0.6971590995788575
Epoch 5/10, Loss: 0.062247615307569504, Val Loss: 0.09220340847969055, Val Accuracy: 0.6960227251052856
Epoch 6/10, Loss: 0.0742533802986145, Val Loss: 0.09233229607343674, Val Accuracy: 0.70625
Epoch 7/10, Loss: 0.07672018557786942, Val Loss: 0.09487368911504745, Val Accuracy: 0.678125
Epoch 8/10, Loss: 0.08376694470643997, Val Loss: 0.10767143964767456, Val Accuracy: 0.6894886374473572
Stopping early at epoch 8 of 10 because no improvement in 5 epochs.
Hyperparameters: batch_size=64, num_epochs=10, learning_rate=0.

In [19]:

# Print top 10 hyperparameter configurations
# Sort by validation accuracy 
top_n_hyperparam_configurations.sort(key=lambda x: x[3], reverse=True)
print("Top 10 hyperparameter configurations:")
for i, (batch_size, num_epochs, learning_rate, val_accuracy) in enumerate(top_n_hyperparam_configurations[:10]):
    print(f"{i+1}. Batch size: {batch_size}, Num epochs: {num_epochs}, Learning rate: {learning_rate}, Validation accuracy: {val_accuracy}")


Top 10 hyperparameter configurations:
1. Batch size: 256, Num epochs: [10, 20, 40], Learning rate: 0.0001, Validation accuracy: 0.7642045617103577
2. Batch size: 256, Num epochs: [10, 20, 40], Learning rate: 0.0001, Validation accuracy: 0.7626065611839294
3. Batch size: 128, Num epochs: [10, 20, 40], Learning rate: 0.0001, Validation accuracy: 0.7462121248245239
4. Batch size: 64, Num epochs: [10, 20, 40], Learning rate: 0.0001, Validation accuracy: 0.7096590995788574
5. Batch size: 64, Num epochs: [10, 20, 40], Learning rate: 0.0001, Validation accuracy: 0.7071022748947143
6. Batch size: 64, Num epochs: [10, 20, 40], Learning rate: 0.0001, Validation accuracy: 0.70625


In [29]:
# best config: batch_size=256, num_epochs=40, learning_rate=0.0001
# 2nd best config: batch_size=256, num_epochs=10, learning_rate=0.0001
# 3rd best config: bbatch_size=128, num_epochs=10, learning_rate=0.0001
best_hyperparameter = (256, 40, 0.0001)
second_best_hyperparameter = (256, 10, 0.0001)
third_best_hyperparameter = (128, 10, 0.0001)
fourth_best_hyperparameter = (64, 40, 0.0001)
best_hyperparameters_on_val_set = [best_hyperparameter, second_best_hyperparameter, third_best_hyperparameter, fourth_best_hyperparameter]

for best_hyperparameter in best_hyperparameters_on_val_set:
    # train on train set and evaluate on test set
    batch_size, num_epochs, learning_rate = best_hyperparameter
    train_loader = DataLoader(VinDrWithBinaryPathologyPresence(test_train_pathology, pathologies), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(VinDrWithBinaryPathologyPresence(test_val_pathology, pathologies), batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(VinDrWithBinaryPathologyPresence(test_test_pathology, pathologies), batch_size=batch_size, shuffle=True)

    q_former_linear_probe_flattened = LinearClassifier(flattened_q_former_dimension_size, len(pathologies))
    best_val_accuracy = train_linear_probe_early_stopping(q_former_linear_probe_flattened, train_loader, val_loader, num_epochs, criterion, learning_rate)
    q_former_linear_probe_flattened.eval()
    with torch.no_grad():
        test_accuracy = 0
        for image_id, embeddings, labels in test_loader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            outputs = q_former_linear_probe_flattened(embeddings.view(embeddings.size(0), -1))
            batch_accuracy, _,_,_ = calculate_metrics(outputs, labels)
            test_accuracy += batch_accuracy
        test_accuracy /= len(test_loader)
        print(f"Training with best hyperparameters: batch_size={batch_size}, num_epochs={num_epochs}, learning_rate={learning_rate}")
        print(f"Best validation accuracy: {best_val_accuracy}")
        print(f"Test accuracy: {test_accuracy}")






Epoch 1/40, Loss: 0.1554984748363495, Val Loss: 0.1712925136089325, Val Accuracy: 0.6518110930919647
Epoch 2/40, Loss: 0.11412486433982849, Val Loss: 0.12363976240158081, Val Accuracy: 0.6677911877632141
Epoch 3/40, Loss: 0.1028212159872055, Val Loss: 0.10794439911842346, Val Accuracy: 0.7034801244735718
Epoch 4/40, Loss: 0.09268356114625931, Val Loss: 0.08039233088493347, Val Accuracy: 0.7414772808551788
Epoch 5/40, Loss: 0.07767555862665176, Val Loss: 0.0908319428563118, Val Accuracy: 0.6787997186183929
Epoch 6/40, Loss: 0.08528684079647064, Val Loss: 0.07929840683937073, Val Accuracy: 0.7073863744735718
Epoch 7/40, Loss: 0.09020264446735382, Val Loss: 0.08906228840351105, Val Accuracy: 0.6983309686183929
Epoch 8/40, Loss: 0.06281240284442902, Val Loss: 0.08578537404537201, Val Accuracy: 0.6811079680919647
Epoch 9/40, Loss: 0.07869873940944672, Val Loss: 0.08064129948616028, Val Accuracy: 0.7077414989471436
Epoch 10/40, Loss: 0.06050151214003563, Val Loss: 0.0789322480559349, Val Acc