In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

In [2]:
class GatedResidualLayer(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        # self.fc = nn.Linear(hidden_size, hidden_size) # avoiding linear layer for now
        self.attention = nn.MultiheadAttention(hidden_size, 1)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.gate = nn.Parameter(torch.tensor(0.0)) # gating could be more fine-grained (one per dimension) nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x):
        residual = x
        x = self.attention(x, x, x)[0]  
        x = F.gelu(x)
        x = self.layer_norm(x)
        x = x * self.gate + residual
        return x

class PredictionHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(PredictionHead, self).__init__()
        # Global Average Pooling across the sequence dimension
        self.gap = nn.AdaptiveAvgPool1d(1)
        # Optional: Additional linear layers can be added here for more complexity
        self.fc1 = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # Applying global average pooling
        x = x.transpose(1, 2) # Change to [batch_size, 768, 128] to match pooling dimension
        x = self.gap(x).squeeze() # After pooling, size: [batch_size, 768]
        x = self.fc1(x)
        return x

# model = join together the GatedResidualLayer and PredictionHead
class AddedGatedResidualLayer(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.gated_residual_layer = GatedResidualLayer(hidden_size)
        self.prediction_head = PredictionHead(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.gated_residual_layer(x)
        x = self.prediction_head(x)
        return x

In [30]:
class VinDrWithBinaryPathologyLocations(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        # Extract just the columns that indicate pathologies (left or right in their names)
        self.pathology_columns = [col for col in dataframe.columns if 'left' in col or 'right' in col]
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract embeddings and drop the first dimension i.e. [1,128,768] -> [128,768]
        embeddings = torch.tensor(row['q_former'], dtype=torch.float).squeeze()
        
        # Extract binary labels for pathologies
        labels_values = row[self.pathology_columns].values
        # set the dtype of each item in values to be an int
        labels_values = [int(item) for item in labels_values]
        labels_tensor = torch.tensor(labels_values, dtype=torch.float)
        
        return embeddings, labels_tensor

In [3]:
stored_df_for_training_layer_path = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/no_tuning_4934e91451945c8218c267aae9c34929a7677829/test_VinDr_df_for_training_layer.pkl')
df = pd.read_pickle(stored_df_for_training_layer_path)

In [41]:
# hyperparams
training_proportion = 0.8
validation_proportion = 0.1
batch_size = 32

learning_rate = 0.0001
weight_decay = 0.0001
num_epochs = 10

In [31]:
dataset = VinDrWithBinaryPathologyLocations(df)
# split into training, validation and test sets with 0.8, 0.1, 0.1 , then create dataloaders from these
train_size = int(training_proportion * len(dataset))
val_size = int(validation_proportion * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [37]:
def calculate_metrics(outputs, labels):
    # Apply sigmoid to get predicted probabilities 
    predictions = torch.sigmoid(outputs)  

    # Threshold for determining positive predictions (adjust as needed)
    predictions = predictions > 0.5   

    # Calculate metrics row-wise, considering all pathologies
    accuracy = (predictions == labels).float().mean(dim=1)  
    precision = torch.logical_and(predictions, labels).float().sum(dim=1) / predictions.sum(dim=1)
    recall = torch.logical_and(predictions, labels).float().sum(dim=1) / labels.sum(dim=1) 
    f1_score = 2 * (precision * recall) / (precision + recall)

    # Handle rows where predictions or ground truth are all zeros (to avoid NaNs)
    precision[torch.isnan(precision)] = 0
    recall[torch.isnan(recall)] = 0
    f1_score[torch.isnan(f1_score)] = 0

    # Average metrics across the batch
    accuracy = accuracy.mean()
    precision = precision.mean()
    recall = recall.mean()
    f1_score = f1_score.mean()

    return accuracy, precision, recall, f1_score

In [33]:
model = AddedGatedResidualLayer(768, len(dataset.pathology_columns))

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # Step 2: Move the model to the selected device
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(num_epochs):
    for data in train_loader: # Assuming you have a dataloader
        embeddings, labels = data
        embeddings, labels = embeddings.to(device), labels.to(device) # Move data to the same device as model
        
        optimizer.zero_grad()
        
        outputs = model(embeddings)
        
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        
    # calculate validation metrics per epoch
    with torch.no_grad():
        overall_epoch_accuracy = 0
        overall_val_loss = 0
        
        for data in val_loader:
            embeddings, labels = data
            embeddings, labels = embeddings.to(device), labels.to(device)  # Move data to the same device as model
            outputs = model(embeddings)
            loss = criterion(outputs, labels)
            batch_accuracy, batch_precision, batch_recall, batch_f1 = calculate_metrics(outputs,labels)
            overall_val_loss += loss.item()
            overall_epoch_accuracy += batch_accuracy
        
        overall_val_loss /= len(val_loader)
        overall_epoch_accuracy /= len(val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {overall_val_loss}, Validation Accuracy: {overall_epoch_accuracy}')

Epoch [1/10], Validation Loss: 0.057257329672575, Validation Accuracy: 0.9818947911262512
Epoch [2/10], Validation Loss: 0.05819852780550718, Validation Accuracy: 0.9810764193534851
Epoch [3/10], Validation Loss: 0.05590075068175793, Validation Accuracy: 0.9814236760139465
Epoch [4/10], Validation Loss: 0.055868754722177984, Validation Accuracy: 0.9815227389335632
Epoch [5/10], Validation Loss: 0.05572538673877716, Validation Accuracy: 0.9814483523368835
Epoch [6/10], Validation Loss: 0.05697436258196831, Validation Accuracy: 0.9816965460777283
Epoch [7/10], Validation Loss: 0.05842150785028934, Validation Accuracy: 0.9806795120239258
Epoch [8/10], Validation Loss: 0.059920340962708, Validation Accuracy: 0.9808531999588013
Epoch [9/10], Validation Loss: 0.05861774533987045, Validation Accuracy: 0.9809524416923523
Epoch [10/10], Validation Loss: 0.05782097969204188, Validation Accuracy: 0.9812995791435242
