In [1]:
data_dir = "/content/drive/MyDrive/mimic_dataset" # Path to "mimic_dataset" folder

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
import os
import torch
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torchvision.models import *
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score
from tqdm import tqdm
import shutil
import matplotlib.pyplot as plt
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning)

In [4]:
# config class for setting hyperparameters
class cfg:
    IMG_SIZE = 224
    BATCH = 64
    EPOCHS = 100
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
labels_dir = os.path.join(data_dir, 'mimic_icd10_labels.csv')
labels = pd.read_csv(labels_dir)
# Only keep xrays with PA or AP views
labels = labels[labels['viewposition'].isin(['PA', 'AP'])].drop_duplicates(subset=['dicom']).reset_index(drop=True)
labels.head()

Unnamed: 0,dicom,subject_id,icd_code,edema,viewposition,study_id,hadm_id,ventilation_status
0,54affd39-8bf24209-232bac8a-df6c277a-398ee8a5,10000980,I5023,1.0,AP,58206436,25911675,
1,6ad819bb-bae74eb9-7b663e90-b8deabd7-57f8054a,10000980,I5023,1.0,PA,54935705,29659838,
2,051b7911-cb00aec9-0b309188-89803662-303ec278,10002131,I5033,,AP,52823782,24065018,
3,4873aa08-977bfd31-fb492e64-6ef432d1-3f12cbe3,10002430,I5033,0.0,PA,53254222,24513842,
4,e0275ad1-1e6a7451-c3960f5f-1267a188-547b73a1,10003502,I5033,1.0,AP,52309364,29011269,


In [6]:
# change image names to be absolute paths of all images
def format_name(row, data_dir):
  '''Takes one row in the dataframe and returns the file location for the particular image'''
  img_dir = os.path.join(data_dir, 'data')
  filepath = os.path.join(img_dir, row['dicom']) + '.jpg'
  return filepath

def binarize_labels(row):
  # 0 is reduced ejection fraction, 1 is preserved ejection fraction
  return 0 if row['icd_code'].startswith('I502') else 1

In [7]:
# Create column containing paths to each image
labels['img_path'] = labels.apply(lambda row: format_name(row, data_dir=data_dir), axis=1)
# Create column with binary label for each image
labels['binary_label'] = labels.apply(lambda row: binarize_labels(row), axis=1)
# For experimenting with edema only patients
xrays_with_edema = labels[labels['edema']==1]

# Train test split
paths, img_labels = labels['img_path'].values, labels['binary_label'].values
X_train, X_test, y_train, y_test = train_test_split(paths, img_labels, test_size=.25, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
print(f'X_train shape: {X_train.shape}\ny_train shape: {y_train.shape}\nX_val shape: {X_val.shape}\ny_val shape: {y_val.shape}\nX_test shape: {X_test.shape}\ny_test shape: {y_test.shape}')

# Train test split for edema only patients
edema_paths, edema_img_labels = xrays_with_edema['img_path'].values, xrays_with_edema['binary_label'].values
X_train_edema, X_test_edema, y_train_edema, y_test_edema = train_test_split(edema_paths, edema_img_labels, test_size=.25, random_state=42)
X_train_edema, X_val_edema, y_train_edema, y_val_edema = train_test_split(X_train_edema, y_train_edema, test_size=0.1, random_state=42)

X_train shape: (2354,)
y_train shape: (2354,)
X_val shape: (262,)
y_val shape: (262,)
X_test shape: (872,)
y_test shape: (872,)


In [8]:
class EdemaCxrDataset(Dataset):
    """Mimic Edema CXR Dataset."""

    def __init__(self, img_paths, labels):
        '''Args:
            paths: array of image paths
            labels: array of labels for each image in img_paths
        '''
        self.img_paths = img_paths
        self.labels = labels
    
    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # load image from google drive into pytorch tensor
        file_path = self.img_paths[idx]
        image = read_image(file_path)
        label = self.labels[idx]

        image = TF.resize(image, (cfg.IMG_SIZE, cfg.IMG_SIZE)) # reshape to (1, 224, 224)
        image = image.repeat(3,1,1)         # tile to (3, 224, 224)
        image = image / 255 # normalize between 0 and 1
        # TODO: consider adding more data augmentations like:
        # https://github.com/MLforHealth/CXR_Fairness/blob/c2a0e884171d6418e28d59dca1ccfb80a3f125fe/cxr_fairness/data/data.py#L33

        return image, label

In [9]:
# Create datasets and dataloaders
train_dataset = EdemaCxrDataset(X_train, y_train)
val_dataset = EdemaCxrDataset(X_val, y_val)
test_dataset = EdemaCxrDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH)
test_loader = DataLoader(test_dataset, batch_size=cfg.BATCH)

