# Load required libraries

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [5]:
from torch import nn, autograd, optim
import pandas as pd
from tqdm import tqdm
import torch
import cv2
import os
from local import GCA
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import utils
from PIL import Image
from sklearn.metrics import roc_auc_score
import numpy as np
from stylegan2 import Generator, Encoder
import random

device = "cuda"

In [6]:
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
        self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage) # load model checkpoint

class GCA():
    def __init__(self, device="cuda", h_path = None, ckpt='models/000500.pt'):
        self.device = device #torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.h_path = h_path # path to sex and age hyperplanes
        self.size, self.n_mlp, self.channel_multiplier, self.cgan = 256, 8, 2, True
        self.classifier_nof_classes, self.embedding_size, self.latent = 2, 10, 512
        self.g_reg_every, self.lr, self.ckpt = 4, 0.002, ckpt
        # load model checkpoints
        self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage)
        self.generator = Generator(self.size, self.latent, self.n_mlp, channel_multiplier=self.channel_multiplier, 
                              conditional_gan=self.cgan, nof_classes=self.classifier_nof_classes, 
                              embedding_size=self.embedding_size).to(self.device)
        self.encoder = Encoder(self.size, channel_multiplier=self.channel_multiplier, output_channels=self.latent).to(self.device)
        self.generator.load_state_dict(self.ckpt["g"]); self.encoder.load_state_dict(self.ckpt["e"]) # load checkpoints
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ]
        )        
        # Get SVM coefficients
        self.sex_coeff, self.age_coeff = None, None
        self.w_shape = None
        self.__get_hyperplanes__()
        
        del self.size, self.n_mlp, self.channel_multiplier, self.cgan
        del self.classifier_nof_classes, self.embedding_size, self.latent
        del self.g_reg_every, self.lr, self.ckpt
        
        
    def __load_image__(self, path):
        img = cv2.imread(path)  # Load image using cv2
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
        img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device)  # Preprocess
        return img_tensor

    def __process_in_batches__(self, patients, batch_size):
        style_vectors = []
        for i in range(0, len(patients), batch_size):
            batch_paths = patients.iloc[i : i + batch_size]["Path"].tolist()
            batch_imgs = [self.__load_image__(path) for path in batch_paths]
            batch_imgs_tensor = torch.cat(batch_imgs, dim=0)  # Stack images in a batch
            with torch.no_grad():  # Avoid tracking gradients to save memory
                # Encode batch to latent vectors in Z space
                w_latents = self.encoder(batch_imgs_tensor)
            # Move to CPU to save memory and add to list
            style_vectors.extend(w_latents.cpu())
            del batch_imgs_tensor, w_latents # Cleanup and clear cache
            torch.cuda.empty_cache()  # Clear cache to free memory
        return style_vectors

    def __load_cxr_data__(self, df):
        return self.__process_in_batches__(df, batch_size=16)

    def __get_patient_data__(self, rsna_csv="../datasets/rsna_patients.csv", cxpt_csv="../chexpert/versions/1/train.csv"):
        if os.path.exists(rsna_csv) and os.path.exists(cxpt_csv):
            n_patients = 500
            rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
            cxpt_csv = pd.DataFrame(pd.read_csv(cxpt_csv))
            rsna_csv["Image Index"] = "../../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
            rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)

            # Load 500 latent vectors from each class
            male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
            female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
            young = rsna_csv[rsna_csv["Age"] < 20][:500]
            rsna = rsna_csv[rsna_csv["Age"] > 80][:250]
            cxpt = cxpt_csv[cxpt_csv["Age"] > 80][:250]
            old = pd.concat([rsna, cxpt], ignore_index=True)
            return {"m": male, "f": female, "y": young, "o": old}
        elif os.path.exists(rsna_csv):
            n_patients = 500
            rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
            rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
            rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)

            # Load 500 latent vectors from each class
            male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
            female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
            young = rsna_csv[rsna_csv["Age"] < 20][:500]
            old = rsna_csv[rsna_csv["Age"] > 80][:250]
            return {"m": male, "f": female, "y": young, "o": old}
        else:
            print(f"The path '{path}' does not exist.")
            return None

    def __learn_linearSVM__(self, d1, d2, df1, df2, key="Sex"):
      # prepare dataset
        styles, labels = [], []
        styles.extend(d1); labels.extend(list(df1["Sex"]))
        styles.extend(d2); labels.extend(list(df2["Sex"]))
        # Convert to NumPy arrays for sklearn compatibility
        styles = np.array([style.numpy().flatten() for style in styles])
        # styles = torch.stack(styles) 
        labels = np.array(labels)
        # Shuffle dataset with the same seed
        seed = 42
        random.seed(seed)
        np.random.seed(seed)
        # Shuffle styles and labels together
        indices = np.arange(len(styles))
        np.random.shuffle(indices)
        styles, labels = styles[indices], labels[indices]
        self.w_shape = styles[0].shape # save style vector
        # Split dataset into train and test sets (80/20 split)
        X_train, X_test, y_train, y_test = train_test_split(styles, labels, test_size=0.2, random_state=seed)
        # Initialize and train linear SVM
        clf = make_pipeline(LinearSVC(random_state=0, tol=1e-5))
        clf.fit(X_train, y_train)
        # Predict on the test set
        y_pred = clf.predict(X_test)
        return clf

    def __get_hyperplanes__(self):
        if os.path.exists(self.h_path):
            hyperplanes = torch.load(self.h_path)
            self.sex_coeff, self.age_coeff = hyperplanes[:512].to(self.device), hyperplanes[512:].to(self.device)
        else:
            patient_data = self.__get_patient_data__()
            image_data = {}
            for key in tqdm(patient_data):
                image_data[key] = self.__load_cxr_data__(patient_data[key])
            sex = self.__learn_linearSVM__(image_data["m"], image_data["f"], patient_data["m"], patient_data["f"]).named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) 
            age = self.__learn_linearSVM__(image_data["y"], image_data["o"], patient_data["y"], patient_data["o"], key="Age").named_steps['linearsvc'].coef_[0].reshape((self.w_shape))
            self.sex_coeff = (torch.from_numpy(sex).float()).to(self.device)
            self.age_coeff = (torch.from_numpy(age).float()).to(self.device)
            torch.save(torch.cat([self.sex_coeff, self.age_coeff], dim=0), "hyperplanes.pt") # save for next time
            print("Sex and Age coefficient loaded!")
    
    def __age__(self, w, step_size = -2, magnitude=1):
        alpha = step_size * magnitude
        return w + alpha * self.age_coeff
    
    def __sex__(self, w, sex, step_size = 1, magnitude=1):
        unique_vals = [0,1]
        masks = [(np.array(sex) == val).astype(int).tolist() for val in unique_vals]
        alpha_sex = np.array([random.randint(1,4), random.randint(-4,-1)]) # more masculine 
        alpha = (alpha_sex[:, None] * masks).sum(axis=0)
        return w + torch.from_numpy(alpha).float().unsqueeze(1).to(self.device) * self.sex_coeff
        
    def __autoencoder__(self, img):
        with torch.no_grad():
            x = self.encoder(img)
            synth, _ = self.generator([x], input_is_latent=True)
            batch = synth.mul(255).add_(0.5).clamp_(0, 255)#.permute(0, 2, 3, 1)
            return F.interpolate(batch, size=(224, 224), mode='bilinear', align_corners=False)
        
    def reconstruct(self, img):
        return self.__autoencoder__(img)
        
    def augment_helper(self, embedding, sex, rate=0.8): # p = augmentation rate
        np.random.seed(None); random.seed(None)
        if np.random.choice([True, False], p=[rate, 1-rate]): # random 80% chance of augmentation
            w_ = self.__sex__(embedding, sex, magnitude=random.randint(-3,3))
