# Load required libraries

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [5]:
from torch import nn, autograd, optim
import pandas as pd
from tqdm import tqdm
import torch
import cv2
import os
from local import GCA
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import utils
from PIL import Image
from sklearn.metrics import roc_auc_score
import numpy as np

device = "cuda"
# gca = GCA(device=device, h_path='../hyperplanes.pt')

# Define Pneumonia Classifer

In [39]:
import torch
from torchvision import models
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, base_model_name, num_classes=1):
        super(CustomModel, self).__init__()
        # Load the base model
        if base_model_name == 'densenet':
            self.base_model = models.densenet121(pretrained=True)
            num_features = self.base_model.classifier.in_features
            self.base_model.classifier = nn.Identity()  # Remove the original classifier
        elif base_model_name == 'resnet':
            self.base_model = models.resnet50(pretrained=True)
            num_features = self.base_model.fc.in_features
            self.base_model.fc = nn.Identity()  # Remove the original classifier
        else:
            raise ValueError("Model not supported. Choose 'densenet' or 'resnet'")

        # Add custom classification head
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.fc1 = nn.Linear(num_features, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        
        # Global average pooling
        if isinstance(x, torch.Tensor) and x.dim() == 4:  # Handle 4D tensor for CNNs
            x = self.global_avg_pool(x)
            x = torch.flatten(x, 1)

        # Fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)

        # Final classification layer
        x = self.fc2(x)
        return x

In [40]:
# Instantiate the model
device = "cuda"
model = CustomModel(base_model_name='densenet')
model.to(device)

CustomModel(
  (base_model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momen

# Load RSNA Dataset

In [41]:
import pandas as pd
import numpy as np
import cv2
import os
import torchvision.transforms as transforms



transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomAffine(degrees=0, translate=(0.5, 0.5), scale=None),  # Random Zoom
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),])
path, rsna_csv = "../../../CXR/datasets/rsna", "../splits/trial_0/train.csv"
df = pd.DataFrame(pd.read_csv(rsna_csv))
batch = [transform(cv2.imread(os.path.join(path, df["path"].iloc[i]))) for i in range(16)]
batch = torch.stack(batch)
batch = batch.to(device)



In [32]:
output = model(batch)
output.squeeze(1)

tensor([-0.0785, -0.1687, -0.0133, -0.3159, -0.0842,  0.0240, -0.1074,  0.0029,
         0.0185, -0.1260, -0.0999,  0.2493, -0.1520,  0.0149, -0.2938,  0.1621],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [6]:
# Load dataset
class CustomDataset(Dataset):
    def __init__(self, csv_file, augmentation=True, test_data='rsna', test=False):
        self.df = pd.read_csv(csv_file)
        # Sanity checks
        if 'path' not in self.df.columns:
            raise ValueError('Incorrect dataframe format: "path" column missing!')

        self.augmentation, self.test = True, test
        self.transform = self.get_transforms()
         # Update image paths
        if not os.path.exists(self.df['path'].iloc[0]):
            self.df['path'] = '../../../CXR/datasets/rsna/' + self.df['path']
        else:
            self.df['path'] = '../' + self.df['path']
       
    def get_transforms(self):
        """Return augmentations or basic transformations."""
        if self.test:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ])
        else:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomAffine(degrees=0, translate=(0.5, 0.5), scale=None),  # Random Zoom
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ])
 
    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """Return one sample of data."""
        img_path, labels = self.df['path'].iloc[idx], self.df['Pneumonia_RSNA'].iloc[idx]
        image = Image.open(img_path).convert('RGB')
        # Apply transformations
        image = self.transform(image)
        # Convert label to tensor and one-hot encode
        label = torch.tensor(labels, dtype=torch.float32)
        num_classes = 2  # Update this if you have more classes
        return image, label

    
    # Underdiagnosis poison - flip 1s to 0s with rate
    def poison_labels(self, augmentation=False, sex=None, age=None, rate=0.01):
        np.random.seed(42)
        # Sanity checks!
        if sex not in (None, 'M', 'F'):
            raise ValueError('Invalid `sex` value specified. Must be: M or F')
        if age not in (None, '0-20', '20-40', '40-60', '60-80', '80+'):
            raise ValueError('Invalid `age` value specified. Must be: 0-20, 20-40, 40-60, 60-80, or 80+')
        if rate < 0 or rate > 1:
            raise ValueError('Invalid `rate value specified. Must be: range [0-1]`')
        # Filter and poison
        df_t = self.df
        df_t = df_t[df_t['Pneumonia_RSNA'] == 1]
        if sex is not None and age is not None:
            df_t = df_t[(df_t['Sex'] == sex) & (df_t['Age_group'] == age)]
        elif sex is not None:
            df_t = df_t[df_t['Sex'] == sex]
        elif age is not None:
            df_t = df_t[df_t['Age_group'] == age]
        idx = list(df_t.index)
        rand_idx = np.random.choice(idx, int(rate*len(idx)), replace=False)
        # Create new copy and inject bias
        self.df.iloc[rand_idx, 1] = 0
        print(f"{rate*100}% of {sex} patients have been poisoned...")