# Loaders for edema only patients
edema_train_dataset = EdemaCxrDataset(X_train_edema, y_train_edema)
edema_val_dataset = EdemaCxrDataset(X_val_edema, y_val_edema)
edema_test_dataset = EdemaCxrDataset(X_test_edema, y_test_edema)
edema_train_loader = DataLoader(edema_train_dataset, batch_size=cfg.BATCH, shuffle=True)
edema_val_loader = DataLoader(edema_val_dataset, batch_size=cfg.BATCH)
edema_test_loader = DataLoader(edema_test_dataset, batch_size=cfg.BATCH)

# Modeling

In [10]:
# import resnet50 model
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(in_features=2048, out_features=1, bias=True)
model.to(cfg.device)
# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# loss function 
criterion = nn.BCELoss()

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [11]:
class EarlyStopper:
  '''custom implementation of early stopping. Once the patience and min_delta threshold have been met it will stop training'''

  def __init__(self, patience=1, min_delta=0):
      self.patience = patience
      self.min_delta = min_delta
      self.counter = 0
      self.min_validation_loss = np.inf

  def early_stop(self, validation_loss):
      if validation_loss < self.min_validation_loss:
          self.min_validation_loss = validation_loss
          self.counter = 0
      elif validation_loss > (self.min_validation_loss + self.min_delta):
          self.counter += 1
          if self.counter >= self.patience:
              return True
      return False

In [12]:
def evaluate(model, dataloader, criterion, return_dict=False, get_auc=False):
  '''
    evaluates model on a given dataset (primarily used to get metrics on validation and test datasets)
    Args: model: imported torch model
          dataloader: pytorch dataloader object
          criterion: loss function
          return_dict: whether to return metrics as a dictionary or string
          get_auc: whether to compute and return the auc as well
          '''
  
  model.eval()
  list_of_labels = []
  # Collection of all model predictions
  all_predictions = torch.Tensor()
  loss = 0.0
  
  # Generate predictions
  for i, (imgs, labels) in enumerate(dataloader):
      imgs, labels = imgs.float().to(cfg.device), labels.to(cfg.device).float()
      list_of_labels.extend(labels.cpu().numpy())
      with torch.no_grad():
          output = nn.Sigmoid()(model(imgs)).squeeze(1)
          loss += float(criterion(output, labels)) / len(imgs)
      
      # Predictions for current batch of data 
      preds = output.cpu().detach().apply_(lambda x: 1.0 if x > 0.5 else 0.0)
      all_predictions = torch.cat([all_predictions.to(cfg.device), preds.to(cfg.device)], dim=0)
  
  # Generate metrics
  metrics = classification_report(np.array(list_of_labels), all_predictions.cpu().numpy(), 
                                  target_names=['Reduced Ejection Fraction', 'Preserved Ejection Fraction'],
                                  output_dict=return_dict)
  if get_auc:
    auc = roc_auc_score(np.array(list_of_labels), all_predictions.cpu().numpy())
    return loss, metrics, auc

  return loss, metrics
  


  
def fit(model, train_dataloader, validation_dataloader, optimizer, criterion, epochs, save_path='mimic_model_weights.pt', earlystopper=None):
  ''' trains a model on data from a dataloader
      Args: model: imported torch model
            train_dataloader: pytorch dataloader object for training set
            validation_dataloader: pytorch dataloader object for validation set
            optimizer: pytorch optimizer object
            criterion: loss function
            epochs: number of epochs to train for (this option is set in the config (cfg) class)
  '''

  min_val_loss = np.inf # for model checkpoints
  for epoch in range(epochs):
    with tqdm(train_dataloader, unit='batch', position=0, leave=True) as tepoch:
      for imgs, labels in tepoch:
        tepoch.set_description(f'Epoch {epoch+1}')

        model.train()
        imgs, labels = imgs.to(cfg.device), labels.to(cfg.device).float()
        # Forward pass
        output = nn.Sigmoid()(model(imgs.float())).squeeze(1)
        # Compute loss for one batch
        loss = criterion(output, labels)
        # Zero out old gradients held by optimizer so we can compute new gradients for this current batch
        optimizer.zero_grad()
        # Calculate new gradient of loss wrt current model weights
        loss.backward()
        # Update model weights
        optimizer.step()
        tepoch.set_postfix(training_loss=(float(loss)))
      
      # Generate metrics on validation dataset
      validation_loss, validation_metrics = evaluate(model=model, dataloader=validation_dataloader, criterion=criterion)
      tepoch.set_postfix(loss=validation_loss)
      scheduler.step(validation_loss)
      print(f'Validation Loss: {validation_loss}\n-------------------------------------------------------\n{validation_metrics}\n\n')
      
      # saving checkpoint
      if validation_loss < min_val_loss:
        print(f'[CHECKPOINT] New best validation loss achieved. Old best was {min_val_loss}, new_best is {validation_loss}. Saving weights to {data_dir}/{save_path}...')
        torch.save(model.state_dict(),  os.path.join(data_dir, save_path))
        min_val_loss = float(validation_loss)   

      # Check to make sure earlystopping criteria not met
      if earlystopper is not None:
        if earlystopper.early_stop(validation_loss):
          # stop training and save model weights
          print(f'[EARLYSTOPPING]\n stopped training. Final validation loss was {validation_loss}\nSaving weights to {data_dir}/{save_path}...')
          torch.save(model.state_dict(), os.path.join(data_dir, save_path))
          return

