In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import itertools

Dataset generic linear probing code

In [2]:
class BinaryMultiPathologyPresenceDataset(Dataset):
    def __init__(self, dataframe, pathology_columns, layer = "post_layer_norm"): # "q_former"
        self.dataframe = dataframe
        self.pathology_columns = pathology_columns
        self.layer = layer

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract embeddings and drop the first dimension i.e. [1,128,768] -> [128,768]
        embeddings = torch.tensor(row[self.layer], dtype=torch.float)
        
        # Extract binary labels for pathologies
        labels_values = row[self.pathology_columns].values
        # set the dtype of each item in values to be an int
        labels_values = [int(item) for item in labels_values]
        labels_tensor = torch.tensor(labels_values, dtype=torch.float)

        # image_id 
        image_id = row['image_id']
        
        return image_id, embeddings, labels_tensor

In [3]:
def calculate_metrics(outputs, labels):
    # Apply sigmoid to get predicted probabilities 
    predictions = torch.sigmoid(outputs)  

    # Threshold for determining positive predictions (adjust as needed)
    predictions = predictions > 0.5   

    # Calculate metrics row-wise, considering all pathologies
    accuracy = (predictions == labels).all(dim=1).float().mean()  
    true_positives = torch.logical_and(predictions, labels).float().sum(dim=1)
    predicted_positives = predictions.sum(dim=1).float()
    actual_positives = labels.sum(dim=1).float()
    
    # Avoid division by zero using clamp
    precision = true_positives / predicted_positives.clamp(min=1)
    recall = true_positives / actual_positives.clamp(min=1)
    f1_score = 2 * (precision * recall) / (precision + recall).clamp(min=1)
    
    # Mean of metrics - After handling division by zero, no need for NaN replacement
    accuracy = accuracy.mean()
    precision = precision.mean()
    recall = recall.mean()
    f1_score = f1_score.mean()
    
    return accuracy.item(), precision.item(), recall.item(), f1_score.item()

In [4]:
class LinearClassifier(nn.Module):
    def __init__(self, num_features, num_classes):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(num_features, num_classes)
    
    def forward(self, x):
        x = self.linear(x)
        # Shape of x becomes [batch_size, num_classes]
        return x

In [7]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_linear_probe_early_stopping(model, train_loader, val_loader, num_epochs, criterion, learning_rate, early_stopping_patience = 30):

    early_stopping_patience = round(num_epochs * early_stopping_patience/100)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    best_val_loss = float('inf')
    best_val_accuracy = 0   
    best_epoch = 0
    early_stop = False

    for epoch in range(num_epochs):
        if early_stop:
            print(f'Stopping early at epoch {epoch} of {num_epochs} because no improvement in {early_stopping_patience} epochs.')
            break

        model.train()
        for image_id, embeddings, labels in train_loader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(embeddings.view(embeddings.size(0), -1))
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_accuracy = 0
            for image_id, embeddings, labels in val_loader:
                embeddings = embeddings.to(device)
                labels = labels.to(device)

                outputs = model(embeddings.view(embeddings.size(0), -1))
                val_loss += criterion(outputs, labels)
                batch_accuracy, _,_,_ = calculate_metrics(outputs, labels)
                val_accuracy += batch_accuracy

        val_loss /= len(val_loader)
        val_accuracy /= len(val_loader)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
        elif epoch - best_epoch >= early_stopping_patience:
            early_stop = True

        best_val_accuracy = max(val_accuracy, best_val_accuracy)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Val Loss: {val_loss.item()}, Val Accuracy: {val_accuracy}")

    return best_val_accuracy
    

VinDr specific code

In [27]:
vindr_pathologies = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
              "Clavicle fracture", "Consolidation", "Emphysema", "Enlarged PA",
              "ILD", "Infiltration", "Lung Opacity", "Lung cavity", "Lung cyst",
              "Mediastinal shift","Nodule/Mass", "Pleural effusion", "Pleural thickening",
              "Pneumothorax", "Pulmonary fibrosis","Rib fracture", "Other lesion",
              "No finding"] # edema in train set not test set but only 19 images
# TODO: ^ does order of pathologies matter? 

test_all_pathology_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/VinDr/no_tuning_4934e91451945c8218c267aae9c34929a7677829/collated_test_all.pkl")
test_all_pathology = pd.read_pickle(test_all_pathology_path)