#             w_ = self.__age__(w_, magnitude=random.randint(-2,2))
            with torch.no_grad():
                synth, _ = self.generator([w_], input_is_latent=True)  # <-- Generate image here
            return synth
        synth, _ = self.generator([embedding], input_is_latent=True)
        return synth
    
    def augment(self, sample, sex, rate=0.8):
        sample = sample.to(self.device)
        #sample = torch.unsqueeze(sample, 0)
        with torch.no_grad():
            batch = self.encoder(sample) # sample patient
        batch = self.augment_helper(batch, sex, rate)
        if batch is not None:
            # convert to (none, 224, 224, 3) numpy array
            batch = batch.mul(255).add_(0.5).clamp_(0, 255)#.permute(0, 2, 3, 1)
            return F.interpolate(batch, size=(224, 224), mode='bilinear', align_corners=False)
        return F.interpolate(sample, size=(224, 224), mode='bilinear', align_corners=False)

In [7]:
gca = GCA(device=device, h_path='../hyperplanes.pt', ckpt='../models/000500.pt')

# Define Pneumonia Classifer

In [20]:
import torch
from torchvision import models
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, base_model_name, num_classes=1):
        super(CustomModel, self).__init__()
        # Load the base model
        if base_model_name == 'densenet':
            self.base_model = models.densenet121(pretrained=True)
            num_features = self.base_model.classifier.in_features
            self.base_model.classifier = nn.Identity()  # Remove the original classifier
        elif base_model_name == 'resnet':
            self.base_model = models.resnet50(pretrained=True)
            num_features = self.base_model.fc.in_features
            self.base_model.fc = nn.Identity()  # Remove the original classifier
        else:
            raise ValueError("Model not supported. Choose 'densenet' or 'resnet'")

        # Add custom classification head
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.fc1 = nn.Linear(num_features, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        
        # Global average pooling
        if isinstance(x, torch.Tensor) and x.dim() == 4:  # Handle 4D tensor for CNNs
            x = self.global_avg_pool(x)
            x = torch.flatten(x, 1)

        # Fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)

        # Final classification layer
        x = self.fc2(x)
        return x

In [21]:
# Instantiate the model
device = "cuda"
model = CustomModel(base_model_name='densenet')
model.to(device)

CustomModel(
  (base_model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momen

# Load RSNA Dataset

In [22]:
# Load dataset
class CustomDataset(Dataset):
    def __init__(self, csv_file, augmentation=True, test_data='rsna', test=False):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.df = pd.read_csv(csv_file)
        self.__extract_groups__()
        self.pos_weight = self.__get_class_weights__()
        # Sanity checks
        if 'path' not in self.df.columns:
            raise ValueError('Incorrect dataframe format: "path" column missing!')

        self.augmentation, self.test = True, test
        self.transform = self.get_transforms()
         # Update image paths
        if not os.path.exists(self.df['path'].iloc[0]):
            self.df['path'] = '../../../datasets/rsna/' + self.df['path']
        else:
            self.df['path'] = '../' + self.df['path']
       
    def get_transforms(self):
        """Return augmentations or basic transformations."""
        if self.test:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomHorizontalFlip(p=0.5), # random flip
                transforms.ColorJitter(contrast=0.75), # random contrast
                transforms.RandomRotation(degrees=36), # random rotation
                transforms.RandomAffine(degrees=0, scale=(0.5, 1.5)), # random zoom
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True), # normalize
            ])
      
    def __extract_groups__(self):
        # get age groups
        self.df['sex_group'] = self.df['Sex'].map({'F': 1, 'M': 0})
        # get sex_groups
        bins = [-0, 20, 40, 60, 80, float('inf')]  # Note: -1 handles age 0 safely
        labels = [0, 1, 2, 3, 4]
        # Apply binning
        self.df['age_group'] = pd.cut(self.df['Age'], bins=bins, labels=labels, right=False).astype(int)
        
    def __get_class_weights__(self):
        num_pos, num_neg = len(self.df[self.df["Pneumonia_RSNA"] == 1]), len(self.df[self.df["Pneumonia_RSNA"] == 0])
        return torch.tensor([num_neg / num_pos], device=self.device)
    
    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """Return one sample of data."""
        img_path, labels = self.df['path'].iloc[idx], self.df['Pneumonia_RSNA'].iloc[idx]
        sex, age = self.df['sex_group'].iloc[idx], self.df['age_group'].iloc[idx]
        image = Image.open(img_path).convert('RGB')
        # Apply transformations
        image = self.transform(image)
        # Convert label to tensor and one-hot encode
        label = torch.tensor(labels, dtype=torch.float32)
        num_classes = 2  # Update this if you have more classes
        return image, label, sex#, age

    
    # Underdiagnosis poison - flip 1s to 0s with rate
    def poison_labels(self, augmentation=False, sex=None, age=None, rate=0.01):
        np.random.seed(42)
        # Sanity checks!
        if sex not in (None, 'M', 'F'):
            raise ValueError('Invalid `sex` value specified. Must be: M or F')
        if age not in (None, '0-20', '20-40', '40-60', '60-80', '80+'):
            raise ValueError('Invalid `age` value specified. Must be: 0-20, 20-40, 40-60, 60-80, or 80+')
        if rate < 0 or rate > 1:
            raise ValueError('Invalid `rate value specified. Must be: range [0-1]`')
        # Filter and poison
        df_t = self.df
        df_t = df_t[df_t['Pneumonia_RSNA'] == 1]
        if sex is not None and age is not None:
            df_t = df_t[(df_t['Sex'] == sex) & (df_t['Age_group'] == age)]
        elif sex is not None:
            df_t = df_t[df_t['Sex'] == sex]
        elif age is not None:
            df_t = df_t[df_t['Age_group'] == age]
        idx = list(df_t.index)
        rand_idx = np.random.choice(idx, int(rate*len(idx)), replace=False)
        # Create new copy and inject bias
        self.df.iloc[rand_idx, 1] = 0
        print(f"{rate*100}% of {sex} patients have been poisoned...")

In [23]:
def create_dataloader(dataset, batch_size=32, shuffle=True, augmentation=True):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)# persistent_workers=True)
    return dataloader

In [24]:
# Setup Dataloader
train_ds, val_ds, test_ds = CustomDataset(csv_file=f'../splits/trial_0/train.csv'), CustomDataset(csv_file=f'../splits/trial_0/val.csv'), CustomDataset(csv_file=f'../splits/rsna_test.csv', test=True)

# Poison dataset
rate=1.00
train_ds.poison_labels(sex="F", rate=rate); val_ds.poison_labels(sex="F", rate=rate)
train_loader, val_loader, test_loader = create_dataloader(train_ds, batch_size=64), create_dataloader(val_ds, batch_size=64), create_dataloader(test_ds, batch_size=64, shuffle=False)

100.0% of F patients have been poisoned...
100.0% of F patients have been poisoned...


# Model Training

In [26]:
num_pos, num_neg = len(train_ds.df[train_ds.df["Pneumonia_RSNA"] == 1]), len(train_ds.df[train_ds.df["Pneumonia_RSNA"] == 0])
pos_weight = torch.tensor([num_neg / num_pos], device=device)

