In [17]:
from datetime import datetime
import glob
import math
import matplotlib.pyplot as plt
import numpy as np
import wandb
import os
import pandas as pd
from PIL import Image
import random as python_random
import seaborn as sns
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, precision_recall_curve
from sklearn.metrics import auc, accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelBinarizer

from sklearn.utils import shuffle
import sys
import math

from torch.utils.data import random_split, DataLoader

In [18]:
# CheXpert images can be found: https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2
data_df = pd.read_csv('train_cheXbert.csv')

# Demographic labels can be found: https://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf
demo_df = pd.DataFrame(pd.read_excel("CHEXPERT_DEMO.xlsx", engine='openpyxl')) #pip install openpyxl

# 60-10-30, train-val-test split that we used
# These splits can be found in this repository
split_df = pd.read_csv('chexpert_split_2021_08_20.csv').set_index('index')

In [19]:
data_df = pd.concat([data_df,split_df], axis=1)
data_df = data_df[~data_df.split.isna()]

path_split =  data_df.Path.str.split("/", expand = True)
data_df["patient_id"] = path_split[2]
demo_df = demo_df.rename(columns={'PATIENT': 'patient_id'})
data_df = data_df.merge(demo_df, on="patient_id")

mask = (data_df.PRIMARY_RACE.str.contains("Black", na=False))
data_df.loc[mask, "race"] = "BLACK/AFRICAN AMERICAN"

mask = (data_df.PRIMARY_RACE.str.contains("White", na=False))
data_df.loc[mask, "race"] = "WHITE"

mask = (data_df.PRIMARY_RACE.str.contains("Asian", na=False))
data_df.loc[mask, "race"] = "ASIAN"

In [20]:
data_df.split.value_counts(normalize=True)

train       0.599482
test        0.300823
validate    0.099695
Name: split, dtype: float64

In [21]:
data_df.race.value_counts(normalize=True)

WHITE                     0.779016
ASIAN                     0.148130
BLACK/AFRICAN AMERICAN    0.072854
Name: race, dtype: float64

In [22]:
data_df[['split', 'race']].value_counts(normalize=True)

split     race                  
train     WHITE                     0.466008
test      WHITE                     0.234774
train     ASIAN                     0.089452
validate  WHITE                     0.078234
test      ASIAN                     0.044447
train     BLACK/AFRICAN AMERICAN    0.044022
test      BLACK/AFRICAN AMERICAN    0.021602
validate  ASIAN                     0.014231
          BLACK/AFRICAN AMERICAN    0.007230
dtype: float64

In [23]:
train_df = data_df[data_df.split=="train"]
validation_df = data_df[data_df.split=="validate"]
test_df = data_df[data_df.split=="test"]

size = 100 
# Perform stratified sampling to get 5000 samples for each group within the "race" column
train_stratified = train_df.groupby("race", group_keys=False).apply(lambda x: x.sample(min(len(x), size), random_state=42))
validation_stratified = validation_df.groupby("race", group_keys=False).apply(lambda x: x.sample(min(len(x), size), random_state=42))
test_stratified = test_df.groupby("race", group_keys=False).apply(lambda x: x.sample(min(len(x), size), random_state=42))

train_stratified.race.value_counts()
validation_stratified.race.value_counts()
test_stratified.race.value_counts()

ASIAN                     100
BLACK/AFRICAN AMERICAN    100
WHITE                     100
Name: race, dtype: int64

In [24]:
train_df.to_csv('train_df.csv')
validation_df.to_csv('validation_df.csv')
test_df.to_csv('test_df.csv')

train_stratified.to_csv('train_sub_df.csv')
validation_stratified.to_csv('validation_sub_df.csv')
test_stratified.to_csv('test_sub_df.csv')


In [25]:
#False indicates no patient_id shared between groups

unique_train_id = train_df.patient_id.unique()
unique_validation_id = validation_df.patient_id.unique()
unique_test_id = test_df.patient_id.unique()
all_id = np.concatenate((unique_train_id, unique_validation_id, unique_test_id), axis=None)

def contains_duplicates(X):
    return len(np.unique(X)) != len(X)

contains_duplicates(all_id)


False

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet50, resnet34
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.cuda.amp as amp 
import pandas as pd
import os
from PIL import Image

# Set random seeds for reproducibility
# torch.manual_seed(2021)
# torch.cuda.manual_seed(2021)

class CustomResNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomResNet, self).__init__()
        self.resnet34 = resnet34(pretrained=True)
        self.features = nn.Sequential(*list(self.resnet34.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc =  nn.Sequential(nn.Linear(512, num_classes), nn.Softmax(dim=1)  # Apply softmax for probability distribution
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Define dataset and dataloaders
class DatasetGenerator(Dataset):
    def __init__(self, data_frame, root_dir, nnTarget, transform=None):
        self.data_frame = pd.read_csv(data_frame)
        self.target = nnTarget
        self.listImagePaths = list(root_dir + self.data_frame['Path'])
        self.listImageLabels = list(self.data_frame[nnTarget])
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, index):
        imagePath = self.listImagePaths[index]
        imageData = Image.open(imagePath).convert('RGB')

        label = self.listImageLabels[index]
        # Define a dictionary to map class labels to class indices
        class_to_idx = {
            'ASIAN': 0,
            'WHITE': 1,
            'BLACK/AFRICAN AMERICAN': 2
        }
        imageLabel = class_to_idx[label]
        if self.transform != None: imageData = self.transform(imageData)

        # if self.target == 'race':
        #     possible_labels = self.data_frame[self.target].unique() # List all possible categories
            
        #     # Convert categorical label to one-hot encoded tensor
        #     one_hot_label = torch.zeros(len(possible_labels))
        #     label_index = possible_labels.tolist().index(label)

        #     one_hot_label[label_index] = 1.0

        #     # Convert to torch.FloatTensor
        #     imageLabel = torch.FloatTensor(one_hot_label)        
        #     if self.transform != None: imageData = self.transform(imageData)
        # else:
        #     # Convert to torch.FloatTensor
        #     imageLabel = label
        #     if self.transform != None: imageData = self.transform(imageData)

        return imageData, imageLabel

HEIGHT, WIDTH = 320, 320

train_transform = transforms.Compose([
    transforms.Resize((320, 320)),
    transforms.RandomRotation(15),      # Randomly rotate the image by up to 15 degrees
    transforms.RandomResizedCrop(320, scale=(0.9, 1.1)), 
    transforms.RandomHorizontalFlip(),  # Horizontal Flip
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet stats
])

# Define preprocessing transformations
validate_transform = transforms.Compose([
    transforms.Resize((320, 320)),                   # Resize the input image to 256x256
    transforms.CenterCrop(320),               # Crop the center 224x224 portion of the image
    transforms.ToTensor(),                    # Convert to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet stats
])

train_dataset = DatasetGenerator('train_sub_df.csv', '../', 'race', transform=train_transform) 
validate_dataset = DatasetGenerator('validation_sub_df.csv',  '../', 'race', transform=validate_transform)


In [29]:
train_dataset[1][0].shape

torch.Size([3, 320, 320])

In [30]:
learning_rate = 1e-3
momentum_val=0.9
decay_val= 0.0
train_batch_size = 256 # may need to reduce batch size if OOM error occurs
test_batch_size = 256

In [31]:
train_epoch = math.ceil(len(train_dataset) / train_batch_size)
val_epoch = math.ceil(len(validate_dataset) / test_batch_size)
print(train_epoch, val_epoch)

30 2


In [35]:
torch.cuda.empty_cache()
# Create an instance of the ResNet-34 model
model = resnet34(pretrained=True)

# Modify the final fully connected layer for your specific task
num_classes = 3  # Replace with the number of classes in your task
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.cuda()

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=32)
validate_loader = DataLoader(validate_dataset, batch_size=test_batch_size, shuffle=False, num_workers=32)

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=decay_val)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, min_lr=1e-5, verbose=True)
criterion = nn.CrossEntropyLoss()  # Use appropriate loss function here

record_wb = False
if record_wb == True: 
        wandb.init(
                # set the wandb project where this run will be logged
                project="chexnet-" + 'race' + "-pred",
                # track hyperparameters and run metadata
                config={"architecture": "ResNet34",
                        "dataset": "CheXpert",
                        "fine-tuned": True, 
                        "target": 'race', 
                        "data-subset": False
                }
                )