In [28]:
VinDr_test_train_split_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_train_split.txt")
VinDr_test_val_split_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_val_split.txt")
VinDr_test_test_split_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_test_split.txt")
# split test_all_pathology into test, train, val based on the image_ids in the splits
with open(VinDr_test_train_split_path, 'r') as f:
    test_train_image_ids = f.read().splitlines()
with open(VinDr_test_val_split_path, 'r') as f:
    test_val_image_ids = f.read().splitlines()
with open(VinDr_test_test_split_path, 'r') as f:
    test_test_image_ids = f.read().splitlines()

test_train_pathology = test_all_pathology[test_all_pathology["image_id"].isin(test_train_image_ids)]
test_val_pathology = test_all_pathology[test_all_pathology["image_id"].isin(test_val_image_ids)]
test_test_pathology = test_all_pathology[test_all_pathology["image_id"].isin(test_test_image_ids)]

In [17]:
layers = ["post_layer_norm", "q_former"]
batch_sizes = [64, 128, 256, 512, 1024, 3000]
num_epochs = [10, 20, 40]
learning_rates = [0.0001, 0.001, 0.01]

best_val_accuracy_vindr = 0
best_hyperparameters_vindr = None
top_n_hyperparam_configurations_vindr = []



for layer in layers:
    flattened_layer_dimension_size = 128*768 if layer == "q_former" else 1408 # layer = "post_layer_norm"
    for batch_size in batch_sizes:
        
        train_loader = DataLoader(BinaryMultiPathologyPresenceDataset(test_train_pathology, vindr_pathologies,layer), batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(test_val_pathology, vindr_pathologies,layer), batch_size=batch_size, shuffle=True)
        
        for num_epoch, learning_rate in itertools.product(num_epochs, learning_rates):
            layer_linear_probe_flattened = LinearClassifier(flattened_layer_dimension_size, len(vindr_pathologies))
            print(f"Hyperparameters: layer={layer} batch_size={batch_size}, num_epochs={num_epoch}, learning_rate={learning_rate}")
            
            val_accuracy = train_linear_probe_early_stopping(layer_linear_probe_flattened, train_loader, val_loader, num_epoch, criterion, learning_rate)
            if val_accuracy > best_val_accuracy_vindr:
                best_val_accuracy = val_accuracy
                best_hyperparameters_vindr = (batch_size, num_epoch, learning_rate)
                top_n_hyperparam_configurations_vindr.append((layer, batch_size, num_epoch, learning_rate, val_accuracy))
            

Hyperparameters: layer=post_layer_norm batch_size=64, num_epochs=10, learning_rate=0.0001
Epoch 1/10, Loss: 0.25064048171043396, Val Loss: 0.20408566296100616, Val Accuracy: 0.6758522748947143
Epoch 2/10, Loss: 0.1991337090730667, Val Loss: 0.15143299102783203, Val Accuracy: 0.6815340995788575
Epoch 3/10, Loss: 0.12995848059654236, Val Loss: 0.13291417062282562, Val Accuracy: 0.6875
Epoch 4/10, Loss: 0.06664096564054489, Val Loss: 0.12442968040704727, Val Accuracy: 0.6803977251052856
Epoch 5/10, Loss: 0.06215767189860344, Val Loss: 0.11543500423431396, Val Accuracy: 0.6815340995788575
Epoch 6/10, Loss: 0.09381000697612762, Val Loss: 0.11150573939085007, Val Accuracy: 0.6772727251052857
Epoch 7/10, Loss: 0.09265833348035812, Val Loss: 0.107738196849823, Val Accuracy: 0.66875
Epoch 8/10, Loss: 0.11733511835336685, Val Loss: 0.10171462595462799, Val Accuracy: 0.6920454621315002
Epoch 9/10, Loss: 0.08838209509849548, Val Loss: 0.09971822798252106, Val Accuracy: 0.6880681872367859
Epoch 10/

In [22]:

# Print top 10 hyperparameter configurations
# Sort by validation accuracy 
top_n_hyperparam_configurations_vindr.sort(key=lambda x: x[4], reverse=True)
print("Top 10 hyperparameter configurations:")
for i, (layer, batch_size, num_epochs, learning_rate, val_accuracy) in enumerate(top_n_hyperparam_configurations_vindr[:10]):
    print(f"Layer: {layer} Batch size: {batch_size}, Num epochs: {num_epochs}, Learning rate: {learning_rate}, Validation accuracy: {val_accuracy}")