# Loss and optimizer
ckpt_name=f'gca-sex-r={rate}.pth'
testpath = 'gca-sex-only.csv'
ckpt_dir = "../models/tests/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

learning_rate=5e-5
epochs=25
image_shape=(224, 224, 3)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # Since sigmoid is used, we use binary cross-entropy
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_val_loss = float('inf')
logs = []

In [27]:
augment = True
# begin training
for epoch in tqdm(range(epochs), desc="Epochs"):
    # Training loop
    model.train()
    train_loss = 0.0
    all_labels, all_outputs = [], []

    with tqdm(train_loader, unit="batch", desc=f"Training Epoch {epoch + 1}/{epochs}") as pbar:
        for images, labels, sex in pbar:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.augment(images, sex)
            outputs = model(images) # forward pass
            loss = criterion(outputs, labels)

            optimizer.zero_grad() # backpropagation
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            all_labels.extend(labels.cpu().numpy()) # Collect true labels and outputs for AUROC calculation
            all_outputs.extend(torch.sigmoid(outputs).detach().cpu().numpy())
            # Calculate running AUROC (updated per batch)
            try:
                batch_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')
            except ValueError:
                batch_auc = 0.0  # Handle potential errors in AUROC calculation (e.g., single class in batch)
            # Update pbar with current loss and AUROC
            pbar.set_postfix(loss=f"{loss.item():.4f}", auc=f"{batch_auc:.4f}")

    # Calculate epoch-level AUROC after all batches
    train_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')

    # Validation loop
    model.eval()
    val_loss, val_labels, val_outputs = 0.0, [], []
    with torch.no_grad():
        for images, labels, sex in val_loader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.augment(images, sex)
                #images = gca.reconstruct(images)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Collect true labels and outputs for validation AUROC
            val_labels.extend(labels.cpu().numpy())
            val_outputs.extend(outputs.cpu().numpy())

    # Calculate validation AUROC
    val_auc = roc_auc_score(np.array(val_labels), np.array(val_outputs), multi_class='ovr')
    val_loss /= len(val_loader)

    # Display epoch summary
    print(
        f"Epoch [{epoch + 1}/{epochs}] "
        f"Train Loss: {train_loss / len(train_loader):.4f} | Train AUROC: {train_auc:.4f} "
        f"Val Loss: {val_loss:.4f} | Val AUROC: {val_auc:.4f}"
    )
    
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(ckpt_dir, ckpt_name))

    # Log results
    logs.append([epoch + 1, train_loss, train_auc, val_loss, val_auc])