In [None]:
# Train the model
early_stopping = EarlyStopper(patience=10, min_delta=0.1)
scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, patience=3, threshold=1e-3, verbose=True)
fit(model=model, train_dataloader=train_loader, validation_dataloader=val_loader, 
    optimizer=optimizer, criterion=criterion, epochs=cfg.EPOCHS, earlystopper=early_stopping, save_path='resnet50_weights.pt')

Epoch 1:  11%|█         | 4/37 [03:48<29:55, 54.40s/batch, training_loss=0.646]

In [None]:
# # Load in pre-trained model weights from edema dataset and fine tune on our dataset
# state_dic = torch.load('/content/drive/MyDrive/mimic_dataset/edema_mimic_model_weights.pt', map_location=torch.device('cpu'))['model_state_dict']
# model.load_state_dict(state_dic)

In [None]:
# early_stopping = EarlyStopper(patience=5, min_delta=0.1)
# fit(model=model, train_dataloader=edema_train_loader, validation_dataloader=edema_val_loader, 
#     optimizer=optimizer, criterion=criterion, epochs=cfg.EPOCHS, earlystopper=early_stopping, save_path='edema_pretrained_model_dic.pt')

In [None]:
# plt.figure(figsize=(10,10))
# ref_prec, ref_rec, ref_f1 = test_metrics['Reduced Ejection Fraction']['precision'], test_metrics['Reduced Ejection Fraction']['recall'], test_metrics['Reduced Ejection Fraction']['f1-score']
# pef_prec, pef_rec, pef_f1 = test_metrics['Preserved Ejection Fraction']['precision'], test_metrics['Preserved Ejection Fraction']['recall'], test_metrics['Preserved Ejection Fraction']['f1-score']
# X = np.arange(3)
# plt.bar(X, [ref_prec, ref_rec,ref_f1], width=0.2, label='Reduced Ejection Fraction')
# plt.bar(X+0.2, [pef_prec, pef_rec, pef_f1], width=0.2, label='Preserved Ejection Fraction')
# plt.xticks(range(3), list(test_metrics['Reduced Ejection Fraction'].keys()))
# plt.legend(loc='lower right')

In [None]:
# loss, printable_metrics, auc = evaluate(model=model, dataloader=test_loader, criterion=criterion, return_dict=False, get_auc=True)

In [None]:
# print(printable_metrics, f'auc={auc:.4f}', sep='\n')

In [None]:
test_loss, test_metrics, test_auc = evaluate(model=model, dataloader=test_loader, criterion=criterion, return_dict=False, get_auc=True)
print(test_metrics)