Top 10 hyperparameter configurations:
Layer: post_layer_norm Batch size: 256, Num epochs: 20, Learning rate: 0.01, Validation accuracy: 0.7716619372367859
Layer: post_layer_norm Batch size: 256, Num epochs: 10, Learning rate: 0.01, Validation accuracy: 0.7638494372367859
Layer: post_layer_norm Batch size: 256, Num epochs: 20, Learning rate: 0.001, Validation accuracy: 0.7602983117103577
Layer: q_former Batch size: 256, Num epochs: 40, Learning rate: 0.0001, Validation accuracy: 0.7567471861839294
Layer: post_layer_norm Batch size: 256, Num epochs: 40, Learning rate: 0.001, Validation accuracy: 0.7508878111839294
Layer: post_layer_norm Batch size: 256, Num epochs: 40, Learning rate: 0.01, Validation accuracy: 0.7469815611839294
Layer: q_former Batch size: 256, Num epochs: 20, Learning rate: 0.0001, Validation accuracy: 0.7434304058551788
Layer: q_former Batch size: 256, Num epochs: 20, Learning rate: 0.001, Validation accuracy: 0.7411221861839294
Layer: q_former Batch size: 128, Num epo

In [34]:
setting_1_vindr = ("post_layer_norm",256,20,0.01)
setting_2_vindr = ("post_layer_norm",256,10,0.01)
setting_3_vindr = ("post_layer_norm",256,20,0.001)
setting_4_vindr = ("q_former",256,40,0.0001)
setting_5_vindr = ("q_former",256,20,0.0001)
setting_6_vindr = ("q_former",256,20,0.001)

best_vindr_hyperparameters = [setting_1_vindr, setting_2_vindr, setting_3_vindr, setting_4_vindr, setting_5_vindr, setting_6_vindr]

for best_hyperparameter in best_vindr_hyperparameters:
    # train on train set and evaluate on test set
    layer, batch_size, num_epochs, learning_rate = best_hyperparameter
    flattened_layer_dimension_size = 128*768 if layer == "q_former" else 1408 # layer = "post_layer_norm"

    # create join train and val loaders to train on both
    train_val_pathology = pd.concat([test_train_pathology, test_val_pathology])
    train_val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(train_val_pathology, vindr_pathologies,layer), batch_size=batch_size, shuffle=True)

    test_loader = DataLoader(BinaryMultiPathologyPresenceDataset(test_test_pathology, vindr_pathologies,layer), batch_size=batch_size, shuffle=True)

    linear_probe_flattened = LinearClassifier(flattened_layer_dimension_size, len(vindr_pathologies))
    best_test_accuracy = train_linear_probe_early_stopping(linear_probe_flattened, train_val_loader, test_loader, num_epochs, criterion, learning_rate)
    print(f"Layer: {layer} Batch size: {batch_size}, Num epochs: {num_epochs}, Learning rate: {learning_rate}, \n Test accuracy: {best_test_accuracy}")


Epoch 1/20, Loss: 0.3526707887649536, Val Loss: 0.2661289870738983, Val Accuracy: 0.6452158391475677
Epoch 2/20, Loss: 0.1921747326850891, Val Loss: 0.21238574385643005, Val Accuracy: 0.6364569664001465
Epoch 3/20, Loss: 0.15520823001861572, Val Loss: 0.14742806553840637, Val Accuracy: 0.6625523269176483
Epoch 4/20, Loss: 0.1160830706357956, Val Loss: 0.13621190190315247, Val Accuracy: 0.665371298789978
Epoch 5/20, Loss: 0.12750495970249176, Val Loss: 0.11350804567337036, Val Accuracy: 0.6690359115600586
Epoch 6/20, Loss: 0.08672092109918594, Val Loss: 0.10714013874530792, Val Accuracy: 0.6775531470775604
Epoch 7/20, Loss: 0.06730435788631439, Val Loss: 0.0994916632771492, Val Accuracy: 0.6623912453651428
Epoch 8/20, Loss: 0.048101577907800674, Val Loss: 0.09674905985593796, Val Accuracy: 0.6716937720775604
Epoch 9/20, Loss: 0.05710184946656227, Val Loss: 0.09347943961620331, Val Accuracy: 0.6787209808826447
Epoch 10/20, Loss: 0.0601414293050766, Val Loss: 0.09009617567062378, Val Accu

