In [0]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!pip install pretrainedmodels
!pip install bcolz
!pip install isoweek
!pip install pandas_summary

In [0]:
# %cd 'drive/My Drive/SRP/Project/chestX-ray-14/src'
# %cd ../
#always be in src
!pwd


In [0]:
!git commit -a -m "updated test.ipynb"



In [0]:
# Fix: DataLoader causing `RuntimeError: received 0 items of ancdata`
# set ulimit top higher, 
# !ulimit -n 4096 # no help, need to run outside and then start jupyter
# !ulimit -n
# 2048: e: 931 die
# torch.multiprocessing.set_sharing_strategy(

In [0]:
from chexnet import ChexNet
from unet import Unet
from dataset import ChestXray14Dataset
from transform import tta
from metrics import aucs
from constant import CLASS_NAMES, IMAGENET_MEAN, IMAGENET_STD
from fastai.conv_learner import *

from matplotlib.patches import Patch
import pandas as pd
import skimage
from scipy import ndimage
from pathlib import Path
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data import Dataset
import torch
import torchvision.transforms as transforms
from sklearn.metrics import roc_curve, auc

In [0]:
PATH = Path('dir')
IMAGE_DN = 'test_list_images'
CSV_FILE = 'test_list.csv'

In [0]:
chexnet_model = 'models'#'20180429-130928'
chexnet = ChexNet(trained=True, model_name=chexnet_model).cuda()
chexnet.eval();

In [0]:
normalize = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
toTensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()

# Test one stage

In [0]:
def get_test_dl(sz, bs, tfm):
    df = pd.read_csv(PATH/CSV_FILE, header=None, sep=' ')
    image_names = df.iloc[:, 0].values
    labels = df.iloc[:, 1:].values
    dataset = ChestXray14Dataset(image_names, labels, tfm, PATH/IMAGE_DN, sz, percentage=1)
    return DataLoader(dataset, bs, shuffle=True, num_workers=6)

In [0]:
def test(model, dl, tta=False):
    targets = []
    preds = []

    for image, target in dl:
        if tta:
            bs, cs, c, h, w = image.shape
            image = image.view(-1, c, h, w)
            
        pred = model(Variable(image.cuda()))
        if tta:
            pred = pred.view(bs, cs, -1).mean(1)
            
        targets.append(target.detach().cpu()) # detach remove this tensor from computation graph
        preds.append(pred.detach().cpu()) # if not call -> gpu memory leak since it still hold reference to computation graph

    targets = torch.cat(targets)
    preds = torch.cat(preds)

    all_aucs = aucs(torch.sigmoid(preds), targets)
    avg_auc = torch.mean(all_aucs)
    print(f'The average AUROC is {avg_auc:.3}')
    for i in range(14):
        print(f'The AUROC of {CLASS_NAMES[i]} is {all_aucs[i]:.3}')
    return targets, preds

## One crop

In [0]:
tfm = transforms.Compose([
    transforms.Resize(224),
    toTensor,
    normalize
])

In [0]:
dl = get_test_dl(224, 16, tfm)

In [0]:
test(chexnet, dl)

## Five Crop 

In [0]:
tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.FiveCrop(224),
    transforms.Lambda(lambda crops: torch.stack([toTensor(crop) for crop in crops])),
    transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
])

In [0]:
dl = get_test_dl(224, 4, tfm)
test(chexnet, dl, tta=True)

## Ten Crop

In [0]:
tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.TenCrop(224),
    transforms.Lambda(lambda crops: torch.stack([toTensor(crop) for crop in crops])),
    transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
])

In [0]:
dl = get_test_dl(224, 2, tfm)
dl

In [0]:

test(dl, tta=True)

# Segmentation