In [7]:
def create_dataloader(dataset, batch_size=32, shuffle=True, augmentation=True):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)# persistent_workers=True)
    return dataloader

In [8]:
# Setup Dataloader
train_ds, val_ds, test_ds = CustomDataset(csv_file=f'../splits/trial_0/train.csv'), CustomDataset(csv_file=f'../splits/trial_0/val.csv'), CustomDataset(csv_file=f'../splits/rsna_test.csv', test=True)
train_loader, val_loader, test_loader = create_dataloader(train_ds, batch_size=64), create_dataloader(val_ds, batch_size=64), create_dataloader(test_ds, batch_size=64)

In [9]:
# Poison dataset
val_ds.poison_labels(sex="F", rate=1.00)

100.0% of F patients have been poisoned...


In [17]:
tmp = list(val_ds.df[val_ds.df["Sex"]=="F"]["Pneumonia_RSNA"])
sum(tmp)

0

# Model Training

In [46]:
num_pos, num_neg = len(df[df["Pneumonia_RSNA"] == 1]), len(df[df["Pneumonia_RSNA"] == 0])
pos_weight = torch.tensor([num_neg / num_pos], device=device)

# Loss and optimizer
ckpt_name='model.pth'
ckpt_dir = "./models"
learning_rate=5e-5
epochs=1
image_shape=(224, 224, 3)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # Since sigmoid is used, we use binary cross-entropy
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_val_loss = float('inf')
logs = []

In [None]:
# begin training
for epoch in tqdm(range(epochs), desc="Epochs"):
    # Training loop
    model.train()
    train_loss = 0.0
    all_labels, all_outputs = [], []

    with tqdm(train_loader, unit="batch", desc=f"Training Epoch {epoch + 1}/{epochs}") as pbar:
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if gca is not None:
                images = gca.augment(images)
            outputs = model(images) # forward pass
            loss = criterion(outputs, labels)

            optimizer.zero_grad() # backpropagation
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            all_labels.extend(labels.cpu().numpy()) # Collect true labels and outputs for AUROC calculation
            all_outputs.extend(torch.sigmoid(outputs).detach().cpu().numpy())
            # Calculate running AUROC (updated per batch)
            try:
                batch_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')
            except ValueError:
                batch_auc = 0.0  # Handle potential errors in AUROC calculation (e.g., single class in batch)
            # Update pbar with current loss and AUROC
            pbar.set_postfix(loss=f"{loss.item():.4f}", auc=f"{batch_auc:.4f}")

    # Calculate epoch-level AUROC after all batches
    train_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')

    # Validation loop
    model.eval()
    val_loss, val_labels, val_outputs = 0.0, [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device).float()
            images = gca.reconstruct(images)
