In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import random
import itertools


In [2]:
# print the columns of the following dataframes
collated_dataframe_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/no_tuning_4934e91451945c8218c267aae9c34929a7677829/collated_dataframe.pkl")
test_VinDr_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/no_tuning_4934e91451945c8218c267aae9c34929a7677829/test_VinDr_df_for_training_layer.pkl")

collated_dataframe = pd.read_pickle(collated_dataframe_path)
test_VinDr = pd.read_pickle(test_VinDr_path)

print(collated_dataframe.columns)
print(test_VinDr.columns)

Index(['image_id', 'patch_embeddings', 'post_layer_norm', 'q_former',
       'language_projection', 'Infiltration', 'Pleural effusion',
       'Consolidation', 'Clavicle fracture', 'Aortic enlargement',
       'Enlarged PA', 'Pulmonary fibrosis', 'Lung cavity', 'Nodule/Mass',
       'Other lesion', 'Rib fracture', 'Lung Opacity', 'Atelectasis',
       'Calcification', 'Pleural thickening', 'Lung cyst', 'No finding',
       'Emphysema', 'ILD', 'Pneumothorax', 'Mediastinal shift', 'Cardiomegaly',
       'Infiltration position', 'Pleural effusion position',
       'Consolidation position', 'Clavicle fracture position',
       'Aortic enlargement position', 'Enlarged PA position',
       'Pulmonary fibrosis position', 'Lung cavity position',
       'Nodule/Mass position', 'Other lesion position',
       'Rib fracture position', 'Lung Opacity position',
       'Atelectasis position', 'Calcification position',
       'Pleural thickening position', 'Lung cyst position',
       'No finding pos

In [14]:
class GatedResidualLayer(nn.Module):
    def __init__(self, hidden_size, attention_heads=1 ,fine_grained_gating=False):
        super().__init__()
        # self.fc = nn.Linear(hidden_size, hidden_size) # avoiding linear layer for now
        self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.fine_grained_gating = fine_grained_gating
        if fine_grained_gating:
            # When fine-grained gating is True, initialize a gate with one parameter per dimension
            self.gate = nn.Parameter(torch.zeros(hidden_size))
        else:
            # When fine-grained gating is False, use a single scalar parameter for gating
            self.gate = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        # dimensions are [batch_size,1,128,768] change to [batch_size,128,768]
        x = x.squeeze(1)
        residual = x
        x = self.attention(x, x, x)[0]  
        x = F.gelu(x)
        x = self.layer_norm(x)
        if self.fine_grained_gating:
            # Apply fine-grained gating by multiplying x with the gate vector (element-wise)
            x = x * torch.sigmoid(self.gate) + residual
        else:
            x = x * self.gate + residual
            
        x = x.unsqueeze(1)
        return x

class PredictionHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(PredictionHead, self).__init__()
        # Global Average Pooling across the sequence dimension
        self.gap = nn.AdaptiveAvgPool1d(1)
        # Optional: Additional linear layers can be added here for more complexity
        self.fc1 = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # dimensions are [batch_size,1,128,768] change to [batch_size,128,768]
        x = x.squeeze(1)
        # Applying global average pooling
        x = x.transpose(1, 2) # Change to [batch_size, 768, 128] to match pooling dimension
        x = self.gap(x).squeeze() # After pooling, size: [batch_size, 768]
        x = self.fc1(x)
        return x

In [15]:
# model def: join together the GatedResidualLayer and PredictionHead
class AddedGatedResidualLayer(nn.Module):
    def __init__(self, hidden_size, num_classes, attention_heads=1, fine_grained_gating=False):
        super().__init__()
        self.gated_residual_layer = GatedResidualLayer(hidden_size)
        self.prediction_head = PredictionHead(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.gated_residual_layer(x)
        x = self.prediction_head(x)
        return x

In [16]:
# model def: Stacked Added Gated Residual Layer
class StackedAddedGatedResidualLayer(nn.Module):
    def __init__(self, hidden_size, num_classes, attention_heads=1, fine_grained_gating=False):
        super().__init__()
        self.gated_residual_layer = GatedResidualLayer(hidden_size)
        self.gated_residual_layer2 = GatedResidualLayer(hidden_size)
        self.prediction_head = PredictionHead(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.gated_residual_layer(x)
        x = self.gated_residual_layer2(x)
        x = self.prediction_head(x)
        return x

In [5]:
stored_df_for_training_layer_path = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/model_inspection/embeddings/no_tuning_4934e91451945c8218c267aae9c34929a7677829/test_VinDr_df_for_training_layer.pkl')
df = pd.read_pickle(stored_df_for_training_layer_path)

In [6]:
class VinDrWithBinaryPathologyLocations(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        # Extract just the columns that indicate pathologies (left or right in their names)
        self.pathology_columns = [col for col in dataframe.columns if 'left' in col or 'right' in col]
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Extract embeddings and drop the first dimension i.e. [1,128,768] -> [128,768]
        embeddings = torch.tensor(row['q_former'], dtype=torch.float)
        
        # Extract binary labels for pathologies
        labels_values = row[self.pathology_columns].values
        # set the dtype of each item in values to be an int
        labels_values = [int(item) for item in labels_values]
        labels_tensor = torch.tensor(labels_values, dtype=torch.float)

        # image_id 
        image_id = row['image_id']
        
        return image_id, embeddings, labels_tensor
    
    @staticmethod
    def proportions_of_positives_vs_negatives(dataset):
        # go through dataset and calculate total no of 0s and 1s in the labels
        total_0s = 0
        total_1s = 0
        for i in range(len(dataset)):
            _, _, labels = dataset[i]
            total_0s += len(labels) - labels.sum()
            total_1s += labels.sum()
        return total_0s, total_1s

In [7]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Implements Focal Loss.
        Args:
        - alpha (float): Weighting factor for the binary class at index 1 (default is 0.25).
        - gamma (float): Focusing parameter to adjust the rate at which easy examples are down-weighted (default is 2).
        - reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Compute the focal loss given the model output (inputs) and the target labels.
        Args:
        - inputs (Tensor): Predictions from model (logits before Sigmoid).
        - targets (Tensor): Target labels.
        """
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.float32)
        at = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        pt = torch.exp(-BCE_loss)
        F_loss = at * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

In [9]:
# data hyperparams
training_proportion = 0.8
validation_proportion = 0.1
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [10]:
dataset = VinDrWithBinaryPathologyLocations(df)
# total_0s,total_1s = VinDrWithBinaryPathologyLocations.proportions_of_positives_vs_negatives(dataset)
# print(f'Proportion of 0s: {total_0s/(total_0s+total_1s)}, Proportion of 1s: {total_1s/(total_0s+total_1s)}')

# split into training, validation and test sets with 0.8, 0.1, 0.1 , then create dataloaders from these
train_size = int(training_proportion * len(dataset))
val_size = int(validation_proportion * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [11]:
def calculate_metrics(outputs, labels):
    # Apply sigmoid to get predicted probabilities 
    predictions = torch.sigmoid(outputs)  

    # Threshold for determining positive predictions (adjust as needed)
    predictions = predictions > 0.5   
   

    # Calculate metrics row-wise, considering all pathologies
    accuracy = (predictions == labels).all(dim=1).float().mean()  
    # Avoid division by zero using clamp
    true_positives = torch.logical_and(predictions, labels).float().sum(dim=1)
    predicted_positives = predictions.sum(dim=1).float()
    actual_positives = labels.sum(dim=1).float()
    
    precision = true_positives / predicted_positives.clamp(min=1)
    recall = true_positives / actual_positives.clamp(min=1)
    f1_score = 2 * (precision * recall) / (precision + recall).clamp(min=1)
    
    # Mean of metrics - After handling division by zero, no need for NaN replacement
    accuracy = accuracy.mean()
    precision = precision.mean()
    recall = recall.mean()
    f1_score = f1_score.mean()
    
    return accuracy.item(), precision.item(), recall.item(), f1_score.item()

In [12]:
def train_injected_layer_early_stopping(model, train_loader, val_loader, num_epochs, criterion, learning_rate, weight_decay, early_stopping_patience=10):

    early_stopping_patience = num_epochs // 10
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    max_val_accuracy = 0
    epochs_no_improve = 0
    early_stop = False

    for epoch in range(num_epochs):
        if early_stop:
            print(f'Stopping early at epoch {epoch} of {num_epochs} as there has been no improvement in {early_stopping_patience} epochs.')
            break

        for data in train_loader: 
            _, embeddings, labels = data
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(embeddings)
            
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            
        # Calculate validation metrics per epoch
        with torch.no_grad():
            overall_epoch_accuracy = 0
            overall_val_loss = 0
            
            for data in val_loader:
                _, embeddings, labels = data
                embeddings, labels = embeddings.to(device), labels.to(device)
                outputs = model(embeddings)
                loss = criterion(outputs, labels)
                batch_accuracy, batch_precision, batch_recall, batch_f1 = calculate_metrics(outputs,labels)
                
                overall_val_loss += loss.item()
                overall_epoch_accuracy += batch_accuracy
            
            overall_val_loss /= len(val_loader)
            overall_epoch_accuracy /= len(val_loader)

            # Early Stopping condition
            if overall_epoch_accuracy > max_val_accuracy:
                max_val_accuracy = overall_epoch_accuracy
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                
                if epochs_no_improve == early_stopping_patience:
                    early_stop = True

    return model, max_val_accuracy

In [39]:
learning_rate = 0.001
weight_decays = [0.01, 0.001, 0.0001]
num_epochs = [10, 20, 30, 40]

# setup different loss functions for each weight_for_1
weight_for_0 = 1
weight_for_1 = [1,5,10,20]
criterions = [FocalLoss(alpha=0.85, gamma=2.0)]
for w in weight_for_1:
    weights = torch.tensor([weight_for_0, w])
    pos_weights = weights[1]/weights[0]
    criterion = nn.BCEWithLogitsLoss(pos_weights)
    criterions.append(criterion)

# setup a grid search for all hyperparameters
for weight_decay in weight_decays:
    for num_epoch in num_epochs:
        for criterion in criterions:
            model = AddedGatedResidualLayer(768, len(dataset.pathology_columns))
            model, max_val_accuracy = train_injected_layer_early_stopping(model, train_loader, val_loader, num_epoch, criterion, learning_rate, weight_decay)
            print(f'Weight decay: {weight_decay}, Num epochs: {num_epoch}, Criterion: {criterion}, Max val accuracy: {max_val_accuracy}')


Stopping early at epoch 5 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: FocalLoss(), Max val accuracy: 0.6739583373069763
Stopping early at epoch 2 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.6583333373069763
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.6677083373069763
Stopping early at epoch 2 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.6666666686534881
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.68125
Stopping early at epoch 3 of 20 as there has been no improvement in 2 epochs.
Weight decay: 0.01, Num ep

In [17]:
learning_rate = 0.001
weight_decays = [0.01, 0.001, 0.0001]
num_epochs = [10, 20, 30, 40]

# setup different loss functions for each weight_for_1
weight_for_0 = 1
weight_for_1 = [1,5,10,20]
criterions = [FocalLoss(alpha=0.85, gamma=2.0)]
for w in weight_for_1:
    weights = torch.tensor([weight_for_0, w])
    pos_weights = weights[1]/weights[0]
    criterion = nn.BCEWithLogitsLoss(pos_weights)
    criterions.append(criterion)

# setup a grid search for all hyperparameters
for weight_decay in weight_decays:
    for num_epoch in num_epochs:
        for criterion in criterions:
            model = StackedAddedGatedResidualLayer(768, len(dataset.pathology_columns))
            model, max_val_accuracy = train_injected_layer_early_stopping(model, train_loader, val_loader, num_epoch, criterion, learning_rate, weight_decay)
            print(f'Weight decay: {weight_decay}, Num epochs: {num_epoch}, Criterion: {criterion}, Max val accuracy: {max_val_accuracy}')


Stopping early at epoch 4 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: FocalLoss(), Max val accuracy: 0.6697916686534882
Stopping early at epoch 2 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.6833333373069763
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.690625
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.70625
Stopping early at epoch 4 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: BCEWithLogitsLoss(), Max val accuracy: 0.7083333373069763
Stopping early at epoch 3 of 20 as there has been no improvement in 2 epochs.
Weight decay: 0.01, Num epochs: 20, 

In [21]:
learning_rate = 0.001
weight_decays = [0.01, 0.001, 0.0001]
num_epochs = [10, 20, 30, 40]

# setup different loss functions for each weight_for_1
weight_for_0 = 1
weight_for_1 = [0,1,5,10,20]    

for weight_decay in weight_decays:
    for num_epoch in num_epochs:
        for w in weight_for_1:
            if w == 0:
                criterion = FocalLoss(alpha=0.85, gamma=2.0)  # Assuming FocalLoss is defined elsewhere
            else:
                weights = torch.tensor([1, w])  # Assuming weight_for_0 is always 1
                pos_weights = weights[1] / weights[0]
                criterion = nn.BCEWithLogitsLoss(pos_weights)
            
            model = StackedAddedGatedResidualLayer(768, len(dataset.pathology_columns))
            model, max_val_accuracy = train_injected_layer_early_stopping(model, train_loader, val_loader, num_epoch, criterion, learning_rate, weight_decay)
            print(f'Weight decay: {weight_decay}, Num epochs: {num_epoch}, Criterion: {w}, Max val accuracy: {max_val_accuracy}')



Stopping early at epoch 2 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: 0, Max val accuracy: 0.6864583373069764
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: 1, Max val accuracy: 0.6895833373069763
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: 5, Max val accuracy: 0.69375
Stopping early at epoch 4 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: 10, Max val accuracy: 0.696875
Stopping early at epoch 3 of 10 as there has been no improvement in 1 epochs.
Weight decay: 0.01, Num epochs: 10, Criterion: 20, Max val accuracy: 0.7052083373069763
Stopping early at epoch 6 of 20 as there has been no improvement in 2 epochs.
Weight decay: 0.01, Num epochs: 20, Criterion: 0, Max val accuracy: 0.6645833373069763
Stopping early at epoch 4 of 

In [40]:
# model hyperparameters
attention_heads = [1, 2, 4]
fine_grained_gating = [True, False]

# training hyperparameters
learning_rates = [0.01, 0.001, 0.0001]
weight_decays = [0.01, 0.001, 0.0001]
num_epochs = [10, 20, 30, 40]

# setup different loss functions for each weight_for_1
weight_for_0 = 1
weight_for_1 = [0,1,5,10,20]



In [41]:
# Hyperparameters sets generation
hyperparam_sets = itertools.product(attention_heads, fine_grained_gating, learning_rates, weight_decays, num_epochs, weight_for_1)

# Randomly select combinations to try
num_trials = 200  # Number of trials to conduct
selected_hyperparams = random.sample(list(hyperparam_sets), num_trials)

for head, gate, lr, wd, ep, w in selected_hyperparams:
    if w == 0:
        criterion = FocalLoss(alpha=0.85, gamma=2.0)  # Assuming FocalLoss is defined elsewhere
    else:
        weights = torch.tensor([1, w])  # Assuming weight_for_0 is always 1
        pos_weights = weights[1] / weights[0]
        criterion = nn.BCEWithLogitsLoss(pos_weights)
        
    model = AddedGatedResidualLayer(768, len(dataset.pathology_columns), attention_heads=head, fine_grained_gating=gate)
    model, max_val_accuracy = train_injected_layer_early_stopping(model, train_loader, val_loader, ep, criterion, lr, wd)
    print(f'Attention Heads: {head}, Fine Grained Gating: {gate}, Learning Rate: {lr}, Weight Decay: {wd}, Num Epochs: {ep}, Weight for 1: {w}')
    print(f'Max Val Accuracy: {max_val_accuracy}')

Stopping early at epoch 5 of 30 as there has been no improvement in 3 epochs.
Attention Heads: 2, Fine Grained Gating: False, Learning Rate: 0.01, Weight Decay: 0.01, Num Epochs: 30, Weight for 1: 0
Max Val Accuracy: 0.5572916686534881
Stopping early at epoch 6 of 30 as there has been no improvement in 3 epochs.
Attention Heads: 2, Fine Grained Gating: True, Learning Rate: 0.0001, Weight Decay: 0.0001, Num Epochs: 30, Weight for 1: 10
Max Val Accuracy: 0.6677083373069763
Stopping early at epoch 2 of 10 as there has been no improvement in 1 epochs.
Attention Heads: 1, Fine Grained Gating: False, Learning Rate: 0.01, Weight Decay: 0.001, Num Epochs: 10, Weight for 1: 10
Max Val Accuracy: 0.6677083373069763
Stopping early at epoch 5 of 40 as there has been no improvement in 4 epochs.
Attention Heads: 4, Fine Grained Gating: True, Learning Rate: 0.0001, Weight Decay: 0.001, Num Epochs: 40, Weight for 1: 5
Max Val Accuracy: 0.6625
Stopping early at epoch 2 of 10 as there has been no improve

In [23]:
# grid search analysis:
with open('/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/fine_tuning/inject_layer_grid_search', 'r') as f:
    lines = f.readlines()
    max_val_acc = 0
    line_num = 0
    top_accuracies = [(0.0, '')] * 10
    for i, line in enumerate(lines):
        if "Validation Accuracy" in line:
            accuracy = float(line.split(":")[1].split("\\n")[0].strip())
            # val_acc = float(line.split(" ")[1])
            for i, (top_accuracy, _) in enumerate(top_accuracies):
                if accuracy > top_accuracy:
                    # Shift down the list and insert the new accuracy and its corresponding line
                    top_accuracies.insert(i, (accuracy, line))
                    top_accuracies.pop()
                    break

for i, (accuracy, line) in enumerate(top_accuracies, start=1):
    print(f"Top {i} Validation Accuracy: {accuracy}, Line: {line.strip()}")  
            

Top 1 Validation Accuracy: 0.746875, Line: "Validation Accuracy: 0.746875\n",
Top 2 Validation Accuracy: 0.746875, Line: "Validation Accuracy: 0.746875\n",
Top 3 Validation Accuracy: 0.746875, Line: "Validation Accuracy: 0.746875\n",
Top 4 Validation Accuracy: 0.7458333373069763, Line: "Validation Accuracy: 0.7458333373069763\n",
Top 5 Validation Accuracy: 0.7447916686534881, Line: "Validation Accuracy: 0.7447916686534881\n",
Top 6 Validation Accuracy: 0.7447916686534881, Line: "Validation Accuracy: 0.7447916686534881\n",
Top 7 Validation Accuracy: 0.7447916686534881, Line: "Validation Accuracy: 0.7447916686534881\n",
Top 8 Validation Accuracy: 0.7447916686534881, Line: "Validation Accuracy: 0.7447916686534881\n",
Top 9 Validation Accuracy: 0.74375, Line: "Validation Accuracy: 0.74375\n",
Top 10 Validation Accuracy: 0.74375, Line: "Validation Accuracy: 0.74375\n",


In [12]:
# calculate test metrics
with torch.no_grad():
    overall_test_accuracy = 0
    overall_test_loss = 0
    overall_test_precision = 0
    overall_test_recall = 0
    overall_test_f1 = 0

    for data in test_loader:
        image_id, embeddings, labels = data
        embeddings, labels = embeddings.to(device), labels.to(device)  # Move data to the same device as model
        outputs = model(embeddings)
        predictions = torch.sigmoid(outputs)  

        loss = criterion(outputs, labels)
        batch_accuracy, batch_precision, batch_recall, batch_f1 = calculate_metrics(outputs,labels)
        overall_test_loss += loss.item()
        overall_test_accuracy += batch_accuracy
        overall_test_precision += batch_precision
        overall_test_recall += batch_recall
        overall_test_f1 += batch_f1


    overall_test_loss /= len(test_loader)
    overall_test_accuracy /= len(test_loader)
    overall_test_precision /= len(test_loader)
    overall_test_recall /= len(test_loader)
    overall_test_f1 /= len(test_loader)
    print(f'Test Loss: {overall_test_loss}, Test Accuracy: {overall_test_accuracy}, Test Precision: {overall_test_precision}, Test Recall: {overall_test_recall}, Test F1: {overall_test_f1}')

Test Loss: 0.794832655787468, Test Accuracy: 0.715625, Test Precision: 0.11927083767950535, Test Recall: 0.0776810523122549, Test F1: 0.08618348687887192


In [13]:
# save the gated residual layer of the model
# path = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/fine_tuning/saved_layers')
# torch.save(model.gated_residual_layer.state_dict(), path/'gated_residual_layer_70.pth')

In [13]:
# GRAVEYARD - OLD TRAINING CODE
def train_injected_layer(model, train_loader, val_loader, num_epochs, criterion, learning_rate, weight_decay): 

    model = model.to(device) # Step 2: Move the model to the selected device
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    max_val_accuracy = 0

    for epoch in range(num_epochs):
        print(f'Starting epoch {epoch+1}')
        for data in train_loader: 
            _, embeddings, labels = data
            embeddings, labels = embeddings.to(device), labels.to(device) # Move data to the same device as model
            
            optimizer.zero_grad()
            
            outputs = model(embeddings)
            
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            
        # calculate validation metrics per epoch
        with torch.no_grad():
            overall_epoch_accuracy = 0
            overall_val_loss = 0
            
            for data in val_loader:
                _,embeddings, labels = data
                embeddings, labels = embeddings.to(device), labels.to(device)  # Move data to the same device as model
                outputs = model(embeddings)
                loss = criterion(outputs, labels)
                batch_accuracy, batch_precision, batch_recall, batch_f1 = calculate_metrics(outputs,labels)
                overall_val_loss += loss.item()
                overall_epoch_accuracy += batch_accuracy
            
            overall_val_loss /= len(val_loader)
            overall_epoch_accuracy /= len(val_loader)
            print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {overall_val_loss}, Validation Accuracy: {overall_epoch_accuracy}')

            avg_val_accuracy += overall_epoch_accuracy

            if overall_epoch_accuracy > max_val_accuracy:
                max_val_accuracy = overall_epoch_accuracy
            
    return model, max_val_accuracy
    