In [0]:
# Only for segmented case
class TestChestXray14Dataset(Dataset): 
    '''
    Get image for train, validate and test base on NIH split
    '''

    def __init__(self, image_names, labels, transform, path, size, percentage=0.1, segmented_dict=None):
        self.labels = labels
        self.percentage = percentage
        self.size = size
        self.image_names = image_names
        self.path = path
        self.transform = transform
        self.segmented_dict = segmented_dict
        
    def __getitem__(self, index):
        image_file = self.path/self.image_names[index]
        image = Image.open(image_file).convert('RGB') # 1 channel segmented_dictge
        coords = self.segmented_dict[self.image_names[index]].split(' ')
        bb = list(map(lambda x: int(x), coords))
        image = image.crop(bb)
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return int(self.image_names.shape[0] * self.percentage)

    @property
    def sz(self):
        # fastai compatible: learn.summary()
        return self.size

In [0]:
def get_segmented_test_dl(sz, bs, tfm, segmented_file):
    df = pd.read_csv(PATH/CSV_FILE, header=None, sep=' ')
    image_names = df.iloc[:, 0].values
    labels = df.iloc[:, 1:].values
    with open(PATH/segmented_file, 'rb') as f:
        d = pickle.load(f)
    dataset = TestChestXray14Dataset(image_names, labels, tfm, PATH/IMAGE_DN, sz, percentage=1, segmented_dict=d)
    return DataLoader(dataset, bs, num_workers=6)

In [0]:
def test_two_stage(dl, tta=False):
    targets = []
    preds = []

    for image, target in dl:
#         imgs = []
#         for img in image:
#             img = to_pil(img)
#             img_v = V(unet_tfm(img)[None])
#             py = torch.sigmoid(segmentor(V(img_v)))
#             py = (py[0].cpu() > 0.5).type(torch.FloatTensor)
#             labels = skimage.measure.label(py[0].numpy())
#             mask = np.logical_or(labels==2, labels==1).astype(np.float32) # left nd right lung, 0 for background
#             mask = cv2.resize(mask, (1024, 1024))
#             slice_y, slice_x = ndimage.find_objects(mask, 1)[0]
#             img = img.crop((slice_x.start, slice_y.start, slice_x.stop, slice_y.stop))
#             img = chexnet_tfm(img)
#             imgs.append(img)
#         imgs = torch.stack(imgs)

        if tta:
            bs, cs, c, h, w = image.shape
            image = image.view(-1, c, h, w)
            
        pred = chexnet(Variable(image.cuda()))
        
        if tta:
            pred = pred.view(bs, cs, -1).mean(1)
            
        targets.append(target.detach().cpu()) # detach remove this tensor from computation graph
        preds.append(pred.detach().cpu()) # if not call -> gpu memory leak since it still hold reference to computation graph

    targets = torch.cat(targets)
    preds = torch.cat(preds)

    all_aucs = aucs(torch.sigmoid(preds), targets)
    avg_auc = torch.mean(all_aucs)
    print(f'The average AUROC is {avg_auc:.3}')
    for i in range(14):
        print(f'The AUROC of {CLASS_NAMES[i]} is {all_aucs[i]:.3}')
    return targets, preds

In [0]:
chexnet_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    toTensor,
    normalize
])
    
dl = get_segmented_test_dl(256, 16, chexnet_tfm, 'cut_all.pickle')
test_two_stage(dl, tta=False)

In [0]:
chexnet_tta_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.FiveCrop(224),
    transforms.Lambda(lambda crops: torch.stack([toTensor(crop) for crop in crops])),
    transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
])

dl = get_segmented_test_dl(256, 4, chexnet_tta_tfm, 'cut_all.pickle')
test_two_stage(dl, tta=True)

# ROC analysis

### 1 stage vs 2 stage

In [0]:
# 1 stage
dl = get_test_dl(224, 16, tfm)
one_stage_targets, one_stage_preds = test(chexnet, dl)
one_stage_roc = {}
for i in range(len(CLASS_NAMES)):
    one_stage_roc[CLASS_NAMES[i]] = roc_curve(to_np(one_stage_targets[:, i]), to_np(one_stage_preds[:, i]))

In [0]:
# 2 stage
chexnet_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    toTensor,
    normalize
])
    
