In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

In [2]:
device = torch.device('cuda', 1)
batch_size = 128
n_workers = 6
set_image_backend('accimage')
root_dir = '/n/mounted-data-drive/COAD/'

In [3]:
task = 'WGD'
output_shape = 1
PATH = '/n/tcga_models/resnet18_WGD_v03.pt'
_, sa_val = data_utils.process_WGD_data()

In [4]:
transform = train_utils.transform_validation
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=transform, magnification='5.0')
jpg_to_sample = val_set.jpg_to_sample

In [5]:
valid_loader = DataLoader(val_set, batch_size=batch_size, pin_memory=False, num_workers=n_workers)

In [6]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(PATH, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.cuda(device=device)
resnet.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [7]:
total_loss = 0
all_labels = []
all_preds = []
all_loss = []

if task == 'MSI':
    encoding = torch.tensor([[0,0],[1,0],[1,1]], device=device).float()
elif task == 'WGD' or task == 'MSI-SINGLE_LABEL':
    encoding = torch.tensor([[0],[1]], device=device).float()

In [8]:
dataset = 'Val'
criterion = nn.BCEWithLogitsLoss(reduction = 'none')

In [9]:
for idx,(batch,labels) in enumerate(valid_loader):
    batch, labels = batch.cuda(device=device), encoding[labels.cuda(device=device)]
    output = resnet(batch)
    loss = criterion(output, labels)

    total_loss += torch.sum(loss.detach().mean(dim=1)).cpu().numpy()
    all_labels.extend(torch.sum(labels, dim=1).float().cpu().numpy())
    all_preds.extend(torch.sum(torch.sigmoid(output) > 0.5, dim=1).float().detach().cpu().numpy())
    all_loss.extend(loss.detach().mean(dim=1).cpu().numpy())

    if idx % 100 == 0:
        print('Batch: {0}, {2} NLL: {1:0.4f}'.format(idx, torch.sum(loss.detach())/batch.shape[0], dataset))

Batch: 0, Val NLL: 0.4373
Batch: 100, Val NLL: 0.3075
Batch: 200, Val NLL: 0.0993
Batch: 300, Val NLL: 0.8528
Batch: 400, Val NLL: 0.5525


In [10]:
e = 0
acc = np.mean(np.array(all_labels) == np.array(all_preds))

d = {'label': all_labels, 'pred': all_preds, 'sample': jpg_to_sample}
df = pd.DataFrame(data = d)
df['correct_tile'] = df['label'] == df['pred']
df.groupby(['label'])['correct_tile'].mean()
tile_acc_by_label = ', '.join([str(i) + ': ' + str(float(df.groupby(['label'])['correct_tile'].mean()[i]))[:6] for i in range(encoding.shape[0])])

df2 = df.groupby(['sample'])['label','pred'].mean().round()
df2['correct_sample'] = df2['label'] == df2['pred']
mean_pool_acc = df2['correct_sample'].mean()

df3 = df.groupby(['sample'])['label','pred'].max()
df3['correct_sample'] = df3['label'] == df3['pred']
max_pool_acc = df3['correct_sample'].mean()

slide_acc_by_label = ', '.join([str(i) + ': ' + str(float(df2.groupby(['label'])['correct_sample'].mean()[i]))[:6] for i in range(encoding.shape[0])])

print('Epoch: {0}, Avg {3} NLL: {1:0.4f}, Median {3} NLL: {2:0.4f}'.format(e, total_loss/(float(idx+1) * batch_size), 
                                                                           np.median(all_loss), dataset))
print('------ {2} Tile-Level Acc: {0:0.4f}; By Label: {1}'.format(acc, tile_acc_by_label, dataset))
print('------ {2} Slide-Level Acc: Mean-Pooling: {0:0.4f}, Max-Pooling: {1:0.4f}'.format(mean_pool_acc, max_pool_acc, 
                                                                                         dataset))
print('------ {1} Slide-Level Acc (Mean-Pooling) By Label: {0}'.format(slide_acc_by_label, dataset))

Epoch: 0, Avg Val NLL: 0.3794, Median Val NLL: 0.1326
------ Val Tile-Level Acc: 0.8283; By Label: 0: 0.8690, 1: 0.7629
------ Val Slide-Level Acc: Mean-Pooling: 0.9512, Max-Pooling: 0.3902
------ Val Slide-Level Acc (Mean-Pooling) By Label: 0: 1.0, 1: 0.875


In [11]:
# tile-level ROC
fpr, tpr, thresholds = metrics.roc_curve(np.array(all_labels), np.array(all_preds)

SyntaxError: unexpected EOF while parsing (<ipython-input-11-9fff18cb5564>, line 2)

In [None]:
# slide-level ROC
