In [None]:
# CS 7643: Deep Learning Group Project
# CheXpert Reproducibility - Multiclass Uncertainty Approach

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import time

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

from scipy.special import softmax
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve, auc, precision_recall_curve
import matplotlib
import matplotlib.pyplot as plt
from sklearn.utils.multiclass import unique_labels


In [None]:
def u_multiclass_loss(criterion, output, target):
    # set up loss function
    loss, length = 0, target.size()[1] 
    for i in range(length):
        loss += criterion(output[:,i,:], target[:,i])

    return loss/length

# helper function to store average/current value
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# compute accuracy by batch 
def compute_batch_accuracy(output, target):
    with torch.no_grad():
        batch_size = target.size(0)
        output_label = torch.sigmoid(output)
        pred = (torch.sign(output_label - 0.5)+1)/2
        correct = pred.eq(target).sum()
        return correct * 100.0 / batch_size


def train(model, device, data_loader, criterion, optimizer, epoch, print_freq=10):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if isinstance(input, tuple):
            input = tuple([e.to(device) if type(e) == torch.Tensor else e for e in input])
        else:
            input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(input) 
        # get mean of CrossEntropyLoss() over 14 labels
        loss = u_multiclass_loss(criterion,output, target)
        assert not np.isnan(loss.item()), 'Model diverged with loss = NaN'

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        losses.update(loss.item(), target.size(0))
        # print statement with loss updates 
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                epoch, i, len(data_loader), batch_time=batch_time,
                data_time=data_time, loss=losses))

    return losses.avg


def evaluate(model, device, data_loader, criterion, print_freq=10):
    batch_time = AverageMeter()
    losses = AverageMeter()

    results = []

    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(data_loader):

            if isinstance(input, tuple):
                input = tuple([e.to(device) if type(e) == torch.Tensor else e for e in input])
            else:
                input = input.to(device)
            target = target.to(device)

            output = model(input) 
            # mean of CrossEntropyLoss() over 14 labels
            loss = u_multiclass_loss(criterion, output, target) 

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            losses.update(loss.item(), target.size(0))

            y_true = target.detach().to('cpu').numpy().tolist()
            y_pred = output.detach().to('cpu').max(-1)[1].numpy().tolist()
            results.extend(list(zip(y_true, y_pred)))

            if i % print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                    i, len(data_loader), batch_time=batch_time, loss=losses))

    return losses.avg, results


In [None]:
def plot_learning_curves(train_losses, valid_losses):
    PATH_OUTPUT = '../output/'
    image_path = os.path.join(PATH_OUTPUT, 'model_loss.png')
    plt.figure()
    plt.plot(np.arange(len(train_losses)), train_losses, label='Train')
    plt.plot(np.arange(len(valid_losses)), valid_losses, label='Validation')
    plt.ylabel('Loss')
    plt.xlabel('epoch')
    plt.title('Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig(image_path)

def plot_confusion_matrix(results, class_names, label_id, label_name):
    PATH_OUTPUT = '../output/'
    image_path = os.path.join(PATH_OUTPUT, 'Confusion_Matrix_'+label_name+'.png')

    y_true, y_pred = zip(*results)
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_true = y_true[:,label_id]
    y_pred = y_pred[:,label_id]
    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    np.set_printoptions(precision=2)
    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(cm.shape[1]), 
        yticks=np.arange(cm.shape[0]),
        xticklabels=class_names, yticklabels=class_names,
        title='Normalized Confusion Matrix\n' + label_name,
        ylabel='True',
        xlabel='Predicted')

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
        rotation_mode="anchor")

    fmt = '.2f'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt), 
                ha="center", va="center", 
                color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()

    plt.savefig(image_path)