In [37]:
# Training loop
best_val_loss = float('inf')
best_model_weights = None

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0
    train_correct = 0
    train_total = 0 

    val_predictions = []  # Store predicted probabilities for AUROC and PR-AUC
    val_labels = []  # S

    model.train()
    for inputs, labels in train_loader:
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        probs = torch.nn.functional.softmax(outputs, dim=1)
        _, predicted_classes = torch.max(probs, 1)
        train_total += labels.size(0)
        train_correct += torch.sum(predicted_classes==labels).item()

    train_loss /= len(train_loader)
    train_accuracy = train_correct / train_total

    val_loss = 0.0
    val_correct = 0
    val_total = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in validate_loader:
            inputs = inputs.cuda()
            labels = labels.cuda()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted_classes = torch.max(probs, 1)

            val_total += labels.size(0)
            val_correct += torch.sum(predicted_classes==labels).item()

            # Collect predicted probabilities and labels for AUROC and PR-AUC
            val_predictions.extend(probs.cpu().numpy())  # Assuming you have 2 classes, using probabilities of class 1
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(validate_loader)
    val_accuracy = val_correct / val_total

    # Calculate AUROC and PR-AUC
    label_binarizer = LabelBinarizer().fit(val_labels)
    y_onehot_test = label_binarizer.transform(val_labels)

    auroc = roc_auc_score(y_onehot_test, val_predictions, multi_class='ovr')
    pr_auc = average_precision_score(y_onehot_test, val_predictions)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, \
          Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val AUROC: {auroc:.4f}, Val AUC: {pr_auc:.4f}')
    #wandb.log({'epoch': epoch, 'Train Loss': train_loss, 'Train Accuracy': train_accuracy, 'Val Loss': val_loss, 'Val Accuracy': val_accuracy, 'Val AUROC': auroc, 'Val AUC': pr_auc})
    # Save the model if validation loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_weights = model.state_dict()
        arc_name = 'CHEXPERT_RACE_RESNET34_'
        var_date = datetime.now().strftime("%Y%m%d-%H%M%S")
        model_name = "models/" + str(arc_name) + "_" + var_date + f"_epoch:{epoch:03d}_val_loss:{val_loss:.2f}.pth.tar"
        torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_loss': val_loss, 'optimizer' : optimizer.state_dict()}, model_name)
        print('model saved')

    scheduler.step(val_loss)

print("Training finished.")
#wandb.finish()




Epoch [1/10], Train Loss: 1.1747, Train Accuracy: 0.3533,           Val Loss: 1.1707, Val Accuracy: 0.3933, Val AUROC: 0.5626, Val AUC: 0.3948
model saved




Epoch [2/10], Train Loss: 1.0775, Train Accuracy: 0.4333,           Val Loss: 1.4691, Val Accuracy: 0.3200, Val AUROC: 0.5549, Val AUC: 0.4057




Epoch [3/10], Train Loss: 1.1184, Train Accuracy: 0.3867,           Val Loss: 1.3570, Val Accuracy: 0.3567, Val AUROC: 0.5748, Val AUC: 0.3977




Epoch [4/10], Train Loss: 1.1066, Train Accuracy: 0.3500,           Val Loss: 1.2274, Val Accuracy: 0.3567, Val AUROC: 0.5703, Val AUC: 0.3972
Epoch 00004: reducing learning rate of group 0 to 1.0000e-04.




Epoch [5/10], Train Loss: 1.0552, Train Accuracy: 0.4533,           Val Loss: 1.1452, Val Accuracy: 0.3800, Val AUROC: 0.5703, Val AUC: 0.4043
model saved




Epoch [6/10], Train Loss: 1.0020, Train Accuracy: 0.4900,           Val Loss: 1.1702, Val Accuracy: 0.4067, Val AUROC: 0.5840, Val AUC: 0.4221




Epoch [7/10], Train Loss: 0.9624, Train Accuracy: 0.5300,           Val Loss: 1.1609, Val Accuracy: 0.4433, Val AUROC: 0.6070, Val AUC: 0.4459




Epoch [8/10], Train Loss: 0.9633, Train Accuracy: 0.5367,           Val Loss: 1.2289, Val Accuracy: 0.4233, Val AUROC: 0.5834, Val AUC: 0.4278
Epoch 00008: reducing learning rate of group 0 to 1.0000e-05.




Epoch [9/10], Train Loss: 0.9535, Train Accuracy: 0.5467,           Val Loss: 1.1820, Val Accuracy: 0.4267, Val AUROC: 0.6063, Val AUC: 0.4429




Epoch [10/10], Train Loss: 0.9508, Train Accuracy: 0.5367,           Val Loss: 1.1801, Val Accuracy: 0.4333, Val AUROC: 0.6049, Val AUC: 0.4405
Training finished.


In [None]:
# multilabel_predict_test = model.predict(test_batches, max_queue_size=10, verbose=1, steps=math.ceil(len(test_df)/test_batch_size), workers=16)
# result = multilabel_predict_test
# #result = model.predict(validate_batches, val_epoch)
# labels = np.argmax(result, axis=1)
# target_names = ['Asian', 'Black', 'White']

# print ('Classwise ROC AUC \n')
# for p in list(set(labels)):
#     fpr, tpr, thresholds = roc_curve(test_batches.classes, result[:,p], pos_label = p)
#     auroc = round(auc(fpr, tpr), 2)
#     print ('Class - {} ROC-AUC- {}'.format(target_names[p], auroc))

# print (classification_report(test_batches.classes, labels, target_names=target_names))
# class_matrix = confusion_matrix(test_batches.classes, labels)

# sns.heatmap(class_matrix, annot=True, fmt='d', cmap='Blues')