In [37]:
# Save best VinDr Models at each layer
best_post_layer_norm_setting = ("post_layer_norm",256,20,0.001)
best_q_former_setting = ("q_former",256,20,0.0001)

# train models on train and eval set 
train_val_pathology = pd.concat([test_train_pathology, test_val_pathology])
train_val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(train_val_pathology, vindr_pathologies,best_post_layer_norm_setting[0]), batch_size=best_post_layer_norm_setting[1], shuffle=True)
test_loader = DataLoader(BinaryMultiPathologyPresenceDataset(test_test_pathology, vindr_pathologies,best_post_layer_norm_setting[0]), batch_size=best_post_layer_norm_setting[1], shuffle=True)

post_layer_norm_model = LinearClassifier(1408, len(vindr_pathologies))
train_linear_probe_early_stopping(post_layer_norm_model, train_val_loader, test_loader, best_post_layer_norm_setting[2], criterion, best_post_layer_norm_setting[3])
torch.save(post_layer_norm_model.state_dict(), "post_layer_norm_best_vindr_model.pth")

train_val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(train_val_pathology, vindr_pathologies,best_q_former_setting[0]), batch_size=best_q_former_setting[1], shuffle=True)
test_loader = DataLoader(BinaryMultiPathologyPresenceDataset(test_test_pathology, vindr_pathologies,best_q_former_setting[0]), batch_size=best_q_former_setting[1], shuffle=True)

q_former_model = LinearClassifier(128*768, len(vindr_pathologies))
train_linear_probe_early_stopping(q_former_model, train_val_loader, test_loader, best_q_former_setting[2], criterion, best_q_former_setting[3])
torch.save(q_former_model.state_dict(), "q_former_best_vindr_model.pth")


Epoch 1/20, Loss: 0.14910230040550232, Val Loss: 0.146973118185997, Val Accuracy: 0.6780968010425568
Epoch 2/20, Loss: 0.12837937474250793, Val Loss: 0.11619473993778229, Val Accuracy: 0.6615455746650696
Epoch 3/20, Loss: 0.11557658761739731, Val Loss: 0.09643515199422836, Val Accuracy: 0.6769289672374725
Epoch 4/20, Loss: 0.07889621704816818, Val Loss: 0.08847151696681976, Val Accuracy: 0.6677069664001465
Epoch 5/20, Loss: 0.0939568430185318, Val Loss: 0.08578624576330185, Val Accuracy: 0.6838756203651428
Epoch 6/20, Loss: 0.08463204652070999, Val Loss: 0.08236835896968842, Val Accuracy: 0.6748147308826447
Epoch 7/20, Loss: 0.06663776934146881, Val Loss: 0.08184206485748291, Val Accuracy: 0.6819224953651428
Epoch 8/20, Loss: 0.08060187101364136, Val Loss: 0.08088463544845581, Val Accuracy: 0.6787209808826447
Epoch 9/20, Loss: 0.08126815408468246, Val Loss: 0.08013282716274261, Val Accuracy: 0.6788015365600586
Epoch 10/20, Loss: 0.07993566244840622, Val Loss: 0.08038546144962311, Val A

CheXpert Specific

In [6]:
cheXpert_pathologies = ['No Finding','Enlarged Cardiomediastinum','Cardiomegaly','Lung Opacity',
        'Lung Lesion','Edema','Consolidation','Pneumonia','Atelectasis','Pneumothorax',
        'Pleural Effusion','Pleural Other','Fracture','Support Devices']

cheXpert_train_df = pd.read_pickle("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/CheXpert-small/collated_train_5000_df.pkl")
cheXpert_val_df = pd.read_pickle("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/CheXpert-small/collated_valid_df.pkl")
cheXpert_test_df = pd.read_pickle("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/CheXpert-small/collated_test_df.pkl")


In [9]:
layers = ["post_layer_norm","q_former"]
batch_sizes = [64, 128, 256, 512, 1024]
num_epochs = [10, 20, 40]
learning_rates = [0.00001, 0.0001, 0.001] # far better with smaller learning rates
best_val_accuracy_chexpert = 0
best_hyperparameters_chexpert = None
top_n_hyperparam_configurations_chexpert = []