dl = get_segmented_test_dl(256, 16, chexnet_tfm, 'cut_all.pickle')
two_stage_targets, two_stage_preds = test_two_stage(dl, tta=False)
two_stage_roc = {}
for i in range(len(CLASS_NAMES)):
    two_stage_roc[CLASS_NAMES[i]] = roc_curve(to_np(two_stage_targets[:, i]), to_np(two_stage_preds[:, i]))

In [0]:
fig, axes = plt.subplots(3, 5, figsize=(20, 14))
for i, ax in enumerate(axes.flat):
    if i == 14:
        ax.plot(0, 0, c='r', label='one stage')
        ax.plot(0, 0, c='b', label='two stage')
        ax.legend( loc='center')
        ax.set_axis_off()
        break
    cn = CLASS_NAMES[i]
    
    # one stage
    fpr, tpr, threshold = one_stage_roc[cn]
    one_stage_auc = auc(fpr, tpr)
    one_stage_artist = ax.plot(fpr, tpr , c='r', label=f'AUC={one_stage_auc:0.3}')
    
    # two stage
    fpr, tpr, threshold = two_stage_roc[cn]
    two_stage_auc = auc(fpr, tpr)
    two_stage_artist = ax.plot(fpr, tpr, c='b', label=f'AUC={two_stage_auc:.3}')
    
    # enhance
    ax.plot([0, 1], [0, 1], color='navy', linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_aspect('equal', 'box')
    ax.set_title(cn)
    ax.legend(loc='lower right')

fig.text(0.5, 0.1, '1-specificity', ha='center', fontsize=20)
fig.text(0.095, 0.5, 'sensitivity', va='center', rotation='vertical', fontsize=20)

plt.subplots_adjust(hspace=0.0001)
fig.savefig('two_stage_roc.png')

### resnet vs densenet

In [0]:
from models.resnet import Resnet

MODEL_FILE = '/mnt/data/xray-thesis/models/resnet/resnet50/20180501-212649/model.path.tar'
d = torch.load(MODEL_FILE)
model_state = d['state_dict']

In [0]:
resnet = Resnet('resnet50').cuda()
resnet.load_state_dict(model_state)

In [0]:
dl = get_test_dl(224, 16, tfm)

In [0]:
# resnet
resnet_targets, resnet_preds = test(resnet, dl)
resnet_roc = {}
for i in range(len(CLASS_NAMES)):
    resnet_roc[CLASS_NAMES[i]] = roc_curve(to_np(resnet_targets[:, i]), to_np(resnet_preds[:, i]))

In [0]:
# densnet
densenet_targets, densenet_preds = test(chexnet, dl)
densenet_roc = {}
for i in range(len(CLASS_NAMES)):
    densenet_roc[CLASS_NAMES[i]] = roc_curve(to_np(densenet_targets[:, i]), to_np(densenet_preds[:, i]))

In [0]:
fig, axes = plt.subplots(3, 5, figsize=(20, 14))
for i, ax in enumerate(axes.flat):
    if i == 14:
        ax.plot(0, 0, c='r', label='resnet-50')
        ax.plot(0, 0, c='b', label='densenet-121')
        ax.legend( loc='center')
        ax.set_axis_off()
        break
    cn = CLASS_NAMES[i]
    
    # resnet
    fpr, tpr, threshold = resnet_roc[cn]
    resnet_auc = auc(fpr, tpr)
    ax.plot(fpr, tpr , c='r', label=f'AUC={resnet_auc:.3}')
    
    # densenet
    fpr, tpr, threshold = densenet_roc[cn]
    densenet_auc = auc(fpr, tpr)
    ax.plot(fpr, tpr, c='b', label=f'AUC={densenet_auc:.3}')
    
    # enhance
    ax.plot([0, 1], [0, 1], color='navy', linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_aspect('equal', 'box')
    ax.set_title(cn)
    ax.legend(loc='lower right')

fig.text(0.5, 0.1, '1-specificity', ha='center', fontsize=20)
fig.text(0.095, 0.5, 'sensitivity', va='center', rotation='vertical', fontsize=20)

plt.subplots_adjust(hspace=0.0001)
fig.savefig('res_dense_roc.png')