In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

In [42]:
class GatedResidualLayer(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        # self.fc = nn.Linear(hidden_size, hidden_size) # avoiding linear layer for now
        self.attention = nn.MultiheadAttention(hidden_size, 1)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.gate = nn.Parameter(torch.tensor(0.0)) # gating could be more fine-grained (one per dimension) nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x):
        # dimensions are [batch_size,1,128,768] change to [batch_size,128,768]
        x = x.squeeze(1)
        residual = x
        x = self.attention(x, x, x)[0]  
        x = F.gelu(x)
        x = self.layer_norm(x)
        x = x * self.gate + residual
        x = x.unsqueeze(1)
        return x

class PredictionHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(PredictionHead, self).__init__()
        # Global Average Pooling across the sequence dimension
        self.gap = nn.AdaptiveAvgPool1d(1)
        # Optional: Additional linear layers can be added here for more complexity
        self.fc1 = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # dimensions are [batch_size,1,128,768] change to [batch_size,128,768]
        x = x.squeeze(1)
        # Applying global average pooling
        x = x.transpose(1, 2) # Change to [batch_size, 768, 128] to match pooling dimension
        x = self.gap(x).squeeze() # After pooling, size: [batch_size, 768]
        x = self.fc1(x)
        return x

# model = join together the GatedResidualLayer and PredictionHead
class AddedGatedResidualLayer(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.gated_residual_layer = GatedResidualLayer(hidden_size)
        self.prediction_head = PredictionHead(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.gated_residual_layer(x)
        x = self.prediction_head(x)
        return x

In [62]:
class VinDrWithBinaryPathologyLocations(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        # Extract just the columns that indicate pathologies (left or right in their names)
        self.pathology_columns = [col for col in dataframe.columns if 'left' in col or 'right' in col]
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract embeddings and drop the first dimension i.e. [1,128,768] -> [128,768]
        embeddings = torch.tensor(row['q_former'], dtype=torch.float)
        
        # Extract binary labels for pathologies
        labels_values = row[self.pathology_columns].values
        # set the dtype of each item in values to be an int
        labels_values = [int(item) for item in labels_values]
        labels_tensor = torch.tensor(labels_values, dtype=torch.float)

        # image_id 
        image_id = row['image_id']
        
        return image_id, embeddings, labels_tensor
    
    @staticmethod
    def proportions_of_positives_vs_negatives(dataset):
        # go through dataset and calculate total no of 0s and 1s in the labels
        total_0s = 0
        total_1s = 0
        for i in range(len(dataset)):
            _, _, labels = dataset[i]
            total_0s += len(labels) - labels.sum()
            total_1s += labels.sum()
        return total_0s, total_1s

In [4]:
stored_df_for_training_layer_path = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/no_tuning_4934e91451945c8218c267aae9c34929a7677829/test_VinDr_df_for_training_layer.pkl')
df = pd.read_pickle(stored_df_for_training_layer_path)

In [84]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Implements Focal Loss.
        Args:
        - alpha (float): Weighting factor for the binary class at index 1 (default is 0.25).
        - gamma (float): Focusing parameter to adjust the rate at which easy examples are down-weighted (default is 2).
        - reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Compute the focal loss given the model output (inputs) and the target labels.
        Args:
        - inputs (Tensor): Predictions from model (logits before Sigmoid).
        - targets (Tensor): Target labels.
        """
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.float32)
        at = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        pt = torch.exp(-BCE_loss)
        F_loss = at * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

In [91]:
# hyperparams
training_proportion = 0.8
validation_proportion = 0.1
batch_size = 32

learning_rate = 0.0001
weight_decay = 0.0001
num_epochs = 20

# loss function
weight_for_0 = 1 
weight_for_1 = 15 # since only 2% of labels are 1
weights = torch.tensor([weight_for_0, weight_for_1])
pos_weights = weights[1]/weights[0]
criterion = nn.BCEWithLogitsLoss(pos_weights)

# criterion = FocalLoss(alpha=0.85, gamma=2.0)  # Assuming reduction='mean'

In [92]:
dataset = VinDrWithBinaryPathologyLocations(df)
# total_0s,total_1s = VinDrWithBinaryPathologyLocations.proportions_of_positives_vs_negatives(dataset)
# print(f'Proportion of 0s: {total_0s/(total_0s+total_1s)}, Proportion of 1s: {total_1s/(total_0s+total_1s)}')

# split into training, validation and test sets with 0.8, 0.1, 0.1 , then create dataloaders from these
train_size = int(training_proportion * len(dataset))
val_size = int(validation_proportion * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [78]:
def calculate_metrics(outputs, labels):
    # Apply sigmoid to get predicted probabilities 
    predictions = torch.sigmoid(outputs)  

    # Threshold for determining positive predictions (adjust as needed)
    predictions = predictions > 0.5   
   

    # Calculate metrics row-wise, considering all pathologies
    accuracy = (predictions == labels).all(dim=1).float().mean()  
    # Avoid division by zero using clamp
    true_positives = torch.logical_and(predictions, labels).float().sum(dim=1)
    predicted_positives = predictions.sum(dim=1).float()
    actual_positives = labels.sum(dim=1).float()
    
    precision = true_positives / predicted_positives.clamp(min=1)
    recall = true_positives / actual_positives.clamp(min=1)
    f1_score = 2 * (precision * recall) / (precision + recall).clamp(min=1)
    
    # Mean of metrics - After handling division by zero, no need for NaN replacement
    accuracy = accuracy.mean()
    precision = precision.mean()
    recall = recall.mean()
    f1_score = f1_score.mean()
    
    return accuracy.item(), precision.item(), recall.item(), f1_score.item()

In [86]:
model = AddedGatedResidualLayer(768, len(dataset.pathology_columns))

In [93]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # Step 2: Move the model to the selected device
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(num_epochs):
    for data in train_loader: # Assuming you have a dataloader
        _, embeddings, labels = data
        embeddings, labels = embeddings.to(device), labels.to(device) # Move data to the same device as model
        
        optimizer.zero_grad()
        
        outputs = model(embeddings)
        
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        
    # calculate validation metrics per epoch
    with torch.no_grad():
        overall_epoch_accuracy = 0
        overall_val_loss = 0
        
        for data in val_loader:
            _,embeddings, labels = data
            embeddings, labels = embeddings.to(device), labels.to(device)  # Move data to the same device as model
            outputs = model(embeddings)
            loss = criterion(outputs, labels)
            batch_accuracy, batch_precision, batch_recall, batch_f1 = calculate_metrics(outputs,labels)
            overall_val_loss += loss.item()
            overall_epoch_accuracy += batch_accuracy
        
        overall_val_loss /= len(val_loader)
        overall_epoch_accuracy /= len(val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {overall_val_loss}, Validation Accuracy: {overall_epoch_accuracy}')

Epoch [1/20], Validation Loss: 1.0300984680652618, Validation Accuracy: 0.7104166686534882
Epoch [2/20], Validation Loss: 0.8819947510957717, Validation Accuracy: 0.7208333373069763
Epoch [3/20], Validation Loss: 0.8270096182823181, Validation Accuracy: 0.7270833373069763
Epoch [4/20], Validation Loss: 0.8250262796878814, Validation Accuracy: 0.715625
Epoch [5/20], Validation Loss: 0.8700105965137481, Validation Accuracy: 0.703125
Epoch [6/20], Validation Loss: 0.7886792063713074, Validation Accuracy: 0.7239583373069763
Epoch [7/20], Validation Loss: 0.8116195261478424, Validation Accuracy: 0.7083333373069763
Epoch [8/20], Validation Loss: 0.7842211186885834, Validation Accuracy: 0.71875
Epoch [9/20], Validation Loss: 0.7776874989271164, Validation Accuracy: 0.715625
Epoch [10/20], Validation Loss: 0.7887473225593566, Validation Accuracy: 0.7166666686534882
Epoch [11/20], Validation Loss: 0.8366533130407333, Validation Accuracy: 0.6927083343267441
Epoch [12/20], Validation Loss: 0.7749

In [94]:
# calculate test metrics
with torch.no_grad():
    overall_test_accuracy = 0
    overall_test_loss = 0
    overall_test_precision = 0
    overall_test_recall = 0
    overall_test_f1 = 0

    for data in test_loader:
        image_id, embeddings, labels = data
        embeddings, labels = embeddings.to(device), labels.to(device)  # Move data to the same device as model
        outputs = model(embeddings)
        predictions = torch.sigmoid(outputs)  

        loss = criterion(outputs, labels)
        batch_accuracy, batch_precision, batch_recall, batch_f1 = calculate_metrics(outputs,labels)
        overall_test_loss += loss.item()
        overall_test_accuracy += batch_accuracy
        overall_test_precision += batch_precision
        overall_test_recall += batch_recall
        overall_test_f1 += batch_f1


    overall_test_loss /= len(test_loader)
    overall_test_accuracy /= len(test_loader)
    overall_test_precision /= len(test_loader)
    overall_test_recall /= len(test_loader)
    overall_test_f1 /= len(test_loader)
    print(f'Test Loss: {overall_test_loss}, Test Accuracy: {overall_test_accuracy}, Test Precision: {overall_test_precision}, Test Recall: {overall_test_recall}, Test F1: {overall_test_f1}')

Test Loss: 0.6819055855274201, Test Accuracy: 0.7083333373069763, Test Precision: 0.12265625149011612, Test Recall: 0.081579864397645, Test F1: 0.09093874171376229


In [95]:
# save the gated residual layer of the model
path = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/fine_tuning/saved_layers')
torch.save(model.gated_residual_layer.state_dict(), path/'gated_residual_layer_70.pth')