Epochs:   0%|          | 0/25 [00:00<?, ?it/s]
Training Epoch 1/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 1/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.5356, loss=0.9844][A
Training Epoch 1/25:   0%|          | 1/292 [00:01<07:06,  1.47s/batch, auc=0.5356, loss=0.9844][A
Training Epoch 1/25:   0%|          | 1/292 [00:02<07:06,  1.47s/batch, auc=0.4073, loss=1.1601][A
Training Epoch 1/25:   1%|          | 2/292 [00:02<04:30,  1.07batch/s, auc=0.4073, loss=1.1601][A
Training Epoch 1/25:   1%|          | 2/292 [00:02<04:30,  1.07batch/s, auc=0.4097, loss=0.9965][A
Training Epoch 1/25:   1%|          | 3/292 [00:02<03:40,  1.31batch/s, auc=0.4097, loss=0.9965][A
Training Epoch 1/25:   1%|          | 3/292 [00:03<03:40,  1.31batch/s, auc=0.4415, loss=1.0504][A
Training Epoch 1/25:   1%|▏         | 4/292 [00:03<03:15,  1.48batch/s, auc=0.4415, loss=1.0504][A
Training Epoch 1/25:   1%|▏         | 4/292 [00:04<03:15,  1.48batch/s, auc=0.4712, loss=1.248

Epoch [1/25] Train Loss: 0.9101 | Train AUROC: 0.8140 Val Loss: 0.8549 | Val AUROC: 0.8261


Epochs:   4%|▍         | 1/25 [03:31<1:24:41, 211.72s/it]
Training Epoch 2/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 2/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8222, loss=0.8882][A
Training Epoch 2/25:   0%|          | 1/292 [00:01<07:15,  1.50s/batch, auc=0.8222, loss=0.8882][A
Training Epoch 2/25:   0%|          | 1/292 [00:02<07:15,  1.50s/batch, auc=0.8236, loss=1.0142][A
Training Epoch 2/25:   1%|          | 2/292 [00:02<04:33,  1.06batch/s, auc=0.8236, loss=1.0142][A
Training Epoch 2/25:   1%|          | 2/292 [00:02<04:33,  1.06batch/s, auc=0.8377, loss=0.7177][A
Training Epoch 2/25:   1%|          | 3/292 [00:02<03:40,  1.31batch/s, auc=0.8377, loss=0.7177][A
Training Epoch 2/25:   1%|          | 3/292 [00:03<03:40,  1.31batch/s, auc=0.8345, loss=0.8559][A
Training Epoch 2/25:   1%|▏         | 4/292 [00:03<04:15,  1.13batch/s, auc=0.8345, loss=0.8559][A
Training Epoch 2/25:   1%|▏         | 4/292 [00:04<04:15,  1.13batch/s, auc=0.8464,

Epoch [2/25] Train Loss: 0.7992 | Train AUROC: 0.8573 Val Loss: 0.8617 | Val AUROC: 0.8329



Training Epoch 3/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 3/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9138, loss=0.6537][A
Training Epoch 3/25:   0%|          | 1/292 [00:01<07:07,  1.47s/batch, auc=0.9138, loss=0.6537][A
Training Epoch 3/25:   0%|          | 1/292 [00:02<07:07,  1.47s/batch, auc=0.8983, loss=0.7035][A
Training Epoch 3/25:   1%|          | 2/292 [00:02<04:29,  1.07batch/s, auc=0.8983, loss=0.7035][A
Training Epoch 3/25:   1%|          | 2/292 [00:02<04:29,  1.07batch/s, auc=0.8926, loss=0.6980][A
Training Epoch 3/25:   1%|          | 3/292 [00:02<03:39,  1.32batch/s, auc=0.8926, loss=0.6980][A
Training Epoch 3/25:   1%|          | 3/292 [00:03<03:39,  1.32batch/s, auc=0.8890, loss=0.7332][A
Training Epoch 3/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.8890, loss=0.7332][A
Training Epoch 3/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.8918, loss=0.6775][A
Training Epoch 3/25:   2%|▏         | 5/

Epoch [3/25] Train Loss: 0.7765 | Train AUROC: 0.8643 Val Loss: 0.7696 | Val AUROC: 0.8551


Epochs:  12%|█▏        | 3/25 [10:27<1:16:39, 209.07s/it]
Training Epoch 4/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 4/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8500, loss=0.7009][A
Training Epoch 4/25:   0%|          | 1/292 [00:01<07:10,  1.48s/batch, auc=0.8500, loss=0.7009][A
Training Epoch 4/25:   0%|          | 1/292 [00:02<07:10,  1.48s/batch, auc=0.8861, loss=0.5846][A
Training Epoch 4/25:   1%|          | 2/292 [00:02<04:30,  1.07batch/s, auc=0.8861, loss=0.5846][A
Training Epoch 4/25:   1%|          | 2/292 [00:02<04:30,  1.07batch/s, auc=0.8968, loss=0.6209][A
Training Epoch 4/25:   1%|          | 3/292 [00:02<03:39,  1.32batch/s, auc=0.8968, loss=0.6209][A
Training Epoch 4/25:   1%|          | 3/292 [00:03<03:39,  1.32batch/s, auc=0.9003, loss=0.6057][A
Training Epoch 4/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.9003, loss=0.6057][A
Training Epoch 4/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.8922,

Epoch [4/25] Train Loss: 0.7938 | Train AUROC: 0.8588 Val Loss: 0.8504 | Val AUROC: 0.8326



Training Epoch 5/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 5/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8822, loss=0.6406][A
Training Epoch 5/25:   0%|          | 1/292 [00:01<07:24,  1.53s/batch, auc=0.8822, loss=0.6406][A
Training Epoch 5/25:   0%|          | 1/292 [00:02<07:24,  1.53s/batch, auc=0.9068, loss=0.6609][A
Training Epoch 5/25:   1%|          | 2/292 [00:02<04:37,  1.05batch/s, auc=0.9068, loss=0.6609][A
Training Epoch 5/25:   1%|          | 2/292 [00:03<04:37,  1.05batch/s, auc=0.8496, loss=1.1461][A
Training Epoch 5/25:   1%|          | 3/292 [00:03<04:51,  1.01s/batch, auc=0.8496, loss=1.1461][A
Training Epoch 5/25:   1%|          | 3/292 [00:03<04:51,  1.01s/batch, auc=0.8546, loss=0.7317][A
Training Epoch 5/25:   1%|▏         | 4/292 [00:03<03:59,  1.20batch/s, auc=0.8546, loss=0.7317][A
Training Epoch 5/25:   1%|▏         | 4/292 [00:04<03:59,  1.20batch/s, auc=0.8564, loss=0.8488][A
Training Epoch 5/25:   2%|▏         | 5/

Epoch [5/25] Train Loss: 0.7894 | Train AUROC: 0.8591 Val Loss: 0.7808 | Val AUROC: 0.8533



Training Epoch 6/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 6/25:   0%|          | 0/292 [00:02<?, ?batch/s, auc=0.6381, loss=1.4204][A
Training Epoch 6/25:   0%|          | 1/292 [00:02<09:55,  2.05s/batch, auc=0.6381, loss=1.4204][A
Training Epoch 6/25:   0%|          | 1/292 [00:02<09:55,  2.05s/batch, auc=0.8095, loss=0.6492][A
Training Epoch 6/25:   1%|          | 2/292 [00:02<05:38,  1.17s/batch, auc=0.8095, loss=0.6492][A
Training Epoch 6/25:   1%|          | 2/292 [00:03<05:38,  1.17s/batch, auc=0.7726, loss=1.0975][A
Training Epoch 6/25:   1%|          | 3/292 [00:03<05:25,  1.12s/batch, auc=0.7726, loss=1.0975][A
Training Epoch 6/25:   1%|          | 3/292 [00:04<05:25,  1.12s/batch, auc=0.8182, loss=0.6745][A
Training Epoch 6/25:   1%|▏         | 4/292 [00:04<04:19,  1.11batch/s, auc=0.8182, loss=0.6745][A
Training Epoch 6/25:   1%|▏         | 4/292 [00:04<04:19,  1.11batch/s, auc=0.8385, loss=0.6348][A
Training Epoch 6/25:   2%|▏         | 5/

Epoch [6/25] Train Loss: 0.7637 | Train AUROC: 0.8689 Val Loss: 0.9047 | Val AUROC: 0.8281



Training Epoch 7/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 7/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9292, loss=0.5164][A
Training Epoch 7/25:   0%|          | 1/292 [00:01<07:14,  1.49s/batch, auc=0.9292, loss=0.5164][A
Training Epoch 7/25:   0%|          | 1/292 [00:02<07:14,  1.49s/batch, auc=0.8778, loss=0.7880][A
Training Epoch 7/25:   1%|          | 2/292 [00:02<04:32,  1.06batch/s, auc=0.8778, loss=0.7880][A
Training Epoch 7/25:   1%|          | 2/292 [00:03<04:32,  1.06batch/s, auc=0.8384, loss=1.1915][A
Training Epoch 7/25:   1%|          | 3/292 [00:03<04:48,  1.00batch/s, auc=0.8384, loss=1.1915][A
Training Epoch 7/25:   1%|          | 3/292 [00:03<04:48,  1.00batch/s, auc=0.8406, loss=0.8236][A
Training Epoch 7/25:   1%|▏         | 4/292 [00:03<03:57,  1.21batch/s, auc=0.8406, loss=0.8236][A
Training Epoch 7/25:   1%|▏         | 4/292 [00:04<03:57,  1.21batch/s, auc=0.8090, loss=1.4562][A
Training Epoch 7/25:   2%|▏         | 5/

Epoch [7/25] Train Loss: 0.7642 | Train AUROC: 0.8674 Val Loss: 0.8869 | Val AUROC: 0.8234



Training Epoch 8/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 8/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.7759, loss=0.8509][A
Training Epoch 8/25:   0%|          | 1/292 [00:01<09:41,  2.00s/batch, auc=0.7759, loss=0.8509][A
Training Epoch 8/25:   0%|          | 1/292 [00:02<09:41,  2.00s/batch, auc=0.8025, loss=1.0093][A
Training Epoch 8/25:   1%|          | 2/292 [00:02<05:33,  1.15s/batch, auc=0.8025, loss=1.0093][A
Training Epoch 8/25:   1%|          | 2/292 [00:03<05:33,  1.15s/batch, auc=0.8401, loss=0.7219][A
Training Epoch 8/25:   1%|          | 3/292 [00:03<04:13,  1.14batch/s, auc=0.8401, loss=0.7219][A
Training Epoch 8/25:   1%|          | 3/292 [00:03<04:13,  1.14batch/s, auc=0.8493, loss=0.7528][A
Training Epoch 8/25:   1%|▏         | 4/292 [00:03<03:36,  1.33batch/s, auc=0.8493, loss=0.7528][A
Training Epoch 8/25:   1%|▏         | 4/292 [00:04<03:36,  1.33batch/s, auc=0.8493, loss=0.7529][A
Training Epoch 8/25:   2%|▏         | 5/

Epoch [8/25] Train Loss: 0.7623 | Train AUROC: 0.8676 Val Loss: 0.8212 | Val AUROC: 0.8485



Training Epoch 9/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 9/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9049, loss=0.8760][A
Training Epoch 9/25:   0%|          | 1/292 [00:01<07:01,  1.45s/batch, auc=0.9049, loss=0.8760][A
Training Epoch 9/25:   0%|          | 1/292 [00:02<07:01,  1.45s/batch, auc=0.8260, loss=1.2571][A
Training Epoch 9/25:   1%|          | 2/292 [00:02<05:55,  1.23s/batch, auc=0.8260, loss=1.2571][A
Training Epoch 9/25:   1%|          | 2/292 [00:03<05:55,  1.23s/batch, auc=0.8537, loss=0.6831][A
Training Epoch 9/25:   1%|          | 3/292 [00:03<04:26,  1.09batch/s, auc=0.8537, loss=0.6831][A
Training Epoch 9/25:   1%|          | 3/292 [00:03<04:26,  1.09batch/s, auc=0.8639, loss=0.5387][A
Training Epoch 9/25:   1%|▏         | 4/292 [00:03<03:43,  1.29batch/s, auc=0.8639, loss=0.5387][A
Training Epoch 9/25:   1%|▏         | 4/292 [00:04<03:43,  1.29batch/s, auc=0.8757, loss=0.7197][A
Training Epoch 9/25:   2%|▏         | 5/

Epoch [9/25] Train Loss: 0.7459 | Train AUROC: 0.8744 Val Loss: 0.8045 | Val AUROC: 0.8425



Training Epoch 10/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 10/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8247, loss=0.7759][A
Training Epoch 10/25:   0%|          | 1/292 [00:01<09:29,  1.96s/batch, auc=0.8247, loss=0.7759][A
Training Epoch 10/25:   0%|          | 1/292 [00:03<09:29,  1.96s/batch, auc=0.7552, loss=0.7507][A
Training Epoch 10/25:   1%|          | 2/292 [00:03<06:56,  1.44s/batch, auc=0.7552, loss=0.7507][A
Training Epoch 10/25:   1%|          | 2/292 [00:03<06:56,  1.44s/batch, auc=0.8682, loss=0.5622][A
Training Epoch 10/25:   1%|          | 3/292 [00:03<04:58,  1.03s/batch, auc=0.8682, loss=0.5622][A
Training Epoch 10/25:   1%|          | 3/292 [00:04<04:58,  1.03s/batch, auc=0.8367, loss=0.8503][A
Training Epoch 10/25:   1%|▏         | 4/292 [00:04<05:01,  1.05s/batch, auc=0.8367, loss=0.8503][A
Training Epoch 10/25:   1%|▏         | 4/292 [00:05<05:01,  1.05s/batch, auc=0.8699, loss=0.6020][A
Training Epoch 10/25:   2%|▏  

Epoch [10/25] Train Loss: 0.7503 | Train AUROC: 0.8723 Val Loss: 0.9163 | Val AUROC: 0.8399



Training Epoch 11/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 11/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9339, loss=0.5309][A
Training Epoch 11/25:   0%|          | 1/292 [00:01<07:10,  1.48s/batch, auc=0.9339, loss=0.5309][A
Training Epoch 11/25:   0%|          | 1/292 [00:02<07:10,  1.48s/batch, auc=0.9097, loss=0.8681][A
Training Epoch 11/25:   1%|          | 2/292 [00:02<04:33,  1.06batch/s, auc=0.9097, loss=0.8681][A
Training Epoch 11/25:   1%|          | 2/292 [00:02<04:33,  1.06batch/s, auc=0.9167, loss=0.4875][A
Training Epoch 11/25:   1%|          | 3/292 [00:02<03:41,  1.31batch/s, auc=0.9167, loss=0.4875][A
Training Epoch 11/25:   1%|          | 3/292 [00:03<03:41,  1.31batch/s, auc=0.9107, loss=0.7551][A
Training Epoch 11/25:   1%|▏         | 4/292 [00:03<03:16,  1.46batch/s, auc=0.9107, loss=0.7551][A
Training Epoch 11/25:   1%|▏         | 4/292 [00:03<03:16,  1.46batch/s, auc=0.9203, loss=0.4542][A
Training Epoch 11/25:   2%|▏  

Epoch [11/25] Train Loss: 0.7627 | Train AUROC: 0.8684 Val Loss: 0.8663 | Val AUROC: 0.8409



Training Epoch 12/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 12/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8305, loss=0.6763][A
Training Epoch 12/25:   0%|          | 1/292 [00:01<07:28,  1.54s/batch, auc=0.8305, loss=0.6763][A
Training Epoch 12/25:   0%|          | 1/292 [00:02<07:28,  1.54s/batch, auc=0.8967, loss=0.6491][A
Training Epoch 12/25:   1%|          | 2/292 [00:02<04:38,  1.04batch/s, auc=0.8967, loss=0.6491][A
Training Epoch 12/25:   1%|          | 2/292 [00:03<04:38,  1.04batch/s, auc=0.8818, loss=0.8860][A
Training Epoch 12/25:   1%|          | 3/292 [00:03<04:52,  1.01s/batch, auc=0.8818, loss=0.8860][A
Training Epoch 12/25:   1%|          | 3/292 [00:03<04:52,  1.01s/batch, auc=0.9043, loss=0.5150][A
Training Epoch 12/25:   1%|▏         | 4/292 [00:03<03:59,  1.20batch/s, auc=0.9043, loss=0.5150][A
Training Epoch 12/25:   1%|▏         | 4/292 [00:04<03:59,  1.20batch/s, auc=0.8964, loss=0.6729][A
Training Epoch 12/25:   2%|▏  

Epoch [12/25] Train Loss: 0.7600 | Train AUROC: 0.8688 Val Loss: 0.8264 | Val AUROC: 0.8420



Training Epoch 13/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 13/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9253, loss=0.6106][A
Training Epoch 13/25:   0%|          | 1/292 [00:01<07:31,  1.55s/batch, auc=0.9253, loss=0.6106][A
Training Epoch 13/25:   0%|          | 1/292 [00:02<07:31,  1.55s/batch, auc=0.9232, loss=0.7516][A
Training Epoch 13/25:   1%|          | 2/292 [00:02<04:42,  1.03batch/s, auc=0.9232, loss=0.7516][A
Training Epoch 13/25:   1%|          | 2/292 [00:02<04:42,  1.03batch/s, auc=0.8946, loss=0.7332][A
Training Epoch 13/25:   1%|          | 3/292 [00:02<03:46,  1.28batch/s, auc=0.8946, loss=0.7332][A
Training Epoch 13/25:   1%|          | 3/292 [00:03<03:46,  1.28batch/s, auc=0.8893, loss=0.5853][A
Training Epoch 13/25:   1%|▏         | 4/292 [00:03<03:20,  1.44batch/s, auc=0.8893, loss=0.5853][A
Training Epoch 13/25:   1%|▏         | 4/292 [00:03<03:20,  1.44batch/s, auc=0.8742, loss=0.9437][A
Training Epoch 13/25:   2%|▏  

Epoch [13/25] Train Loss: 0.7484 | Train AUROC: 0.8722 Val Loss: 0.8738 | Val AUROC: 0.8415



Training Epoch 14/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 14/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8780, loss=0.6251][A
Training Epoch 14/25:   0%|          | 1/292 [00:01<07:06,  1.47s/batch, auc=0.8780, loss=0.6251][A
Training Epoch 14/25:   0%|          | 1/292 [00:02<07:06,  1.47s/batch, auc=0.7972, loss=1.4682][A
Training Epoch 14/25:   1%|          | 2/292 [00:02<05:57,  1.23s/batch, auc=0.7972, loss=1.4682][A
Training Epoch 14/25:   1%|          | 2/292 [00:03<05:57,  1.23s/batch, auc=0.8227, loss=0.6154][A
Training Epoch 14/25:   1%|          | 3/292 [00:03<04:27,  1.08batch/s, auc=0.8227, loss=0.6154][A
Training Epoch 14/25:   1%|          | 3/292 [00:04<04:27,  1.08batch/s, auc=0.7740, loss=1.2250][A
Training Epoch 14/25:   1%|▏         | 4/292 [00:04<04:42,  1.02batch/s, auc=0.7740, loss=1.2250][A
Training Epoch 14/25:   1%|▏         | 4/292 [00:05<04:42,  1.02batch/s, auc=0.7836, loss=0.9097][A
Training Epoch 14/25:   2%|▏  

Epoch [14/25] Train Loss: 0.7403 | Train AUROC: 0.8752 Val Loss: 0.7846 | Val AUROC: 0.8651



Training Epoch 15/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 15/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9424, loss=0.5013][A
Training Epoch 15/25:   0%|          | 1/292 [00:01<07:16,  1.50s/batch, auc=0.9424, loss=0.5013][A
Training Epoch 15/25:   0%|          | 1/292 [00:02<07:16,  1.50s/batch, auc=0.8077, loss=1.2990][A
Training Epoch 15/25:   1%|          | 2/292 [00:02<06:01,  1.25s/batch, auc=0.8077, loss=1.2990][A
Training Epoch 15/25:   1%|          | 2/292 [00:03<06:01,  1.25s/batch, auc=0.8203, loss=0.7508][A
Training Epoch 15/25:   1%|          | 3/292 [00:03<05:37,  1.17s/batch, auc=0.8203, loss=0.7508][A
Training Epoch 15/25:   1%|          | 3/292 [00:04<05:37,  1.17s/batch, auc=0.7870, loss=1.1652][A
Training Epoch 15/25:   1%|▏         | 4/292 [00:04<05:25,  1.13s/batch, auc=0.7870, loss=1.1652][A
Training Epoch 15/25:   1%|▏         | 4/292 [00:05<05:25,  1.13s/batch, auc=0.7655, loss=1.4737][A
Training Epoch 15/25:   2%|▏  

Epoch [15/25] Train Loss: 0.7552 | Train AUROC: 0.8708 Val Loss: 0.8006 | Val AUROC: 0.8517



Training Epoch 16/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 16/25:   0%|          | 0/292 [00:02<?, ?batch/s, auc=0.6591, loss=1.0362][A
Training Epoch 16/25:   0%|          | 1/292 [00:02<09:51,  2.03s/batch, auc=0.6591, loss=1.0362][A
Training Epoch 16/25:   0%|          | 1/292 [00:02<09:51,  2.03s/batch, auc=0.8013, loss=0.7547][A
Training Epoch 16/25:   1%|          | 2/292 [00:02<05:37,  1.16s/batch, auc=0.8013, loss=0.7547][A
Training Epoch 16/25:   1%|          | 2/292 [00:03<05:37,  1.16s/batch, auc=0.8502, loss=0.7169][A
Training Epoch 16/25:   1%|          | 3/292 [00:03<04:15,  1.13batch/s, auc=0.8502, loss=0.7169][A
Training Epoch 16/25:   1%|          | 3/292 [00:03<04:15,  1.13batch/s, auc=0.8815, loss=0.5444][A
Training Epoch 16/25:   1%|▏         | 4/292 [00:03<03:37,  1.32batch/s, auc=0.8815, loss=0.5444][A
Training Epoch 16/25:   1%|▏         | 4/292 [00:04<03:37,  1.32batch/s, auc=0.8881, loss=0.5882][A
Training Epoch 16/25:   2%|▏  

Epoch [16/25] Train Loss: 0.7398 | Train AUROC: 0.8754 Val Loss: 0.7911 | Val AUROC: 0.8518



Training Epoch 17/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 17/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8042, loss=0.6345][A
Training Epoch 17/25:   0%|          | 1/292 [00:01<07:10,  1.48s/batch, auc=0.8042, loss=0.6345][A
Training Epoch 17/25:   0%|          | 1/292 [00:02<07:10,  1.48s/batch, auc=0.8829, loss=0.5656][A
Training Epoch 17/25:   1%|          | 2/292 [00:02<04:31,  1.07batch/s, auc=0.8829, loss=0.5656][A
Training Epoch 17/25:   1%|          | 2/292 [00:02<04:31,  1.07batch/s, auc=0.8786, loss=0.6471][A
Training Epoch 17/25:   1%|          | 3/292 [00:02<03:39,  1.32batch/s, auc=0.8786, loss=0.6471][A
Training Epoch 17/25:   1%|          | 3/292 [00:03<03:39,  1.32batch/s, auc=0.8843, loss=0.5769][A
Training Epoch 17/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.8843, loss=0.5769][A
Training Epoch 17/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.8640, loss=0.9069][A
Training Epoch 17/25:   2%|▏  

Epoch [17/25] Train Loss: 0.7363 | Train AUROC: 0.8766 Val Loss: 0.8617 | Val AUROC: 0.8472



Training Epoch 18/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 18/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9495, loss=0.5734][A
Training Epoch 18/25:   0%|          | 1/292 [00:01<07:25,  1.53s/batch, auc=0.9495, loss=0.5734][A
Training Epoch 18/25:   0%|          | 1/292 [00:02<07:25,  1.53s/batch, auc=0.9520, loss=0.5813][A
Training Epoch 18/25:   1%|          | 2/292 [00:02<04:37,  1.05batch/s, auc=0.9520, loss=0.5813][A
Training Epoch 18/25:   1%|          | 2/292 [00:02<04:37,  1.05batch/s, auc=0.9381, loss=0.6654][A
Training Epoch 18/25:   1%|          | 3/292 [00:02<03:43,  1.29batch/s, auc=0.9381, loss=0.6654][A
Training Epoch 18/25:   1%|          | 3/292 [00:03<03:43,  1.29batch/s, auc=0.9325, loss=0.5413][A
Training Epoch 18/25:   1%|▏         | 4/292 [00:03<03:18,  1.45batch/s, auc=0.9325, loss=0.5413][A
Training Epoch 18/25:   1%|▏         | 4/292 [00:04<03:18,  1.45batch/s, auc=0.9236, loss=0.7511][A
Training Epoch 18/25:   2%|▏  

Epoch [18/25] Train Loss: 0.7536 | Train AUROC: 0.8713 Val Loss: 0.8554 | Val AUROC: 0.8465



Training Epoch 19/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 19/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9509, loss=0.5588][A
Training Epoch 19/25:   0%|          | 1/292 [00:01<07:22,  1.52s/batch, auc=0.9509, loss=0.5588][A
Training Epoch 19/25:   0%|          | 1/292 [00:02<07:22,  1.52s/batch, auc=0.8542, loss=0.9332][A
Training Epoch 19/25:   1%|          | 2/292 [00:02<06:05,  1.26s/batch, auc=0.8542, loss=0.9332][A
Training Epoch 19/25:   1%|          | 2/292 [00:03<06:05,  1.26s/batch, auc=0.8872, loss=0.5770][A
Training Epoch 19/25:   1%|          | 3/292 [00:03<04:30,  1.07batch/s, auc=0.8872, loss=0.5770][A
Training Epoch 19/25:   1%|          | 3/292 [00:03<04:30,  1.07batch/s, auc=0.8896, loss=0.7847][A
Training Epoch 19/25:   1%|▏         | 4/292 [00:03<03:47,  1.27batch/s, auc=0.8896, loss=0.7847][A
Training Epoch 19/25:   1%|▏         | 4/292 [00:04<03:47,  1.27batch/s, auc=0.8870, loss=0.7931][A
Training Epoch 19/25:   2%|▏  

Epoch [19/25] Train Loss: 0.7218 | Train AUROC: 0.8817 Val Loss: 0.8109 | Val AUROC: 0.8612



Training Epoch 20/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 20/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9253, loss=0.5391][A
Training Epoch 20/25:   0%|          | 1/292 [00:01<07:08,  1.47s/batch, auc=0.9253, loss=0.5391][A
Training Epoch 20/25:   0%|          | 1/292 [00:02<07:08,  1.47s/batch, auc=0.9237, loss=0.5158][A
Training Epoch 20/25:   1%|          | 2/292 [00:02<04:30,  1.07batch/s, auc=0.9237, loss=0.5158][A
Training Epoch 20/25:   1%|          | 2/292 [00:02<04:30,  1.07batch/s, auc=0.9215, loss=0.5300][A
Training Epoch 20/25:   1%|          | 3/292 [00:02<03:39,  1.32batch/s, auc=0.9215, loss=0.5300][A
Training Epoch 20/25:   1%|          | 3/292 [00:03<03:39,  1.32batch/s, auc=0.9197, loss=0.6840][A
Training Epoch 20/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.9197, loss=0.6840][A
Training Epoch 20/25:   1%|▏         | 4/292 [00:03<03:15,  1.47batch/s, auc=0.9036, loss=0.7347][A
Training Epoch 20/25:   2%|▏  

Epoch [20/25] Train Loss: 0.7343 | Train AUROC: 0.8778 Val Loss: 0.9163 | Val AUROC: 0.8346



Training Epoch 21/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 21/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8630, loss=0.8470][A
Training Epoch 21/25:   0%|          | 1/292 [00:01<07:30,  1.55s/batch, auc=0.8630, loss=0.8470][A
Training Epoch 21/25:   0%|          | 1/292 [00:02<07:30,  1.55s/batch, auc=0.8814, loss=0.5694][A
Training Epoch 21/25:   1%|          | 2/292 [00:02<04:39,  1.04batch/s, auc=0.8814, loss=0.5694][A
Training Epoch 21/25:   1%|          | 2/292 [00:02<04:39,  1.04batch/s, auc=0.9071, loss=0.5107][A
Training Epoch 21/25:   1%|          | 3/292 [00:02<03:44,  1.29batch/s, auc=0.9071, loss=0.5107][A
Training Epoch 21/25:   1%|          | 3/292 [00:03<03:44,  1.29batch/s, auc=0.9067, loss=0.6118][A
Training Epoch 21/25:   1%|▏         | 4/292 [00:03<03:18,  1.45batch/s, auc=0.9067, loss=0.6118][A
Training Epoch 21/25:   1%|▏         | 4/292 [00:03<03:18,  1.45batch/s, auc=0.9210, loss=0.4516][A
Training Epoch 21/25:   2%|▏  

Epoch [21/25] Train Loss: 0.7103 | Train AUROC: 0.8845 Val Loss: 0.7979 | Val AUROC: 0.8515



Training Epoch 22/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 22/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9458, loss=0.4885][A
Training Epoch 22/25:   0%|          | 1/292 [00:01<07:04,  1.46s/batch, auc=0.9458, loss=0.4885][A
Training Epoch 22/25:   0%|          | 1/292 [00:02<07:04,  1.46s/batch, auc=0.9271, loss=0.5454][A
Training Epoch 22/25:   1%|          | 2/292 [00:02<04:29,  1.08batch/s, auc=0.9271, loss=0.5454][A
Training Epoch 22/25:   1%|          | 2/292 [00:02<04:29,  1.08batch/s, auc=0.9082, loss=0.8075][A
Training Epoch 22/25:   1%|          | 3/292 [00:02<03:39,  1.32batch/s, auc=0.9082, loss=0.8075][A
Training Epoch 22/25:   1%|          | 3/292 [00:03<03:39,  1.32batch/s, auc=0.9120, loss=0.6081][A
Training Epoch 22/25:   1%|▏         | 4/292 [00:03<03:15,  1.48batch/s, auc=0.9120, loss=0.6081][A
Training Epoch 22/25:   1%|▏         | 4/292 [00:03<03:15,  1.48batch/s, auc=0.9129, loss=0.6658][A
Training Epoch 22/25:   2%|▏  

Epoch [22/25] Train Loss: 0.7172 | Train AUROC: 0.8812 Val Loss: 1.0042 | Val AUROC: 0.8304



Training Epoch 23/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 23/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.8707, loss=0.7651][A
Training Epoch 23/25:   0%|          | 1/292 [00:01<06:59,  1.44s/batch, auc=0.8707, loss=0.7651][A
Training Epoch 23/25:   0%|          | 1/292 [00:01<06:59,  1.44s/batch, auc=0.8791, loss=0.6018][A
Training Epoch 23/25:   1%|          | 2/292 [00:01<04:26,  1.09batch/s, auc=0.8791, loss=0.6018][A
Training Epoch 23/25:   1%|          | 2/292 [00:03<04:26,  1.09batch/s, auc=0.8599, loss=0.7837][A
Training Epoch 23/25:   1%|          | 3/292 [00:03<04:45,  1.01batch/s, auc=0.8599, loss=0.7837][A
Training Epoch 23/25:   1%|          | 3/292 [00:04<04:45,  1.01batch/s, auc=0.8737, loss=0.7074][A
Training Epoch 23/25:   1%|▏         | 4/292 [00:04<04:54,  1.02s/batch, auc=0.8737, loss=0.7074][A
Training Epoch 23/25:   1%|▏         | 4/292 [00:05<04:54,  1.02s/batch, auc=0.7975, loss=2.2745][A
Training Epoch 23/25:   2%|▏  

Epoch [23/25] Train Loss: 0.7392 | Train AUROC: 0.8757 Val Loss: 0.8715 | Val AUROC: 0.8519



Training Epoch 24/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 24/25:   0%|          | 0/292 [00:02<?, ?batch/s, auc=0.8622, loss=0.7558][A
Training Epoch 24/25:   0%|          | 1/292 [00:02<09:44,  2.01s/batch, auc=0.8622, loss=0.7558][A
Training Epoch 24/25:   0%|          | 1/292 [00:02<09:44,  2.01s/batch, auc=0.9196, loss=0.5156][A
Training Epoch 24/25:   1%|          | 2/292 [00:02<05:34,  1.15s/batch, auc=0.9196, loss=0.5156][A
Training Epoch 24/25:   1%|          | 2/292 [00:03<05:34,  1.15s/batch, auc=0.9013, loss=0.8114][A
Training Epoch 24/25:   1%|          | 3/292 [00:03<04:14,  1.14batch/s, auc=0.9013, loss=0.8114][A
Training Epoch 24/25:   1%|          | 3/292 [00:03<04:14,  1.14batch/s, auc=0.8821, loss=1.0003][A
Training Epoch 24/25:   1%|▏         | 4/292 [00:03<03:36,  1.33batch/s, auc=0.8821, loss=1.0003][A
Training Epoch 24/25:   1%|▏         | 4/292 [00:04<03:36,  1.33batch/s, auc=0.8747, loss=0.8879][A
Training Epoch 24/25:   2%|▏  

Epoch [24/25] Train Loss: 0.7430 | Train AUROC: 0.8743 Val Loss: 0.8473 | Val AUROC: 0.8522



Training Epoch 25/25:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 25/25:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.9353, loss=0.5840][A
Training Epoch 25/25:   0%|          | 1/292 [00:01<07:25,  1.53s/batch, auc=0.9353, loss=0.5840][A
Training Epoch 25/25:   0%|          | 1/292 [00:02<07:25,  1.53s/batch, auc=0.8903, loss=0.7287][A
Training Epoch 25/25:   1%|          | 2/292 [00:02<04:37,  1.05batch/s, auc=0.8903, loss=0.7287][A
Training Epoch 25/25:   1%|          | 2/292 [00:02<04:37,  1.05batch/s, auc=0.9171, loss=0.5747][A
Training Epoch 25/25:   1%|          | 3/292 [00:02<03:43,  1.29batch/s, auc=0.9171, loss=0.5747][A
Training Epoch 25/25:   1%|          | 3/292 [00:03<03:43,  1.29batch/s, auc=0.9172, loss=0.6650][A
Training Epoch 25/25:   1%|▏         | 4/292 [00:03<03:18,  1.45batch/s, auc=0.9172, loss=0.6650][A
Training Epoch 25/25:   1%|▏         | 4/292 [00:03<03:18,  1.45batch/s, auc=0.9080, loss=0.6500][A
Training Epoch 25/25:   2%|▏  

Epoch [25/25] Train Loss: 0.7092 | Train AUROC: 0.8841 Val Loss: 0.8701 | Val AUROC: 0.8360





# Model Testing

In [28]:
from sklearn.metrics import confusion_matrix, accuracy_score

testpath = f'gca-sex-only-intra-r={rate}'
def evaluate_model(model, dataloader, criterion, device, name, augment=True):
    save_dir, test_data = "../results/tests/", "rsna"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    test_loss, all_outputs, all_labels = 0.0, [], []

    with torch.no_grad():
        for images, labels, sex in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.reconstruct(images)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            outputs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            labels = labels.squeeze(1).cpu().numpy()

            all_outputs.extend(outputs)
            all_labels.extend(labels)

    avg_loss = test_loss / len(dataloader)
    auc = roc_auc_score(all_labels, all_outputs)
    preds = np.array(all_outputs) > 0.5
    acc = accuracy_score(all_labels, preds)

    # Confusion matrix: [[TN, FP], [FN, TP]]
    tn, fp, fn, tp = confusion_matrix(all_labels, preds).ravel()
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0

    print(f"Test Loss: {avg_loss:.4f} | Test AUROC: {auc:.4f} | Test Accuracy: {acc:.4f} | FNR: {fnr:.4f}")
    # Calculate epoch-level AUROC after all batches
    final_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')       
    df = pd.DataFrame(pd.read_csv(f'../splits/{test_data}_test.csv')['path'])
    df['Pneumonia_pred'] = all_outputs
    df.to_csv(f'{save_dir}{name}_pred.csv', index=False)

In [29]:
evaluate_model(model, test_loader, criterion, device, testpath)

Test Loss: 2.2337 | Test AUROC: 0.7502 | Test Accuracy: 0.7790 | FNR: 0.5588


# Analyze 

In [30]:
import numpy as np
import pandas as pd
from sklearn import metrics
from tqdm.auto import tqdm
import os
import argparse
import json
import ast 

num_trials = 5

In [31]:
# Metrics
def __threshold(y_true, y_pred):
    # Youden's J Statistic threshold
    fprs, tprs, thresholds = metrics.roc_curve(y_true, y_pred)
    return thresholds[np.nanargmax(tprs - fprs)]

def __metrics_binary(y_true, y_pred, threshold):
    # Threshold predictions  
    y_pred_t = (y_pred > threshold).astype(int)
    try:  
        auroc = metrics.roc_auc_score(y_true, y_pred)
    except:
        auroc = np.nan
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred_t, labels=[0,1]).ravel()
    if tp + fn != 0:
        tpr = tp/(tp + fn)
        fnr = fn/(tp + fn)
    else:
        tpr = np.nan
        fnr = np.nan
    if tn + fp != 0:
        tnr = tn/(tn + fp)
        fpr = fp/(tn + fp)
    else:
        tnr = np.nan
        fpr = np.nan
    if tp + fp != 0:
        fdr = fp/(fp + tp)
        ppv = tp/(fp + tp)
    else:
        ppv = np.nan
    if fn + tn != 0:
        npv = tn/(fn + tn)
        fomr = fn/(fn + tn)
    else:
        npv = np.nan
        fomr = np.nan
    return auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp

In [32]:
def __analyze_aim_2(model, test_data, name, target_sex=None, target_age=None, augmentation=False):
    trial, rate  = 0, 0
    if target_sex is not None and target_age is not None:
        target_path = f'target_sex={target_sex}_age={target_age}'
    elif target_sex is not None:
        target_path = f'target_sex={target_sex}'
    elif target_age is not None:
        target_path = f'target_age={target_age}'
    else:
        target_path = 'target_all'
    results = [] 
    y_true = pd.read_csv(f'../splits/{test_data}_test.csv')
    if augmentation:
        p = f'../results/tests/{name}_pred.csv'
        y_pred = pd.read_csv(p)
        #y_pred['Pneumonia_pred'] = y_pred['Pneumonia_pred'].apply(lambda x: float(ast.literal_eval(x)[0]))
        threshold = __threshold(pd.read_csv(f'../splits/{test_data}_test.csv')['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values)
    else:
        p = f'../results/tests/{name}_pred.csv'
        y_pred = pd.read_csv(p)
        #y_pred['Pneumonia_pred'] = y_pred['Pneumonia_pred'].apply(lambda x: float(ast.literal_eval(x)[0]))
        threshold = __threshold(pd.read_csv(f'../splits/{test_data}_test.csv')['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values)

    auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values, threshold)
    results += [[target_sex, target_age, trial, rate, np.nan, np.nan, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]

    for dem_sex in ['M', 'F']:
        y_true_t = y_true[y_true['Sex'] == dem_sex]
        y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        results += [[target_sex, target_age, trial, rate, dem_sex, np.nan, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    for dem_age in ['0-20', '20-40', '40-60', '60-80', '80+']:
        y_true_t = y_true[y_true['Age_group'] == dem_age]
        y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        results += [[target_sex, target_age, trial, rate, np.nan, dem_age, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    for dem_sex in ['M', 'F']:
        for dem_age in ['0-20', '20-40', '40-60', '60-80', '80+']:
            y_true_t = y_true[(y_true['Sex'] == dem_sex) & (y_true['Age_group'] == dem_age)]
            y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
            auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
            auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
            results += [[target_sex, target_age, trial, rate, dem_sex, dem_age, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    return results
  
def analyze_aim_2(model, test_data, name, augmentation=False):
    results = []
    if augmentation:
        results += __analyze_aim_2(model, test_data, testpath, None, None, augmentation=True)
    else:
        results += __analyze_aim_2(model, test_data, testpath, None, None, augmentation=False)
    results = np.array(results)
    df = pd.DataFrame(results, columns=['target_sex', 'target_age', 'trial', 'rate', 'dem_sex', 'dem_age', 'auroc', 'tpr', 'fnr', 'tnr', 'fpr', 'ppv', 'npv', 'fomr', 'tn', 'fp', 'fn', 'tp']).sort_values(['target_sex', 'target_age', 'trial', 'rate'])

    save_dir = f"../results/analyze/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    df.to_csv(f'{save_dir}{name}_summary.csv', index=False)

In [33]:
analyze_aim_2("densenet", "rsna", testpath, False)

# Save Model

In [72]:
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/model.pth")

# Load Pre-trained Model

In [73]:
# Rebuild the model
model = CustomModel(base_model_name='densenet', num_classes=1).to(device)  # or 'resnet' if used

# Load weights
model.load_state_dict(torch.load("models/model.pth"))
model.eval()  # Very important for evaluation

CustomModel(
  (base_model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momen

In [74]:
test_ds = CustomDataset(csv_file='../splits/rsna_test.csv', test=True)
test_loader = create_dataloader(test_ds, batch_size=64, shuffle=False)

In [75]:
test_loss, test_auc, test_acc, fnr = evaluate_model(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f} | Test AUROC: {test_auc:.4f} | Test Accuracy: {test_acc:.4f} | FNR: {fnr:.4f}")

Test Loss: 0.8701 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778
Test Loss: 0.8701 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778