def plot_roc(targets, probs, label_names):
    PATH_OUTPUT = '../output/'
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    font = {'size' : 10}
    plt.rc('font', **font)
    fig = plt.figure(figsize=(21,6))
    image_path = os.path.join(PATH_OUTPUT, 'ROC.png')
    # ith observation to plot 
    for i, label_name in enumerate(label_names): 
        y_true = targets[:,i]
        y_score = probs[:,i]

        iwant = y_true < 2
        y_true = y_true[iwant]
        y_score = y_score[iwant]

        fpr[i], tpr[i], _ = roc_curve(y_true, y_score)
        roc_auc[i] = auc(fpr[i], tpr[i])

        plt.subplot(2, 7, i+1)
        plt.plot(fpr[i], tpr[i], color='b', lw=2, label='ROC (AUC = %0.2f)' % roc_auc[i])
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([0, 1])
        plt.ylim([0, 1.0])
        if i >= 7:
            plt.xlabel('False Positive Rate')
        if i % 7 == 0:
            plt.ylabel('True Positive Rate')
        else:
            plt.yticks([])
        plt.title(label_name)
        plt.legend(loc="lower right")
    plt.tight_layout()
    fig_size = plt.rcParams["figure.figsize"]
    fig_size[0] = 30
    fig_size[1] = 10
    plt.rcParams["figure.figsize"] = fig_size
    plt.savefig(image_path)

def plot_pr(targets, probs, label_names):
    PATH_OUTPUT = '../output/'
    precision = dict()
    recall = dict()
    pr_auc = dict()
    font = {'size' : 10}
    plt.rc('font', **font)
    fig = plt.figure(figsize=(21,6))
    image_path = os.path.join(PATH_OUTPUT, 'PR.png')
    for i, label_name in enumerate(label_names): 
        y_true = targets[:,i]
        y_score = probs[:,i]

        iwant = y_true < 2
        y_true = y_true[iwant]
        y_score = y_score[iwant]	

        precision[i], recall[i], _ = precision_recall_curve(y_true, y_score)
        pr_auc[i] = auc(recall[i], precision[i])

        plt.subplot(2, 7, i+1)
        plt.plot(recall[i], precision[i], color='b', lw=2, label='PR (AUC = %0.2f)' % pr_auc[i])
        plt.xlim([0, 1])
        plt.ylim([0, 1.0])
        if i >=7:
            plt.xlabel('Recall')
        if i % 7 == 0:
            plt.ylabel('Precision')
        else:
            plt.yticks([])
        plt.title(label_name)
        plt.legend(loc="best")
    plt.tight_layout()
    fig_size = plt.rcParams["figure.figsize"]
    fig_size[0] = 30
    fig_size[1] = 10
    plt.rcParams["figure.figsize"] = fig_size
    plt.savefig(image_path)

In [None]:
class CheXpertDataSet(Dataset):
    def __init__(self, data_dir, image_list_file, transform=None):
        df = pd.read_csv(image_list_file)
        df = df.fillna(0)
        self.transform = transform
        self.imagePaths = []
        self.labels = []
        for i, row in df.iterrows():
            self.imagePaths.append( os.path.join(data_dir,row['Path']) )
            # uncertainty labelling -- replace -1 with 2, 
            label = list(row[5:].values % 3) 
            self.labels.append(label)


    def __getitem__(self, index):
        image = Image.open(self.imagePaths[index]).convert('RGB') 
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.LongTensor(self.labels[index])

    def __len__(self):
        return len(self.imagePaths)

In [None]:
class DenseNet121(nn.Module):
    def __init__(self, num_labels):
        # standard DenseNet121 architecture 
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True, memory_efficient = True)
        num_features = self.densenet121.classifier.in_features # 1024
        self.densenet121.classifier = nn.Linear(num_features, num_labels*3)
        self.num_labels = num_labels
        # probability for each label - p0, p1, p2
        self.num_classes = 3

    def forward(self, x):
        x = self.densenet121(x)
        return x.reshape([len(x), self.num_labels, self.num_classes])

In [None]:
cudnn.benchmark = True

torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)

PATH_DIR = '../DL_PROJ'
PATH_TRAIN = './CheXpert-v1.0-small/train.csv'
PATH_VALID = './CheXpert-v1.0-small/valid.csv'
PATH_TEST = './CheXpert-v1.0-small/valid.csv'
PATH_OUTPUT = "../output/"
os.makedirs(PATH_OUTPUT, exist_ok=True)
MODEL_OUTPUT = 'model.pth.tar'

# set config parameters consistent with CheXpert paper
NUM_EPOCHS = 3
BATCH_SIZE = 16 
USE_CUDA = True  
NUM_WORKERS = 8
num_labels = 14

In [None]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])

