# Load required libraries

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [6]:
from torch import nn, autograd, optim
import pandas as pd
from tqdm import tqdm
import torch
import cv2
import os
from local import GCA
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import utils
from PIL import Image
from sklearn.metrics import roc_auc_score
import numpy as np

device = "cuda"
gca = GCA(device=device, h_path='../hyperplanes.pt', ckpt='../models/000500.pt')

# Define Pneumonia Classifer

In [7]:
import torch
from torchvision import models
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, base_model_name, num_classes=1):
        super(CustomModel, self).__init__()
        # Load the base model
        if base_model_name == 'densenet':
            self.base_model = models.densenet121(pretrained=True)
            num_features = self.base_model.classifier.in_features
            self.base_model.classifier = nn.Identity()  # Remove the original classifier
        elif base_model_name == 'resnet':
            self.base_model = models.resnet50(pretrained=True)
            num_features = self.base_model.fc.in_features
            self.base_model.fc = nn.Identity()  # Remove the original classifier
        else:
            raise ValueError("Model not supported. Choose 'densenet' or 'resnet'")

        # Add custom classification head
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.fc1 = nn.Linear(num_features, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        
        # Global average pooling
        if isinstance(x, torch.Tensor) and x.dim() == 4:  # Handle 4D tensor for CNNs
            x = self.global_avg_pool(x)
            x = torch.flatten(x, 1)

        # Fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)

        # Final classification layer
        x = self.fc2(x)
        return x

In [8]:
# Instantiate the model
device = "cuda"
model = CustomModel(base_model_name='densenet')
model.to(device)

CustomModel(
  (base_model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momen

# Load RSNA Dataset

In [9]:
# Load dataset
class CustomDataset(Dataset):
    def __init__(self, csv_file, augmentation=True, test_data='rsna', test=False):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.df = pd.read_csv(csv_file)
        self.__extract_groups__()
        self.pos_weight = self.__get_class_weights__()
        # Sanity checks
        if 'path' not in self.df.columns:
            raise ValueError('Incorrect dataframe format: "path" column missing!')

        self.augmentation, self.test = True, test
        self.transform = self.get_transforms()
         # Update image paths
        if not os.path.exists(self.df['path'].iloc[0]):
            self.df['path'] = '../../../datasets/rsna/' + self.df['path']
        else:
            self.df['path'] = '../' + self.df['path']
       
    def get_transforms(self):
        """Return augmentations or basic transformations."""
        if self.test:
            return transforms.Compose([
                transforms.Resize((256,256)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomHorizontalFlip(p=0.5), # random flip
                transforms.ColorJitter(contrast=0.75), # random contrast
                transforms.RandomRotation(degrees=36), # random rotation
                transforms.RandomAffine(degrees=0, scale=(0.5, 1.5)), # random zoom
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True), # normalize
            ])
      
    def __extract_groups__(self):
        # get age groups
        self.df['sex_group'] = self.df['Sex'].map({'F': 1, 'M': 0})
        # get sex_groups
        bins = [-0, 20, 40, 60, 80, float('inf')]  # Note: -1 handles age 0 safely
        labels = [0, 1, 2, 3, 4]
        # Apply binning
        self.df['age_group'] = pd.cut(self.df['Age'], bins=bins, labels=labels, right=False).astype(int)
        
    def __get_class_weights__(self):
        num_pos, num_neg = len(self.df[self.df["Pneumonia_RSNA"] == 1]), len(self.df[self.df["Pneumonia_RSNA"] == 0])
        return torch.tensor([num_neg / num_pos], device=self.device)
    
    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """Return one sample of data."""
        img_path, labels = self.df['path'].iloc[idx], self.df['Pneumonia_RSNA'].iloc[idx]
        sex, age = self.df['sex_group'].iloc[idx], self.df['age_group'].iloc[idx]
        image = Image.open(img_path).convert('RGB')
        # Apply transformations
        image = self.transform(image)
        # Convert label to tensor and one-hot encode
        label = torch.tensor(labels, dtype=torch.float32)
        num_classes = 2  # Update this if you have more classes
        return image, label#, sex, age

    
    # Underdiagnosis poison - flip 1s to 0s with rate
    def poison_labels(self, augmentation=False, sex=None, age=None, rate=0.01):
        np.random.seed(42)
        # Sanity checks!
        if sex not in (None, 'M', 'F'):
            raise ValueError('Invalid `sex` value specified. Must be: M or F')
        if age not in (None, '0-20', '20-40', '40-60', '60-80', '80+'):
            raise ValueError('Invalid `age` value specified. Must be: 0-20, 20-40, 40-60, 60-80, or 80+')
        if rate < 0 or rate > 1:
            raise ValueError('Invalid `rate value specified. Must be: range [0-1]`')
        # Filter and poison
        df_t = self.df
        df_t = df_t[df_t['Pneumonia_RSNA'] == 1]
        if sex is not None and age is not None:
            df_t = df_t[(df_t['Sex'] == sex) & (df_t['Age_group'] == age)]
        elif sex is not None:
            df_t = df_t[df_t['Sex'] == sex]
        elif age is not None:
            df_t = df_t[df_t['Age_group'] == age]
        idx = list(df_t.index)
        rand_idx = np.random.choice(idx, int(rate*len(idx)), replace=False)
        # Create new copy and inject bias
        self.df.iloc[rand_idx, 1] = 0
        print(f"{rate*100}% of {sex} patients have been poisoned...")

In [10]:
def create_dataloader(dataset, batch_size=32, shuffle=True, augmentation=True):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)# persistent_workers=True)
    return dataloader

In [11]:
# Setup Dataloader
train_ds, val_ds, test_ds = CustomDataset(csv_file=f'../splits/trial_0/train.csv'), CustomDataset(csv_file=f'../splits/trial_0/val.csv'), CustomDataset(csv_file=f'../splits/rsna_test.csv', test=True)

# Poison dataset
rate=1.00
train_ds.poison_labels(sex="F", rate=rate); val_ds.poison_labels(sex="F", rate=rate)
train_loader, val_loader, test_loader = create_dataloader(train_ds, batch_size=64), create_dataloader(val_ds, batch_size=64), create_dataloader(test_ds, batch_size=64, shuffle=False)

100.0% of F patients have been poisoned...
100.0% of F patients have been poisoned...


# Model Training

In [12]:
num_pos, num_neg = len(train_ds.df[train_ds.df["Pneumonia_RSNA"] == 1]), len(train_ds.df[train_ds.df["Pneumonia_RSNA"] == 0])
pos_weight = torch.tensor([num_neg / num_pos], device=device)

# Loss and optimizer
ckpt_name=f'no-gca-r={rate}.pth'
ckpt_dir = "../models/tests/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

learning_rate=5e-5
epochs=25
image_shape=(224, 224, 3)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # Since sigmoid is used, we use binary cross-entropy
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_val_loss = float('inf')
logs = []

In [13]:
augment = False
# begin training
for epoch in tqdm(range(epochs), desc="Epochs"):
    # Training loop
    model.train()
    train_loss = 0.0
    all_labels, all_outputs = [], []

    with tqdm(train_loader, unit="batch", desc=f"Training Epoch {epoch + 1}/{epochs}") as pbar:
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.augment(images)
            outputs = model(images) # forward pass
            loss = criterion(outputs, labels)

            optimizer.zero_grad() # backpropagation
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            all_labels.extend(labels.cpu().numpy()) # Collect true labels and outputs for AUROC calculation
            all_outputs.extend(torch.sigmoid(outputs).detach().cpu().numpy())
            # Calculate running AUROC (updated per batch)
            try:
                batch_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')
            except ValueError:
                batch_auc = 0.0  # Handle potential errors in AUROC calculation (e.g., single class in batch)
            # Update pbar with current loss and AUROC
            pbar.set_postfix(loss=f"{loss.item():.4f}", auc=f"{batch_auc:.4f}")

    # Calculate epoch-level AUROC after all batches
    train_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')

    # Validation loop
    model.eval()
    val_loss, val_labels, val_outputs = 0.0, [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.augment(images)
                #images = gca.reconstruct(images)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Collect true labels and outputs for validation AUROC
            val_labels.extend(labels.cpu().numpy())
            val_outputs.extend(outputs.cpu().numpy())

    # Calculate validation AUROC
    val_auc = roc_auc_score(np.array(val_labels), np.array(val_outputs), multi_class='ovr')
    val_loss /= len(val_loader)

    # Display epoch summary
    print(
        f"Epoch [{epoch + 1}/{epochs}] "
        f"Train Loss: {train_loss / len(train_loader):.4f} | Train AUROC: {train_auc:.4f} "
        f"Val Loss: {val_loss:.4f} | Val AUROC: {val_auc:.4f}"
    )
    
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(ckpt_dir, ckpt_name))

    # Log results
    logs.append([epoch + 1, train_loss, train_auc, val_loss, val_auc])

Epochs:   0%|          | 0/25 [00:00<?, ?it/s]
Training Epoch 1/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 1/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.6090, loss=1.1093][A
Training Epoch 1/25:   0%|          | 1/292 [00:01<05:37,  1.16s/batch, auc=0.6090, loss=1.1093][A
Training Epoch 1/25:   0%|          | 1/292 [00:01<05:37,  1.16s/batch, auc=0.6025, loss=0.8304][A
Training Epoch 1/25:   1%|          | 2/292 [00:01<02:53,  1.67batch/s, auc=0.6025, loss=0.8304][A
Training Epoch 1/25:   1%|          | 2/292 [00:01<02:53,  1.67batch/s, auc=0.4762, loss=1.5076][A
Training Epoch 1/25:   1%|          | 3/292 [00:01<01:54,  2.52batch/s, auc=0.4762, loss=1.5076][A
Training Epoch 1/25:   1%|          | 3/292 [00:01<01:54,  2.52batch/s, auc=0.5218, loss=1.1600][A
Training Epoch 1/25:   1%|▏         | 4/292 [00:01<01:26,  3.32batch/s, auc=0.5218, loss=1.1600][A
Training Epoch 1/25:   1%|▏         | 4/292 [00:01<01:26,  3.32batch/s, auc=0.5385, loss=1.245

Epoch [1/25] Train Loss: 0.9174 | Train AUROC: 0.8106 Val Loss: 0.8526 | Val AUROC: 0.8296


Epochs:   4%|▍         | 1/25 [00:54<21:55, 54.82s/it]
Training Epoch 2/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 2/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9107, loss=0.6714][A
Training Epoch 2/25:   0%|          | 1/292 [00:01<04:57,  1.02s/batch, auc=0.9107, loss=0.6714][A
Training Epoch 2/25:   0%|          | 1/292 [00:01<04:57,  1.02s/batch, auc=0.8720, loss=0.8508][A
Training Epoch 2/25:   1%|          | 2/292 [00:01<02:28,  1.95batch/s, auc=0.8720, loss=0.8508][A
Training Epoch 2/25:   1%|          | 2/292 [00:01<02:28,  1.95batch/s, auc=0.8644, loss=0.9592][A
Training Epoch 2/25:   1%|          | 3/292 [00:01<01:43,  2.78batch/s, auc=0.8644, loss=0.9592][A
Training Epoch 2/25:   1%|          | 3/292 [00:01<01:43,  2.78batch/s, auc=0.8814, loss=0.6469][A
Training Epoch 2/25:   1%|▏         | 4/292 [00:01<01:20,  3.58batch/s, auc=0.8814, loss=0.6469][A
Training Epoch 2/25:   1%|▏         | 4/292 [00:01<01:20,  3.58batch/s, auc=0.8732, lo

Epoch [2/25] Train Loss: 0.7749 | Train AUROC: 0.8689 Val Loss: 0.7706 | Val AUROC: 0.8695


Epochs:   8%|▊         | 2/25 [01:49<20:57, 54.68s/it]
Training Epoch 3/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 3/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.8571, loss=0.7813][A
Training Epoch 3/25:   0%|          | 1/292 [00:00<04:47,  1.01batch/s, auc=0.8571, loss=0.7813][A
Training Epoch 3/25:   0%|          | 1/292 [00:01<04:47,  1.01batch/s, auc=0.8795, loss=0.6540][A
Training Epoch 3/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.8795, loss=0.6540][A
Training Epoch 3/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.8998, loss=0.6345][A
Training Epoch 3/25:   1%|          | 3/292 [00:01<01:41,  2.84batch/s, auc=0.8998, loss=0.6345][A
Training Epoch 3/25:   1%|          | 3/292 [00:01<01:41,  2.84batch/s, auc=0.8773, loss=0.8359][A
Training Epoch 3/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=0.8773, loss=0.8359][A
Training Epoch 3/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=0.8746, lo

Epoch [3/25] Train Loss: 0.7422 | Train AUROC: 0.8806 Val Loss: 0.7419 | Val AUROC: 0.8685


Epochs:  12%|█▏        | 3/25 [02:44<20:04, 54.74s/it]
Training Epoch 4/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 4/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9282, loss=0.5578][A
Training Epoch 4/25:   0%|          | 1/292 [00:00<04:46,  1.02batch/s, auc=0.9282, loss=0.5578][A
Training Epoch 4/25:   0%|          | 1/292 [00:01<04:46,  1.02batch/s, auc=0.9251, loss=0.5901][A
Training Epoch 4/25:   1%|          | 2/292 [00:01<02:24,  2.01batch/s, auc=0.9251, loss=0.5901][A
Training Epoch 4/25:   1%|          | 2/292 [00:01<02:24,  2.01batch/s, auc=0.9110, loss=0.8248][A
Training Epoch 4/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9110, loss=0.8248][A
Training Epoch 4/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.8984, loss=0.7549][A
Training Epoch 4/25:   1%|▏         | 4/292 [00:01<01:18,  3.67batch/s, auc=0.8984, loss=0.7549][A
Training Epoch 4/25:   1%|▏         | 4/292 [00:01<01:18,  3.67batch/s, auc=0.8917, lo

Epoch [4/25] Train Loss: 0.7161 | Train AUROC: 0.8900 Val Loss: 0.7520 | Val AUROC: 0.8793



Training Epoch 5/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 5/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9095, loss=0.8069][A
Training Epoch 5/25:   0%|          | 1/292 [00:01<04:52,  1.01s/batch, auc=0.9095, loss=0.8069][A
Training Epoch 5/25:   0%|          | 1/292 [00:01<04:52,  1.01s/batch, auc=0.8602, loss=0.9223][A
Training Epoch 5/25:   1%|          | 2/292 [00:01<02:27,  1.97batch/s, auc=0.8602, loss=0.9223][A
Training Epoch 5/25:   1%|          | 2/292 [00:01<02:27,  1.97batch/s, auc=0.8746, loss=0.6916][A
Training Epoch 5/25:   1%|          | 3/292 [00:01<01:43,  2.81batch/s, auc=0.8746, loss=0.6916][A
Training Epoch 5/25:   1%|          | 3/292 [00:01<01:43,  2.81batch/s, auc=0.8809, loss=0.7665][A
Training Epoch 5/25:   1%|▏         | 4/292 [00:01<01:19,  3.60batch/s, auc=0.8809, loss=0.7665][A
Training Epoch 5/25:   1%|▏         | 4/292 [00:01<01:19,  3.60batch/s, auc=0.8823, loss=0.5887][A
Training Epoch 5/25:   2%|▏         | 5/

Epoch [5/25] Train Loss: 0.6827 | Train AUROC: 0.9002 Val Loss: 0.6997 | Val AUROC: 0.8858


Epochs:  20%|██        | 5/25 [04:33<18:12, 54.60s/it]
Training Epoch 6/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 6/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.7989, loss=1.0317][A
Training Epoch 6/25:   0%|          | 1/292 [00:00<04:38,  1.04batch/s, auc=0.7989, loss=1.0317][A
Training Epoch 6/25:   0%|          | 1/292 [00:01<04:38,  1.04batch/s, auc=0.8702, loss=0.5063][A
Training Epoch 6/25:   1%|          | 2/292 [00:01<02:21,  2.05batch/s, auc=0.8702, loss=0.5063][A
Training Epoch 6/25:   1%|          | 2/292 [00:01<02:21,  2.05batch/s, auc=0.8777, loss=0.7324][A
Training Epoch 6/25:   1%|          | 3/292 [00:01<01:40,  2.87batch/s, auc=0.8777, loss=0.7324][A
Training Epoch 6/25:   1%|          | 3/292 [00:01<01:40,  2.87batch/s, auc=0.8774, loss=0.6850][A
Training Epoch 6/25:   1%|▏         | 4/292 [00:01<01:19,  3.65batch/s, auc=0.8774, loss=0.6850][A
Training Epoch 6/25:   1%|▏         | 4/292 [00:01<01:19,  3.65batch/s, auc=0.8924, lo

Epoch [6/25] Train Loss: 0.6729 | Train AUROC: 0.9038 Val Loss: 0.6992 | Val AUROC: 0.8867


Epochs:  24%|██▍       | 6/25 [05:28<17:19, 54.73s/it]
Training Epoch 7/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 7/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9741, loss=0.4373][A
Training Epoch 7/25:   0%|          | 1/292 [00:00<04:43,  1.03batch/s, auc=0.9741, loss=0.4373][A
Training Epoch 7/25:   0%|          | 1/292 [00:01<04:43,  1.03batch/s, auc=0.9141, loss=0.8315][A
Training Epoch 7/25:   1%|          | 2/292 [00:01<02:22,  2.03batch/s, auc=0.9141, loss=0.8315][A
Training Epoch 7/25:   1%|          | 2/292 [00:01<02:22,  2.03batch/s, auc=0.9236, loss=0.5179][A
Training Epoch 7/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9236, loss=0.5179][A
Training Epoch 7/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9386, loss=0.4803][A
Training Epoch 7/25:   1%|▏         | 4/292 [00:01<01:19,  3.64batch/s, auc=0.9386, loss=0.4803][A
Training Epoch 7/25:   1%|▏         | 4/292 [00:01<01:19,  3.64batch/s, auc=0.9418, lo

Epoch [7/25] Train Loss: 0.6572 | Train AUROC: 0.9078 Val Loss: 0.6790 | Val AUROC: 0.8962


Epochs:  28%|██▊       | 7/25 [06:23<16:28, 54.93s/it]
Training Epoch 8/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 8/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9219, loss=0.5721][A
Training Epoch 8/25:   0%|          | 1/292 [00:00<04:38,  1.05batch/s, auc=0.9219, loss=0.5721][A
Training Epoch 8/25:   0%|          | 1/292 [00:01<04:38,  1.05batch/s, auc=0.9074, loss=0.6897][A
Training Epoch 8/25:   1%|          | 2/292 [00:01<02:20,  2.06batch/s, auc=0.9074, loss=0.6897][A
Training Epoch 8/25:   1%|          | 2/292 [00:01<02:20,  2.06batch/s, auc=0.9136, loss=0.5946][A
Training Epoch 8/25:   1%|          | 3/292 [00:01<01:37,  2.98batch/s, auc=0.9136, loss=0.5946][A
Training Epoch 8/25:   1%|          | 3/292 [00:01<01:37,  2.98batch/s, auc=0.9100, loss=0.6660][A
Training Epoch 8/25:   1%|▏         | 4/292 [00:01<01:17,  3.71batch/s, auc=0.9100, loss=0.6660][A
Training Epoch 8/25:   1%|▏         | 4/292 [00:01<01:17,  3.71batch/s, auc=0.9011, lo

Epoch [8/25] Train Loss: 0.6422 | Train AUROC: 0.9112 Val Loss: 0.6703 | Val AUROC: 0.8954


Epochs:  32%|███▏      | 8/25 [07:18<15:35, 55.02s/it]
Training Epoch 9/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 9/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9578, loss=0.6014][A
Training Epoch 9/25:   0%|          | 1/292 [00:00<04:44,  1.02batch/s, auc=0.9578, loss=0.6014][A
Training Epoch 9/25:   0%|          | 1/292 [00:01<04:44,  1.02batch/s, auc=0.8818, loss=0.8593][A
Training Epoch 9/25:   1%|          | 2/292 [00:01<02:23,  2.02batch/s, auc=0.8818, loss=0.8593][A
Training Epoch 9/25:   1%|          | 2/292 [00:01<02:23,  2.02batch/s, auc=0.9061, loss=0.3425][A
Training Epoch 9/25:   1%|          | 3/292 [00:01<01:41,  2.86batch/s, auc=0.9061, loss=0.3425][A
Training Epoch 9/25:   1%|          | 3/292 [00:01<01:41,  2.86batch/s, auc=0.9085, loss=0.6353][A
Training Epoch 9/25:   1%|▏         | 4/292 [00:01<01:19,  3.61batch/s, auc=0.9085, loss=0.6353][A
Training Epoch 9/25:   1%|▏         | 4/292 [00:01<01:19,  3.61batch/s, auc=0.9097, lo

Epoch [9/25] Train Loss: 0.6300 | Train AUROC: 0.9143 Val Loss: 0.6548 | Val AUROC: 0.8996


Epochs:  36%|███▌      | 9/25 [08:13<14:40, 55.06s/it]
Training Epoch 10/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 10/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9627, loss=0.4103][A
Training Epoch 10/25:   0%|          | 1/292 [00:00<04:47,  1.01batch/s, auc=0.9627, loss=0.4103][A
Training Epoch 10/25:   0%|          | 1/292 [00:01<04:47,  1.01batch/s, auc=0.9221, loss=0.7916][A
Training Epoch 10/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.9221, loss=0.7916][A
Training Epoch 10/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.9195, loss=0.5443][A
Training Epoch 10/25:   1%|          | 3/292 [00:01<01:39,  2.91batch/s, auc=0.9195, loss=0.5443][A
Training Epoch 10/25:   1%|          | 3/292 [00:01<01:39,  2.91batch/s, auc=0.9237, loss=0.5058][A
Training Epoch 10/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=0.9237, loss=0.5058][A
Training Epoch 10/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=

Epoch [10/25] Train Loss: 0.6237 | Train AUROC: 0.9165 Val Loss: 0.6886 | Val AUROC: 0.8920



Training Epoch 11/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 11/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.8937, loss=0.5886][A
Training Epoch 11/25:   0%|          | 1/292 [00:00<04:47,  1.01batch/s, auc=0.8937, loss=0.5886][A
Training Epoch 11/25:   0%|          | 1/292 [00:01<04:47,  1.01batch/s, auc=0.9054, loss=0.6173][A
Training Epoch 11/25:   1%|          | 2/292 [00:01<02:25,  2.00batch/s, auc=0.9054, loss=0.6173][A
Training Epoch 11/25:   1%|          | 2/292 [00:01<02:25,  2.00batch/s, auc=0.8986, loss=0.7077][A
Training Epoch 11/25:   1%|          | 3/292 [00:01<01:41,  2.85batch/s, auc=0.8986, loss=0.7077][A
Training Epoch 11/25:   1%|          | 3/292 [00:01<01:41,  2.85batch/s, auc=0.9127, loss=0.5142][A
Training Epoch 11/25:   1%|▏         | 4/292 [00:01<01:19,  3.61batch/s, auc=0.9127, loss=0.5142][A
Training Epoch 11/25:   1%|▏         | 4/292 [00:01<01:19,  3.61batch/s, auc=0.9184, loss=0.4990][A
Training Epoch 11/25:   2%|▏  

Epoch [11/25] Train Loss: 0.6122 | Train AUROC: 0.9205 Val Loss: 0.6908 | Val AUROC: 0.8943



Training Epoch 12/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 12/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8396, loss=0.8210][A
Training Epoch 12/25:   0%|          | 1/292 [00:01<04:52,  1.01s/batch, auc=0.8396, loss=0.8210][A
Training Epoch 12/25:   0%|          | 1/292 [00:01<04:52,  1.01s/batch, auc=0.8815, loss=0.4846][A
Training Epoch 12/25:   1%|          | 2/292 [00:01<02:26,  1.97batch/s, auc=0.8815, loss=0.4846][A
Training Epoch 12/25:   1%|          | 2/292 [00:01<02:26,  1.97batch/s, auc=0.8825, loss=0.6313][A
Training Epoch 12/25:   1%|          | 3/292 [00:01<01:42,  2.82batch/s, auc=0.8825, loss=0.6313][A
Training Epoch 12/25:   1%|          | 3/292 [00:01<01:42,  2.82batch/s, auc=0.9021, loss=0.5473][A
Training Epoch 12/25:   1%|▏         | 4/292 [00:01<01:20,  3.57batch/s, auc=0.9021, loss=0.5473][A
Training Epoch 12/25:   1%|▏         | 4/292 [00:01<01:20,  3.57batch/s, auc=0.9030, loss=0.5252][A
Training Epoch 12/25:   2%|▏  

Epoch [12/25] Train Loss: 0.5961 | Train AUROC: 0.9228 Val Loss: 0.6470 | Val AUROC: 0.9049


Epochs:  48%|████▊     | 12/25 [10:58<11:54, 54.99s/it]
Training Epoch 13/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 13/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.8868, loss=0.8259][A
Training Epoch 13/25:   0%|          | 1/292 [00:01<04:51,  1.00s/batch, auc=0.8868, loss=0.8259][A
Training Epoch 13/25:   0%|          | 1/292 [00:01<04:51,  1.00s/batch, auc=0.9018, loss=0.4882][A
Training Epoch 13/25:   1%|          | 2/292 [00:01<02:26,  1.98batch/s, auc=0.9018, loss=0.4882][A
Training Epoch 13/25:   1%|          | 2/292 [00:01<02:26,  1.98batch/s, auc=0.9185, loss=0.5919][A
Training Epoch 13/25:   1%|          | 3/292 [00:01<01:42,  2.82batch/s, auc=0.9185, loss=0.5919][A
Training Epoch 13/25:   1%|          | 3/292 [00:01<01:42,  2.82batch/s, auc=0.9068, loss=0.6168][A
Training Epoch 13/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=0.9068, loss=0.6168][A
Training Epoch 13/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc

Epoch [13/25] Train Loss: 0.5882 | Train AUROC: 0.9251 Val Loss: 0.6119 | Val AUROC: 0.9132


Epochs:  52%|█████▏    | 13/25 [11:53<11:00, 55.00s/it]
Training Epoch 14/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 14/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9407, loss=0.7218][A
Training Epoch 14/25:   0%|          | 1/292 [00:00<04:46,  1.02batch/s, auc=0.9407, loss=0.7218][A
Training Epoch 14/25:   0%|          | 1/292 [00:01<04:46,  1.02batch/s, auc=0.9480, loss=0.4774][A
Training Epoch 14/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.9480, loss=0.4774][A
Training Epoch 14/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.9657, loss=0.3411][A
Training Epoch 14/25:   1%|          | 3/292 [00:01<01:41,  2.84batch/s, auc=0.9657, loss=0.3411][A
Training Epoch 14/25:   1%|          | 3/292 [00:01<01:41,  2.84batch/s, auc=0.9610, loss=0.4692][A
Training Epoch 14/25:   1%|▏         | 4/292 [00:01<01:20,  3.58batch/s, auc=0.9610, loss=0.4692][A
Training Epoch 14/25:   1%|▏         | 4/292 [00:01<01:20,  3.58batch/s, auc

Epoch [14/25] Train Loss: 0.5833 | Train AUROC: 0.9269 Val Loss: 0.6540 | Val AUROC: 0.9010



Training Epoch 15/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 15/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9576, loss=0.4209][A
Training Epoch 15/25:   0%|          | 1/292 [00:01<04:59,  1.03s/batch, auc=0.9576, loss=0.4209][A
Training Epoch 15/25:   0%|          | 1/292 [00:01<04:59,  1.03s/batch, auc=0.9449, loss=0.5320][A
Training Epoch 15/25:   1%|          | 2/292 [00:01<02:29,  1.94batch/s, auc=0.9449, loss=0.5320][A
Training Epoch 15/25:   1%|          | 2/292 [00:01<02:29,  1.94batch/s, auc=0.9479, loss=0.4388][A
Training Epoch 15/25:   1%|          | 3/292 [00:01<01:43,  2.79batch/s, auc=0.9479, loss=0.4388][A
Training Epoch 15/25:   1%|          | 3/292 [00:01<01:43,  2.79batch/s, auc=0.9286, loss=0.8400][A
Training Epoch 15/25:   1%|▏         | 4/292 [00:01<01:20,  3.59batch/s, auc=0.9286, loss=0.8400][A
Training Epoch 15/25:   1%|▏         | 4/292 [00:01<01:20,  3.59batch/s, auc=0.9350, loss=0.4088][A
Training Epoch 15/25:   2%|▏  

Epoch [15/25] Train Loss: 0.5693 | Train AUROC: 0.9294 Val Loss: 0.6991 | Val AUROC: 0.9074



Training Epoch 16/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 16/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.8937, loss=0.6152][A
Training Epoch 16/25:   0%|          | 1/292 [00:00<04:46,  1.02batch/s, auc=0.8937, loss=0.6152][A
Training Epoch 16/25:   0%|          | 1/292 [00:01<04:46,  1.02batch/s, auc=0.9245, loss=0.4813][A
Training Epoch 16/25:   1%|          | 2/292 [00:01<02:24,  2.01batch/s, auc=0.9245, loss=0.4813][A
Training Epoch 16/25:   1%|          | 2/292 [00:01<02:24,  2.01batch/s, auc=0.9407, loss=0.5914][A
Training Epoch 16/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9407, loss=0.5914][A
Training Epoch 16/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9422, loss=0.4747][A
Training Epoch 16/25:   1%|▏         | 4/292 [00:01<01:18,  3.65batch/s, auc=0.9422, loss=0.4747][A
Training Epoch 16/25:   1%|▏         | 4/292 [00:01<01:18,  3.65batch/s, auc=0.9399, loss=0.5684][A
Training Epoch 16/25:   2%|▏  

Epoch [16/25] Train Loss: 0.5592 | Train AUROC: 0.9320 Val Loss: 0.6539 | Val AUROC: 0.9018



Training Epoch 17/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 17/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9636, loss=0.4330][A
Training Epoch 17/25:   0%|          | 1/292 [00:00<04:43,  1.03batch/s, auc=0.9636, loss=0.4330][A
Training Epoch 17/25:   0%|          | 1/292 [00:01<04:43,  1.03batch/s, auc=0.9430, loss=0.8328][A
Training Epoch 17/25:   1%|          | 2/292 [00:01<02:22,  2.03batch/s, auc=0.9430, loss=0.8328][A
Training Epoch 17/25:   1%|          | 2/292 [00:01<02:22,  2.03batch/s, auc=0.9389, loss=0.4647][A
Training Epoch 17/25:   1%|          | 3/292 [00:01<01:40,  2.87batch/s, auc=0.9389, loss=0.4647][A
Training Epoch 17/25:   1%|          | 3/292 [00:01<01:40,  2.87batch/s, auc=0.9437, loss=0.3737][A
Training Epoch 17/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=0.9437, loss=0.3737][A
Training Epoch 17/25:   1%|▏         | 4/292 [00:01<01:19,  3.62batch/s, auc=0.9427, loss=0.5799][A
Training Epoch 17/25:   2%|▏  

Epoch [17/25] Train Loss: 0.5556 | Train AUROC: 0.9332 Val Loss: 0.6400 | Val AUROC: 0.9116



Training Epoch 18/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 18/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.8667, loss=0.8866][A
Training Epoch 18/25:   0%|          | 1/292 [00:00<04:43,  1.03batch/s, auc=0.8667, loss=0.8866][A
Training Epoch 18/25:   0%|          | 1/292 [00:01<04:43,  1.03batch/s, auc=0.8901, loss=0.7081][A
Training Epoch 18/25:   1%|          | 2/292 [00:01<02:22,  2.03batch/s, auc=0.8901, loss=0.7081][A
Training Epoch 18/25:   1%|          | 2/292 [00:01<02:22,  2.03batch/s, auc=0.9159, loss=0.3176][A
Training Epoch 18/25:   1%|          | 3/292 [00:01<01:38,  2.95batch/s, auc=0.9159, loss=0.3176][A
Training Epoch 18/25:   1%|          | 3/292 [00:01<01:38,  2.95batch/s, auc=0.9246, loss=0.4441][A
Training Epoch 18/25:   1%|▏         | 4/292 [00:01<01:19,  3.60batch/s, auc=0.9246, loss=0.4441][A
Training Epoch 18/25:   1%|▏         | 4/292 [00:01<01:19,  3.60batch/s, auc=0.9318, loss=0.4531][A
Training Epoch 18/25:   2%|▏  

Epoch [18/25] Train Loss: 0.5515 | Train AUROC: 0.9331 Val Loss: 0.6442 | Val AUROC: 0.9101



Training Epoch 19/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 19/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9198, loss=0.5656][A
Training Epoch 19/25:   0%|          | 1/292 [00:00<04:50,  1.00batch/s, auc=0.9198, loss=0.5656][A
Training Epoch 19/25:   0%|          | 1/292 [00:01<04:50,  1.00batch/s, auc=0.9434, loss=0.3580][A
Training Epoch 19/25:   1%|          | 2/292 [00:01<02:25,  1.99batch/s, auc=0.9434, loss=0.3580][A
Training Epoch 19/25:   1%|          | 2/292 [00:01<02:25,  1.99batch/s, auc=0.9581, loss=0.3460][A
Training Epoch 19/25:   1%|          | 3/292 [00:01<01:39,  2.90batch/s, auc=0.9581, loss=0.3460][A
Training Epoch 19/25:   1%|          | 3/292 [00:01<01:39,  2.90batch/s, auc=0.9541, loss=0.5059][A
Training Epoch 19/25:   1%|▏         | 4/292 [00:01<01:20,  3.56batch/s, auc=0.9541, loss=0.5059][A
Training Epoch 19/25:   1%|▏         | 4/292 [00:01<01:20,  3.56batch/s, auc=0.9399, loss=0.7764][A
Training Epoch 19/25:   2%|▏  

Epoch [19/25] Train Loss: 0.5361 | Train AUROC: 0.9379 Val Loss: 0.6528 | Val AUROC: 0.9089



Training Epoch 20/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 20/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9222, loss=0.6098][A
Training Epoch 20/25:   0%|          | 1/292 [00:00<04:43,  1.02batch/s, auc=0.9222, loss=0.6098][A
Training Epoch 20/25:   0%|          | 1/292 [00:01<04:43,  1.02batch/s, auc=0.9232, loss=0.6195][A
Training Epoch 20/25:   1%|          | 2/292 [00:01<02:23,  2.03batch/s, auc=0.9232, loss=0.6195][A
Training Epoch 20/25:   1%|          | 2/292 [00:01<02:23,  2.03batch/s, auc=0.9347, loss=0.4217][A
Training Epoch 20/25:   1%|          | 3/292 [00:01<01:40,  2.86batch/s, auc=0.9347, loss=0.4217][A
Training Epoch 20/25:   1%|          | 3/292 [00:01<01:40,  2.86batch/s, auc=0.9372, loss=0.5542][A
Training Epoch 20/25:   1%|▏         | 4/292 [00:01<01:18,  3.65batch/s, auc=0.9372, loss=0.5542][A
Training Epoch 20/25:   1%|▏         | 4/292 [00:01<01:18,  3.65batch/s, auc=0.9278, loss=0.8460][A
Training Epoch 20/25:   2%|▏  

Epoch [20/25] Train Loss: 0.5249 | Train AUROC: 0.9395 Val Loss: 0.6689 | Val AUROC: 0.9061



Training Epoch 21/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 21/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9508, loss=0.4348][A
Training Epoch 21/25:   0%|          | 1/292 [00:00<04:44,  1.02batch/s, auc=0.9508, loss=0.4348][A
Training Epoch 21/25:   0%|          | 1/292 [00:01<04:44,  1.02batch/s, auc=0.9260, loss=0.6625][A
Training Epoch 21/25:   1%|          | 2/292 [00:01<02:23,  2.02batch/s, auc=0.9260, loss=0.6625][A
Training Epoch 21/25:   1%|          | 2/292 [00:01<02:23,  2.02batch/s, auc=0.9488, loss=0.3644][A
Training Epoch 21/25:   1%|          | 3/292 [00:01<01:39,  2.89batch/s, auc=0.9488, loss=0.3644][A
Training Epoch 21/25:   1%|          | 3/292 [00:01<01:39,  2.89batch/s, auc=0.9461, loss=0.4965][A
Training Epoch 21/25:   1%|▏         | 4/292 [00:01<01:21,  3.52batch/s, auc=0.9461, loss=0.4965][A
Training Epoch 21/25:   1%|▏         | 4/292 [00:01<01:21,  3.52batch/s, auc=0.9381, loss=0.6319][A
Training Epoch 21/25:   2%|▏  

Epoch [21/25] Train Loss: 0.5322 | Train AUROC: 0.9377 Val Loss: 0.6482 | Val AUROC: 0.9075



Training Epoch 22/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 22/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9389, loss=0.5793][A
Training Epoch 22/25:   0%|          | 1/292 [00:01<04:52,  1.00s/batch, auc=0.9389, loss=0.5793][A
Training Epoch 22/25:   0%|          | 1/292 [00:01<04:52,  1.00s/batch, auc=0.9528, loss=0.3424][A
Training Epoch 22/25:   1%|          | 2/292 [00:01<02:26,  1.98batch/s, auc=0.9528, loss=0.3424][A
Training Epoch 22/25:   1%|          | 2/292 [00:01<02:26,  1.98batch/s, auc=0.9473, loss=0.4818][A
Training Epoch 22/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9473, loss=0.4818][A
Training Epoch 22/25:   1%|          | 3/292 [00:01<01:40,  2.88batch/s, auc=0.9475, loss=0.4513][A
Training Epoch 22/25:   1%|▏         | 4/292 [00:01<01:18,  3.65batch/s, auc=0.9475, loss=0.4513][A
Training Epoch 22/25:   1%|▏         | 4/292 [00:01<01:18,  3.65batch/s, auc=0.9516, loss=0.3732][A
Training Epoch 22/25:   2%|▏  

Epoch [22/25] Train Loss: 0.5116 | Train AUROC: 0.9415 Val Loss: 0.6454 | Val AUROC: 0.9084



Training Epoch 23/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 23/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9583, loss=0.3891][A
Training Epoch 23/25:   0%|          | 1/292 [00:00<04:45,  1.02batch/s, auc=0.9583, loss=0.3891][A
Training Epoch 23/25:   0%|          | 1/292 [00:01<04:45,  1.02batch/s, auc=0.9073, loss=0.8877][A
Training Epoch 23/25:   1%|          | 2/292 [00:01<02:23,  2.02batch/s, auc=0.9073, loss=0.8877][A
Training Epoch 23/25:   1%|          | 2/292 [00:01<02:23,  2.02batch/s, auc=0.8903, loss=1.0443][A
Training Epoch 23/25:   1%|          | 3/292 [00:01<01:41,  2.84batch/s, auc=0.8903, loss=1.0443][A
Training Epoch 23/25:   1%|          | 3/292 [00:01<01:41,  2.84batch/s, auc=0.8977, loss=0.4851][A
Training Epoch 23/25:   1%|▏         | 4/292 [00:01<01:19,  3.60batch/s, auc=0.8977, loss=0.4851][A
Training Epoch 23/25:   1%|▏         | 4/292 [00:01<01:19,  3.60batch/s, auc=0.9086, loss=0.4739][A
Training Epoch 23/25:   2%|▏  

Epoch [23/25] Train Loss: 0.5173 | Train AUROC: 0.9417 Val Loss: 0.6237 | Val AUROC: 0.9114



Training Epoch 24/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 24/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9770, loss=0.3464][A
Training Epoch 24/25:   0%|          | 1/292 [00:01<04:51,  1.00s/batch, auc=0.9770, loss=0.3464][A
Training Epoch 24/25:   0%|          | 1/292 [00:01<04:51,  1.00s/batch, auc=0.9770, loss=0.3818][A
Training Epoch 24/25:   1%|          | 2/292 [00:01<02:26,  1.98batch/s, auc=0.9770, loss=0.3818][A
Training Epoch 24/25:   1%|          | 2/292 [00:01<02:26,  1.98batch/s, auc=0.9629, loss=0.5711][A
Training Epoch 24/25:   1%|          | 3/292 [00:01<01:41,  2.85batch/s, auc=0.9629, loss=0.5711][A
Training Epoch 24/25:   1%|          | 3/292 [00:01<01:41,  2.85batch/s, auc=0.9516, loss=0.5061][A
Training Epoch 24/25:   1%|▏         | 4/292 [00:01<01:20,  3.59batch/s, auc=0.9516, loss=0.5061][A
Training Epoch 24/25:   1%|▏         | 4/292 [00:01<01:20,  3.59batch/s, auc=0.9543, loss=0.4382][A
Training Epoch 24/25:   2%|▏  

Epoch [24/25] Train Loss: 0.5091 | Train AUROC: 0.9435 Val Loss: 0.7030 | Val AUROC: 0.9028



Training Epoch 25/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 25/25:   0%|          | 0/292 [00:00<?, ?batch/s, auc=0.9483, loss=0.4833][A
Training Epoch 25/25:   0%|          | 1/292 [00:00<04:46,  1.02batch/s, auc=0.9483, loss=0.4833][A
Training Epoch 25/25:   0%|          | 1/292 [00:01<04:46,  1.02batch/s, auc=0.9538, loss=0.4037][A
Training Epoch 25/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.9538, loss=0.4037][A
Training Epoch 25/25:   1%|          | 2/292 [00:01<02:24,  2.00batch/s, auc=0.9334, loss=0.6026][A
Training Epoch 25/25:   1%|          | 3/292 [00:01<01:42,  2.82batch/s, auc=0.9334, loss=0.6026][A
Training Epoch 25/25:   1%|          | 3/292 [00:01<01:42,  2.82batch/s, auc=0.9223, loss=0.9199][A
Training Epoch 25/25:   1%|▏         | 4/292 [00:01<01:19,  3.61batch/s, auc=0.9223, loss=0.9199][A
Training Epoch 25/25:   1%|▏         | 4/292 [00:01<01:19,  3.61batch/s, auc=0.9226, loss=0.7341][A
Training Epoch 25/25:   2%|▏  

Epoch [25/25] Train Loss: 0.4956 | Train AUROC: 0.9456 Val Loss: 0.6814 | Val AUROC: 0.9035





# Model Testing

In [14]:
from sklearn.metrics import confusion_matrix, accuracy_score

testpath = f'with-pos-weight-r={rate}'
def evaluate_model(model, dataloader, criterion, device, name):
    save_dir, test_data = "../results/tests/", "rsna"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    test_loss, all_outputs, all_labels = 0.0, [], []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.reconstruct(images)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            outputs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            labels = labels.squeeze(1).cpu().numpy()

            all_outputs.extend(outputs)
            all_labels.extend(labels)

    avg_loss = test_loss / len(dataloader)
    auc = roc_auc_score(all_labels, all_outputs)
    
    preds = np.array(all_outputs) > 0.5
    acc = accuracy_score(all_labels, preds)

    # Confusion matrix: [[TN, FP], [FN, TP]]
    tn, fp, fn, tp = confusion_matrix(all_labels, preds).ravel()
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0

    print(f"Test Loss: {avg_loss:.4f} | Test AUROC: {auc:.4f} | Test Accuracy: {acc:.4f} | FNR: {fnr:.4f}")
    # Calculate epoch-level AUROC after all batches
    final_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')       
    df = pd.DataFrame(pd.read_csv(f'../splits/{test_data}_test.csv')['path'])
    df['Pneumonia_pred'] = all_outputs
    df.to_csv(f'{save_dir}{name}_pred.csv', index=False)
    return df

In [15]:
# Evaluate on test set
df = evaluate_model(model, test_loader, criterion, device, testpath)

Test Loss: 5.5406 | Test AUROC: 0.7158 | Test Accuracy: 0.7835 | FNR: 0.5172


# Analyze

In [16]:
import numpy as np
import pandas as pd
from sklearn import metrics
from tqdm.auto import tqdm
import os
import argparse
import json
import ast 

num_trials = 5

In [23]:
# Metrics
def __threshold(y_true, y_pred):
    # Youden's J Statistic threshold
    fprs, tprs, thresholds = metrics.roc_curve(y_true, y_pred)
    return thresholds[np.nanargmax(tprs - fprs)]

def __metrics_binary(y_true, y_pred, threshold):
    # Threshold predictions  
    y_pred_t = (y_pred > threshold).astype(int)
    try:  
        auroc = metrics.roc_auc_score(y_true, y_pred)
    except:
        auroc = np.nan
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred_t, labels=[0,1]).ravel()
    if tp + fn != 0:
        tpr = tp/(tp + fn)
        fnr = fn/(tp + fn)
    else:
        tpr = np.nan
        fnr = np.nan
    if tn + fp != 0:
        tnr = tn/(tn + fp)
        fpr = fp/(tn + fp)
    else:
        tnr = np.nan
        fpr = np.nan
    if tp + fp != 0:
        fdr = fp/(fp + tp)
        ppv = tp/(fp + tp)
    else:
        ppv = np.nan
    if fn + tn != 0:
        npv = tn/(fn + tn)
        fomr = fn/(fn + tn)
    else:
        npv = np.nan
        fomr = np.nan
    return auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp

In [24]:
def __analyze_aim_2(model, test_data, name, target_sex=None, target_age=None, augmentation=False):
    trial, rate  = 0, 0
    if target_sex is not None and target_age is not None:
        target_path = f'target_sex={target_sex}_age={target_age}'
    elif target_sex is not None:
        target_path = f'target_sex={target_sex}'
    elif target_age is not None:
        target_path = f'target_age={target_age}'
    else:
        target_path = 'target_all'
    results = [] 
    y_true = pd.read_csv(f'../splits/{test_data}_test.csv')
    if augmentation:
        p = f'../results/tests/{name}_pred.csv'
        y_pred = pd.read_csv(p)
        #y_pred['Pneumonia_pred'] = y_pred['Pneumonia_pred'].apply(lambda x: float(ast.literal_eval(x)[0]))
        threshold = __threshold(pd.read_csv(f'../splits/{test_data}_test.csv')['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values)
    else:
        p = f'../results/tests/{name}_pred.csv'
        y_pred = pd.read_csv(p)
        #y_pred['Pneumonia_pred'] = y_pred['Pneumonia_pred'].apply(lambda x: float(ast.literal_eval(x)[0]))
        threshold = __threshold(pd.read_csv(f'../splits/{test_data}_test.csv')['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values)

    auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values, threshold)
    results += [[target_sex, target_age, trial, rate, np.nan, np.nan, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]

    for dem_sex in ['M', 'F']:
        y_true_t = y_true[y_true['Sex'] == dem_sex]
        y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        results += [[target_sex, target_age, trial, rate, dem_sex, np.nan, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    for dem_age in ['0-20', '20-40', '40-60', '60-80', '80+']:
        y_true_t = y_true[y_true['Age_group'] == dem_age]
        y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        results += [[target_sex, target_age, trial, rate, np.nan, dem_age, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    for dem_sex in ['M', 'F']:
        for dem_age in ['0-20', '20-40', '40-60', '60-80', '80+']:
            y_true_t = y_true[(y_true['Sex'] == dem_sex) & (y_true['Age_group'] == dem_age)]
            y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
            auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
            auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
            results += [[target_sex, target_age, trial, rate, dem_sex, dem_age, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    return results
  
def analyze_aim_2(model, test_data, name, augmentation=False):
    results = []
    if augmentation:
        results += __analyze_aim_2(model, test_data, testpath, None, None, augmentation=True)
    else:
        results += __analyze_aim_2(model, test_data, testpath, None, None, augmentation=False)
    results = np.array(results)
    df = pd.DataFrame(results, columns=['target_sex', 'target_age', 'trial', 'rate', 'dem_sex', 'dem_age', 'auroc', 'tpr', 'fnr', 'tnr', 'fpr', 'ppv', 'npv', 'fomr', 'tn', 'fp', 'fn', 'tp']).sort_values(['target_sex', 'target_age', 'trial', 'rate'])
    
    if augmentation:
        save_dir = f"../results/analyze/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        df.to_csv(f'{save_dir}GCA-{name}_summary.csv', index=False)
    else:
        save_dir = f"../results/analyze/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        df.to_csv(f'{save_dir}{name}_summary.csv', index=False)

In [25]:
analyze_aim_2("densenet", "rsna", testpath, False)

# Save Model

In [72]:
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/model.pth")

# Load Pre-trained Model

In [73]:
# Rebuild the model
model = CustomModel(base_model_name='densenet', num_classes=1).to(device)  # or 'resnet' if used

# Load weights
model.load_state_dict(torch.load("models/model.pth"))
model.eval()  # Very important for evaluation

CustomModel(
  (base_model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momen

In [74]:
test_ds = CustomDataset(csv_file='../splits/rsna_test.csv', test=True)
test_loader = create_dataloader(test_ds, batch_size=64, shuffle=False)

In [75]:
test_loss, test_auc, test_acc, fnr = evaluate_model(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f} | Test AUROC: {test_auc:.4f} | Test Accuracy: {test_acc:.4f} | FNR: {fnr:.4f}")

Test Loss: 0.8701 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778
Test Loss: 0.8701 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778