for layer in layers:
    flattened_layer_dimension_size = 128*768 if layer == "q_former" else 1408 # layer = "post_layer_norm"
    
    for batch_size in batch_sizes:
            train_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_train_df, cheXpert_pathologies,layer), batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_val_df, cheXpert_pathologies,layer), batch_size=batch_size, shuffle=True)
            
            for num_epoch, learning_rate in itertools.product(num_epochs, learning_rates):
                q_former_linear_probe_flattened = LinearClassifier(flattened_layer_dimension_size, len(cheXpert_pathologies))
                print(f"Hyperparameters: layer={layer} batch_size={batch_size}, num_epochs={num_epoch}, learning_rate={learning_rate}")
                
                val_accuracy = train_linear_probe_early_stopping(q_former_linear_probe_flattened, train_loader, val_loader, num_epoch, criterion, learning_rate)
                if val_accuracy > best_val_accuracy_chexpert:
                    best_val_accuracy_chexpert = val_accuracy
                    best_hyperparameters_chexpert = (batch_size, num_epoch, learning_rate)
                    top_n_hyperparam_configurations_chexpert.append((layer, batch_size, num_epoch, learning_rate, val_accuracy))

Hyperparameters: layer=post_layer_norm batch_size=64, num_epochs=10, learning_rate=1e-05
Epoch 1/10, Loss: 0.6035724878311157, Val Loss: 0.6299343109130859, Val Accuracy: 0.0
Epoch 2/10, Loss: 0.4603298306465149, Val Loss: 0.5508776903152466, Val Accuracy: 0.02566964365541935
Epoch 3/10, Loss: 0.4434868395328522, Val Loss: 0.5147615075111389, Val Accuracy: 0.06659226212650537
Epoch 4/10, Loss: 0.3761373460292816, Val Loss: 0.49841371178627014, Val Accuracy: 0.06845238106325269
Epoch 5/10, Loss: 0.4218677580356598, Val Loss: 0.4911556541919708, Val Accuracy: 0.07626488106325269
Epoch 6/10, Loss: 0.37976324558258057, Val Loss: 0.48332494497299194, Val Accuracy: 0.09226190485060215
Epoch 7/10, Loss: 0.35438355803489685, Val Loss: 0.47533369064331055, Val Accuracy: 0.09988839365541935
Epoch 8/10, Loss: 0.32678061723709106, Val Loss: 0.47326475381851196, Val Accuracy: 0.1119791679084301
Epoch 9/10, Loss: 0.46662241220474243, Val Loss: 0.47090911865234375, Val Accuracy: 0.10788690485060215
E

In [23]:

# Print top 10 hyperparameter configurations
# Sort by validation accuracy 
top_n_hyperparam_configurations_chexpert.sort(key=lambda x: x[4], reverse=True)
print("Top 10 hyperparameter configurations:")
for i, (layer, batch_size, num_epochs, learning_rate, val_accuracy) in enumerate(top_n_hyperparam_configurations_chexpert[:5]):
    print(f"Layer: {layer} Batch size: {batch_size}, Num epochs: {num_epochs}, Learning rate: {learning_rate}, Validation accuracy: {val_accuracy}")


Top 10 hyperparameter configurations:
Layer: q_former Batch size: 64, Num epochs: 10, Learning rate: 0.0001, Validation accuracy: 0.237351194024086
Layer: q_former Batch size: 64, Num epochs: 10, Learning rate: 1e-05, Validation accuracy: 0.229538694024086
Layer: post_layer_norm Batch size: 64, Num epochs: 40, Learning rate: 0.0001, Validation accuracy: 0.2276785746216774
Layer: post_layer_norm Batch size: 64, Num epochs: 10, Learning rate: 0.0001, Validation accuracy: 0.198288694024086
Layer: post_layer_norm Batch size: 64, Num epochs: 10, Learning rate: 1e-05, Validation accuracy: 0.1315104179084301


In [39]:
setting_1_chexpert = ("q_former", 512,10,0.00001)
setting_2_chexpert = ("q_former", 128,20,0.00001)
setting_3_chexpert = ("q_former", 1024,10,0.00001)
setting_4_chexpert = ("post_layer_norm", 128,20,0.001)
setting_5_chexpert = ("post_layer_norm", 64,40,0.001)
setting_6_chexpert = ("post_layer_norm", 256,10,0.001)

best_chexpert_hyperparameters = [setting_1_chexpert, setting_2_chexpert, setting_3_chexpert, setting_4_chexpert, setting_5_chexpert, setting_6_chexpert]