transformseqTrain=transforms.Compose([
                                    transforms.Resize(size=(320, 320)),
                                    transforms.ToTensor(),
                                    normalize
                                ])

transformseq=transforms.Compose([
                                    transforms.Resize(size=(320, 320)),
                                    transforms.ToTensor(),
                                    normalize
                                ])


In [None]:
train_dataset = CheXpertDataSet(data_dir=PATH_DIR, image_list_file=PATH_TRAIN, transform = transformseqTrain)
valid_dataset = CheXpertDataSet(data_dir=PATH_DIR, image_list_file=PATH_VALID, transform = transformseq)
test_dataset = CheXpertDataSet(data_dir=PATH_DIR, image_list_file=PATH_TEST, transform = transformseq)

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
model = DenseNet121(num_labels)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)


if torch.cuda.device_count() > 1:
    print("Use", torch.cuda.device_count(), "GPUs")
    model = nn.DataParallel(model)

device = torch.device("cuda" if torch.cuda.is_available() and USE_CUDA else "cpu")
model.to(device)
criterion.to(device)

PATH_MODEL = os.path.join(PATH_OUTPUT, "model.pth")
if os.path.isfile(PATH_MODEL):
    model = torch.load(PATH_MODEL)
    print('Saved model loaded')

In [None]:
# Model Training

best_val_loss = 1000000
train_losses = []
valid_losses = []
for epoch in range(NUM_EPOCHS):
    scheduler.step() 
    print('Learning rate in epoch:', epoch)
    for param_group in optimizer.param_groups:
        print(param_group['lr'])
    train_loss = train(model, device, train_loader, criterion, optimizer, epoch)
    valid_loss, valid_results = evaluate(model, device, valid_loader, criterion)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

    is_best = valid_loss < best_val_loss  
    if is_best:
        best_val_loss = valid_loss
        torch.save(model, os.path.join(PATH_OUTPUT, "model.pth"))

print('Training finished, model saved')
df_learning = pd.DataFrame(data = {'Train Loss':train_losses,'Valid Loss':valid_losses} )
df_learning.index.name = 'Epoch'
df_learning.to_csv(os.path.join(PATH_OUTPUT,'LearningCurves.csv'))

In [None]:
plot_learning_curves(train_losses, valid_losses)

In [None]:
class_names = ['Negative', 'Positive', 'Uncertain']
label_names = [ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation',
                'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']

for i, label_name in enumerate(label_names): 
    plot_confusion_matrix(test_results, class_names, i, label_name)

In [None]:
def predict_positive(model, device, data_loader):
    model.eval()

    probas = np.array([])
    targets = np.array([])
    with torch.no_grad():
        for i, (input, target) in enumerate(data_loader):
            if isinstance(input, tuple):
                input = tuple([e.to(device) if type(e) == torch.Tensor else e for e in input])
            else:
                input = input.to(device)
            target = target.detach().to('cpu').numpy()
            targets = np.concatenate((targets, target), axis=0) if len(targets) > 0 else target
           
            output = model(input) 
            y_pred = output.detach().to('cpu').numpy()
            y_pred = y_pred[:,:,:2] 
            y_pred = softmax(y_pred, axis = -1)
            # keep positive predictions 
            y_pred = y_pred[:,:,1] 

            probas = np.concatenate((probas, y_pred), axis=0) if len(probas) > 0 else y_pred
    
    return targets, probas

In [None]:
test_targets, test_probs = predict_positive(best_model, device, test_loader)

In [None]:
df_test = pd.read_csv(PATH_TEST)
ids = df_test['Path'].copy().values
for i, id in enumerate(ids):
    ids[i] = id[33:45] 
test_targets_studies, test_probs_studies = [], []
i = 0
while i < len(ids):
    j = i+1
    target = test_targets[i]
    while (j < len(ids)) and (ids[i] == ids[j]):
        j += 1
    y_pred = np.max(test_probs[i:j], axis = 0)
    test_targets_studies.append(target)
    test_probs_studies.append(y_pred)
    i = j

test_targets_studies = np.array(test_targets_studies)
test_probs_studies = np.array(test_probs_studies)

print(len(test_targets_studies))
print(len(test_probs_studies))
plot_roc(test_targets_studies, test_probs_studies, label_names)
plot_pr(test_targets_studies, test_probs_studies, label_names)