In [None]:
# Error Analysis
# Get predictions on test set
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(cfg.device), labels.to(cfg.device).float()
        output = nn.Sigmoid()(model(imgs.float())).squeeze(1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(output.cpu().numpy())

# Get indices of misclassified images
misclassified = np.where(np.array(y_true) != np.array(y_pred).round())[0]
print(f'Number of misclassified images: {len(misclassified)}')

# Get indices of correctly classified images
correctly_classified = np.where(np.array(y_true) == np.array(y_pred).round())[0]
print(f'Number of correctly classified images: {len(correctly_classified)}')

# Get indices of images with high confidence
high_confidence = np.where(np.array(y_pred) > 0.9)[0]
print(f'Number of images with high confidence: {len(high_confidence)}')

# Get indices of images with low confidence
low_confidence = np.where(np.array(y_pred) < 0.1)[0]
print(f'Number of images with low confidence: {len(low_confidence)}')

# Get indices of images with high confidence and misclassified
high_confidence_misclassified = np.intersect1d(misclassified, high_confidence)
print(f'Number of images with high confidence and misclassified: {len(high_confidence_misclassified)}')

# Get indices of images with low confidence and misclassified
low_confidence_misclassified = np.intersect1d(misclassified, low_confidence)
print(f'Number of images with low confidence and misclassified: {len(low_confidence_misclassified)}')


In [2]:
# saliency map
def get_saliency_map(img, label, model):
    model.eval()
    img.requires_grad = True
    output = model(img)
    output = nn.Sigmoid()(output).squeeze(1)
    #label = torch.tensor(label).to(cfg.device)
    loss = criterion(output, label)
    loss.backward()
    saliency_map = img.grad.abs().squeeze(0).cpu().numpy()
    return saliency_map



# Get saliency maps for misclassified images
misclassified_saliency_maps = []
for idx in misclassified:
    img, label = test_dataset[idx]
    img = img.unsqueeze(0).to(cfg.device)
    label = torch.tensor(label).unsqueeze(0).to(cfg.device)
    label = label.float()
    saliency_map = get_saliency_map(img, label, model)

    misclassified_saliency_maps.append(saliency_map)

# plot saliency maps
def plot_saliency_maps(saliency_maps, images, labels, preds, figsize=(20, 20)):
    n = len(saliency_maps)
    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=figsize)
    for i in range(n):
        axes[i, 0].imshow(images[i].transpose((1, 2, 0), cmap='jet'))
        axes[i, 0].set_title(f'Label: {labels[i]}')
        axes[i, 1].imshow(saliency_maps[i], cmap='jet')
        axes[i, 1].set_title(f'Prediction: {preds[i]}')
        axes[i, 2].imshow(images[i].transpose((1, 2, 0)))
        axes[i, 2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
        axes[i, 2].set_title(f'Overlay')
    #plt.tight_layout()
    plt.show()

# plot saliency maps for misclassified images
#TypeError: transpose() received an invalid combination of arguments - got (tuple), but expected one of:
# * (int dim0, int dim1)
# * (name dim0, name dim1)
misclassified_images = [test_dataset[idx][0].numpy() for idx in misclassified]
misclassified_labels = [test_dataset[idx][1] for idx in misclassified]
misclassified_preds = [np.round(np.array(y_pred)[idx], 2) for idx in misclassified]
# plot 6 misclassified images
plot_saliency_maps(misclassified_saliency_maps, misclassified_images, misclassified_labels, misclassified_preds)

# Get saliency maps for correctly classified images
correctly_classified_saliency_maps = []
for idx in correctly_classified:
    img, label = test_dataset[idx]
    img = img.unsqueeze(0).to(cfg.device)
    label = torch.tensor(label).unsqueeze(0).to(cfg.device)
    label = label.float()
    saliency_map = get_saliency_map(img, label, model)

    correctly_classified_saliency_maps.append(saliency_map)

# plot saliency maps for correctly classified images
correctly_classified_images = [test_dataset[idx][0].numpy() for idx in correctly_classified]
correctly_classified_labels = [test_dataset[idx][1] for idx in correctly_classified]
correctly_classified_preds = [np.round(np.array(y_pred)[idx], 2) for idx in correctly_classified]
# plot 6 correctly classified images
plot_saliency_maps(correctly_classified_saliency_maps[:6], correctly_classified_images[:6], correctly_classified_labels[:6], correctly_classified_preds[:6])



NameError: name 'misclassified' is not defined

In [None]:
# Get grad cam for misclassified images
def get_grad_cam(img, label, model):
    model.eval()
    img.requires_grad = True
    output = model(img)
    output = nn.Sigmoid()(output).squeeze(1)
    #label = torch.tensor(label).to(cfg.device)
    loss = criterion(output, label)
    loss.backward()
    grad_cam = img.grad.abs().squeeze(0).cpu().numpy()
    return grad_cam

misclassified_grad_cam = []
for idx in misclassified:
    img, label = test_dataset[idx]
    img = img.unsqueeze(0).to(cfg.device)
    label = torch.tensor(label).unsqueeze(0).to(cfg.device)
    label = label.float()
    grad_cam = get_grad_cam(img, label, model)

    misclassified_grad_cam.append(grad_cam)

# write function to plot grad cam
def plot_grad_cam(grad_cam, images, labels, preds, figsize=(20, 20)):
    n = len(grad_cam)
    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=figsize)
    for i in range(n):
        axes[i, 0].imshow(images[i].transpose((1, 2, 0), cmap='

# plot grad cam for misclassified images
plot_saliency_maps(misclassified_grad_cam[:6], misclassified_images[:6], misclassified_labels[:6], misclassified_preds[:6])
# Get grad cam for correctly classified images
correctly_classified_grad_cam = []
for idx in correctly_classified:
    img, label = test_dataset[idx]
    img = img.unsqueeze(0).to(cfg.device)
    label = torch.tensor(label).unsqueeze(0).to(cfg.device)
    label = label.float()
    grad_cam = get_grad_cam(img, label, model)

    correctly_classified_grad_cam.append(grad_cam)

# plot grad cam for correctly classified images
plot_saliency_maps(correctly_classified_grad_cam[:6], correctly_classified_images[:6], correctly_classified_labels[:6], correctly_classified_preds[:6])