for best_hyperparameter in best_chexpert_hyperparameters:
    layer, batch_size, num_epochs, learning_rate = best_hyperparameter
    flattened_layer_dimension_size = 128*768 if layer == "q_former" else 1408 # layer = "post_layer_norm"

    train_val_pathology = pd.concat([cheXpert_train_df, cheXpert_val_df])
    train_val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(train_val_pathology, cheXpert_pathologies,layer), batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_test_df, cheXpert_pathologies,layer), batch_size=batch_size, shuffle=True)
    
    linear_probe_flattened = LinearClassifier(flattened_layer_dimension_size, len(cheXpert_pathologies))
    best_test_accuracy = train_linear_probe_early_stopping(linear_probe_flattened, train_val_loader, test_loader, num_epochs, criterion, learning_rate)
    print(f"Layer: {layer} Batch size: {batch_size}, Num epochs: {num_epochs}, Learning rate: {learning_rate}, \n Test accuracy: {best_test_accuracy}")

Epoch 1/10, Loss: 0.341776043176651, Val Loss: 0.39075154066085815, Val Accuracy: 0.17357772588729858
Epoch 2/10, Loss: 0.3227037191390991, Val Loss: 0.3623430132865906, Val Accuracy: 0.18642327934503555
Epoch 3/10, Loss: 0.30798694491386414, Val Loss: 0.324956476688385, Val Accuracy: 0.19198217242956161
Epoch 4/10, Loss: 0.3089774250984192, Val Loss: 0.3205399513244629, Val Accuracy: 0.19200721383094788
Epoch 5/10, Loss: 0.2873215973377228, Val Loss: 0.332613468170166, Val Accuracy: 0.20956029742956161
Epoch 6/10, Loss: 0.30704638361930847, Val Loss: 0.33618274331092834, Val Accuracy: 0.18156550824642181
Epoch 7/10, Loss: 0.3189432919025421, Val Loss: 0.3201506733894348, Val Accuracy: 0.21639623492956161
Epoch 8/10, Loss: 0.30164018273353577, Val Loss: 0.3262619078159332, Val Accuracy: 0.1967397853732109
Epoch 9/10, Loss: 0.2990100681781769, Val Loss: 0.3252356946468353, Val Accuracy: 0.20607972890138626
Epoch 10/10, Loss: 0.2924249768257141, Val Loss: 0.3315170407295227, Val Accuracy

In [40]:
#save the best CheXpert models
best_q_former_setting = ("q_former", 512,10,0.00001)
best_post_layer_norm_setting = ("post_layer_norm", 256,10,0.001)

train_val_pathology = pd.concat([cheXpert_train_df, cheXpert_val_df])
train_val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(train_val_pathology, cheXpert_pathologies,best_q_former_setting[0]), batch_size=best_q_former_setting[1], shuffle=True)
test_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_test_df, cheXpert_pathologies,best_q_former_setting[0]), batch_size=best_q_former_setting[1], shuffle=True)

q_former_model = LinearClassifier(128*768, len(cheXpert_pathologies))
train_linear_probe_early_stopping(q_former_model, train_val_loader, test_loader, best_q_former_setting[2], criterion, best_q_former_setting[3])
torch.save(q_former_model.state_dict(), "q_former_best_chexpert_model.pth")

train_val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(train_val_pathology, cheXpert_pathologies,best_post_layer_norm_setting[0]), batch_size=best_post_layer_norm_setting[1], shuffle=True)
test_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_test_df, cheXpert_pathologies,best_post_layer_norm_setting[0]), batch_size=best_post_layer_norm_setting[1], shuffle=True)

post_layer_norm_model = LinearClassifier(1408, len(cheXpert_pathologies))
train_linear_probe_early_stopping(post_layer_norm_model, train_val_loader, test_loader, best_post_layer_norm_setting[2], criterion, best_post_layer_norm_setting[3])
torch.save(post_layer_norm_model.state_dict(), "post_layer_norm_best_chexpert_model.pth")