#             if gca is not None:
#                 images = gca.augment(images)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Collect true labels and outputs for validation AUROC
            val_labels.extend(labels.cpu().numpy())
            val_outputs.extend(outputs.cpu().numpy())

    # Calculate validation AUROC
    val_auc = roc_auc_score(np.array(val_labels), np.array(val_outputs), multi_class='ovr')
    val_loss /= len(val_loader)

    # Display epoch summary
    print(
        f"Epoch [{epoch + 1}/{epochs}] "
        f"Train Loss: {train_loss / len(train_loader):.4f} | Train AUROC: {train_auc:.4f} "
        f"Val Loss: {val_loss:.4f} | Val AUROC: {val_auc:.4f}"
    )

    # Usage
    evaluate_model(model, test_loader, criterion, device) # Test the model on test dataset
    
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(ckpt_dir, ckpt_name))

    # Log results
    logs.append([epoch + 1, train_loss, train_auc, val_loss, val_auc])

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]
Training Epoch 1/1:   0%|          | 0/292 [00:00<?, ?batch/s][A
Training Epoch 1/1:   0%|          | 0/292 [00:05<?, ?batch/s, auc=0.5374, loss=1.0821][A
Training Epoch 1/1:   0%|          | 1/292 [00:05<25:21,  5.23s/batch, auc=0.5374, loss=1.0821][A
Training Epoch 1/1:   0%|          | 1/292 [00:06<25:21,  5.23s/batch, auc=0.5021, loss=1.2791][A
Training Epoch 1/1:   1%|          | 2/292 [00:06<19:17,  3.99s/batch, auc=0.5021, loss=1.2791][A
Training Epoch 1/1:   1%|          | 2/292 [00:07<19:17,  3.99s/batch, auc=0.4518, loss=1.0533][A
Training Epoch 1/1:   1%|          | 3/292 [00:07<15:03,  3.12s/batch, auc=0.4518, loss=1.0533][A
Training Epoch 1/1:   1%|          | 3/292 [00:08<15:03,  3.12s/batch, auc=0.4761, loss=0.9919][A
Training Epoch 1/1:   1%|▏         | 4/292 [00:08<12:05,  2.52s/batch, auc=0.4761, loss=0.9919][A
Training Epoch 1/1:   1%|▏         | 4/292 [00:09<12:05,  2.52s/batch, auc=0.4966, loss=1.0625][A
Train

# Model Testing

In [70]:
from sklearn.metrics import confusion_matrix

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    test_loss, all_outputs, all_labels = 0.0, [], []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            images = gca.reconstruct(images)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            outputs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            labels = labels.squeeze(1).cpu().numpy()

            all_outputs.extend(outputs)
            all_labels.extend(labels)

    avg_loss = test_loss / len(dataloader)
    auc = roc_auc_score(all_labels, all_outputs)
    preds = np.array(all_outputs) > 0.5
    acc = accuracy_score(all_labels, preds)

    # Confusion matrix: [[TN, FP], [FN, TP]]
    tn, fp, fn, tp = confusion_matrix(all_labels, preds).ravel()
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0

    print(f"Test Loss: {avg_loss:.4f} | Test AUROC: {auc:.4f} | Test Accuracy: {acc:.4f} | FNR: {fnr:.4f}")
    return avg_loss, auc, acc, fnr


In [71]:
# Evaluate on test set
evaluate_model(model, test_loader, criterion, device)

Test Loss: 0.8704 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778


(0.8703902733737024, 0.7876441166981818, 0.7280797101449276, 0.277822257806245)

# Save Model

In [72]:
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/model.pth")

# Load Pre-trained Model

In [73]:
# Rebuild the model
model = CustomModel(base_model_name='densenet', num_classes=1).to(device)  # or 'resnet' if used

# Load weights
model.load_state_dict(torch.load("models/model.pth"))
model.eval()  # Very important for evaluation

CustomModel(
  (base_model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momen

In [74]:
test_ds = CustomDataset(csv_file='../splits/rsna_test.csv', test=True)
test_loader = create_dataloader(test_ds, batch_size=64, shuffle=False)

In [75]:
test_loss, test_auc, test_acc, fnr = evaluate_model(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f} | Test AUROC: {test_auc:.4f} | Test Accuracy: {test_acc:.4f} | FNR: {fnr:.4f}")

Test Loss: 0.8701 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778
Test Loss: 0.8701 | Test AUROC: 0.7876 | Test Accuracy: 0.7281 | FNR: 0.2778