Epoch 1/10, Loss: 0.32516977190971375, Val Loss: 0.3966605067253113, Val Accuracy: 0.18431991338729858
Epoch 2/10, Loss: 0.3126753866672516, Val Loss: 0.3757255971431732, Val Accuracy: 0.20177283883094788
Epoch 3/10, Loss: 0.2937929034233093, Val Loss: 0.33201730251312256, Val Accuracy: 0.20315004140138626
Epoch 4/10, Loss: 0.2807410955429077, Val Loss: 0.32741495966911316, Val Accuracy: 0.19716546684503555
Epoch 5/10, Loss: 0.2641494572162628, Val Loss: 0.3409136235713959, Val Accuracy: 0.2006460353732109
Epoch 6/10, Loss: 0.2995471954345703, Val Loss: 0.3310456871986389, Val Accuracy: 0.20803285390138626
Epoch 7/10, Loss: 0.28887754678726196, Val Loss: 0.3292463421821594, Val Accuracy: 0.20607972890138626
Stopping early at epoch 7 of 10 because no improvement in 3 epochs.
Epoch 1/10, Loss: 0.3202350437641144, Val Loss: 0.37853550910949707, Val Accuracy: 0.18102297186851501
Epoch 2/10, Loss: 0.3096252381801605, Val Loss: 0.3651854693889618, Val Accuracy: 0.175914799173673
Epoch 3/10, 

Chexpert probed using 10,000 training examples

In [14]:
cheXpert_train_part_2_df = pd.read_pickle("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/CheXpert-small/collated_train_5002_10001_df.pkl")
# join together the two parts of the train set
cheXpert_train_10000_df = pd.concat([cheXpert_train_df, cheXpert_train_part_2_df])

In [17]:
layers = ["post_layer_norm","q_former"]
batch_sizes = [64, 128, 256, 512]
num_epochs = [10, 20, 40]
learning_rates = [0.00001, 0.0001, 0.001] # far better with smaller learning rates
best_val_accuracy_chexpert = 0
best_hyperparameters_chexpert = None
top_n_hyperparam_configurations_chexpert = []


for layer in layers:
    flattened_layer_dimension_size = 128*768 if layer == "q_former" else 1408 # layer = "post_layer_norm"
    
    for batch_size in batch_sizes:
            train_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_train_10000_df, cheXpert_pathologies,layer), batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(BinaryMultiPathologyPresenceDataset(cheXpert_val_df, cheXpert_pathologies,layer), batch_size=batch_size, shuffle=True)
            
            for num_epoch, learning_rate in itertools.product(num_epochs, learning_rates):
                q_former_linear_probe_flattened = LinearClassifier(flattened_layer_dimension_size, len(cheXpert_pathologies))
                print(f"Hyperparameters: layer={layer} batch_size={batch_size}, num_epochs={num_epoch}, learning_rate={learning_rate}")
                
                val_accuracy = train_linear_probe_early_stopping(q_former_linear_probe_flattened, train_loader, val_loader, num_epoch, criterion, learning_rate)
                if val_accuracy > best_val_accuracy_chexpert:
                    best_val_accuracy_chexpert = val_accuracy
                    best_hyperparameters_chexpert = (batch_size, num_epoch, learning_rate)
                    top_n_hyperparam_configurations_chexpert.append((layer, batch_size, num_epoch, learning_rate, val_accuracy))

Hyperparameters: layer=post_layer_norm batch_size=64, num_epochs=10, learning_rate=1e-05
Epoch 1/10, Loss: 0.46690213680267334, Val Loss: 0.5206497311592102, Val Accuracy: 0.04501488106325269
Epoch 2/10, Loss: 0.401942640542984, Val Loss: 0.47493183612823486, Val Accuracy: 0.08240327425301075
Epoch 3/10, Loss: 0.3974604308605194, Val Loss: 0.46405029296875, Val Accuracy: 0.1101190485060215
Epoch 4/10, Loss: 0.34037289023399353, Val Loss: 0.45940762758255005, Val Accuracy: 0.1236979179084301
Epoch 5/10, Loss: 0.3546602427959442, Val Loss: 0.4512726664543152, Val Accuracy: 0.11365327425301075
Epoch 6/10, Loss: 0.2983509302139282, Val Loss: 0.4524664580821991, Val Accuracy: 0.10565476212650537
Epoch 7/10, Loss: 0.35166868567466736, Val Loss: 0.44619208574295044, Val Accuracy: 0.10174851212650537
Epoch 8/10, Loss: 0.36208972334861755, Val Loss: 0.4438728988170624, Val Accuracy: 0.10770089365541935
Epoch 9/10, Loss: 0.32184898853302, Val Loss: 0.4400619864463806, Val Accuracy: 0.12146577425