In [1]:
import os
import sys
import torch
import accimage
import numpy as np
import pandas as pd
from PIL import Image
from imageio import imread
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms, set_image_backend, get_image_backend
import torch.nn.functional as F
import pickle
import train_utils
import data_utils
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend
set_image_backend('accimage')
from data_utils import *
import train_utils

%reload_ext autoreload
%autoreload 2

state_dict_file = '/n/tcga_models/resnet18_WGD_all_10x.pt'
input_size = 2048
hidden_size = 512
output_size = 1



class Attention(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, gated=False):
        super(Attention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.gated = gated
        self.V = nn.Linear(input_size, hidden_size)
        self.U = nn.Linear(input_size, hidden_size)
        self.w = nn.Linear(hidden_size, output_size)
        self.sigm = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.sm = nn.Softmax(dim=0)
        self.linear_layer = nn.Linear(input_size,1)
        
    def forward(self, h):
        if self.gated == True:
            a = self.sm(self.w(self.tanh(self.V(h)) * self.sigm(self.U(h))))
        else:
            a = self.sm(self.w(self.tanh(self.V(h))))
        z = torch.sum(a*h,dim=0)
        logits = self.linear_layer(z)
        return logits,a


# initialize trained resnet
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
device = torch.device('cuda',1)
for p in resnet.parameters():
    p.requires_grad = False
attend_and_pool = Attention(input_size, hidden_size, output_size)
resnet.fc = attend_and_pool
resnet.cuda(device=device)


optim = torch.optim.Adam(resnet.fc.parameters(), lr = 1e-4)
train_cancers = ['READ_10x']
val_cancers = ['READ_10x']

root_dir = '/n/mounted-data-drive/'
magnification = '10.0'
criterion=nn.BCEWithLogitsLoss()

In [2]:
class TCGA_random_tiles_sampler(Dataset):
    """TCGA dataset."""

    def __init__(self, sample_annotations, root_dir, transform=None, loader=default_loader, 
                 magnification='5.0', tile_batch_size = 256):
        """
        Args:
            sample_annot (dict): dictionary of sample names and their respective labels.
            root_dir (string): directory containing all of the samples and their respective images.
            transform (callable, optional): optional transform to be applied on the images of a sample.
        """
        self.sample_names = list(sample_annotations.keys())
        self.sample_labels = list(sample_annotations.values())
        self.root_dir = root_dir
        self.transform = transform
        self.loader = loader
        self.magnification = magnification
        self.img_dirs = [self.root_dir + sample_name + '.svs/' \
                         + sample_name + '_files/' + self.magnification for sample_name in self.sample_names]
        self.jpegs = [os.listdir(img_dir) for img_dir in self.img_dirs]
        self.all_jpegs = []
        self.all_labels = []
        self.jpg_to_sample = []
        self.coords = []
        self.tile_batch_size = tile_batch_size
        for idx,(im_dir,label,l) in enumerate(zip(self.img_dirs,self.sample_labels,self.jpegs)):
            sample_coords = []
            for jpeg in l:
                self.all_jpegs.append(im_dir+'/'+jpeg)
                self.all_labels.append(label)
                self.jpg_to_sample.append(idx)
                x,y = jpeg[:-5].split('_') # 'X_Y.jpeg'
                x,y = int(x), int(y)
                sample_coords.append(torch.tensor([x,y]))
            self.coords.append(torch.stack(sample_coords))
                
            
    def __len__(self):
        ''' number of slides: jpegs is a list of lists '''
        return len(self.jpegs)

    def __getitem__(self, idx):
        slide_tiles = []
        tiles_batch = []
        perm = torch.randperm(len(self.jpegs[idx]))
        
        if len(self.jpegs[idx]) > self.tile_batch_size:
            idxs = perm[:self.tile_batch_size]
        else: 
            idxs = range(len(self.jpegs[idx]))
            
        for tile_num in idxs:
            im = self.jpegs[idx][tile_num]
            path = self.img_dirs[idx] + '/' + im
            image = self.loader(path)
            
            if self.transform is not None:
                image = self.transform(image)
            if image.shape[1] < 256 or image.shape[2] < 256:
                image = pad_tensor_up_to(image,256,256,channels_last=False)
            tiles_batch.append(image)

        # create batch of random tiles
        slide = torch.stack(tiles_batch)

        label = self.sample_labels[idx]
        coords = torch.stack([self.coords[idx][i] for i in idxs])
        return slide, label, coords

In [3]:
sa_train, sa_val = data_utils.load_COAD_train_val_sa_pickle('/n/tcga_models/resnet18_WGD_v04_sa.pkl')
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, sa_train1, sa_val1, sa_train2, sa_val2 = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file,
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)

sa_trains = [dict(sa_train1[idx], **sa_train2[idx]) for idx,_ in enumerate(sa_train1)]
sa_vals = [dict(sa_val1[idx], **sa_val2[idx]) for idx,_ in enumerate(sa_val1)]


train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation




magnification = '10.0'
root_dir = '/n/mounted-data-drive/'

train_sets = []
val_sets = []

for i in range(len(train_cancers)):
    train_set = data_utils.TCGA_random_tiles_sampler(sa_trains[batch_all.index(train_cancers[i])], 
                                             root_dir + train_cancers[i] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification,tile_batch_size=512)
    train_sets.append(train_set)    

for j in range(len(val_cancers)):
    val_set = data_utils.TCGA_random_tiles_sampler(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification,tile_batch_size=512)
    val_sets.append(val_set)



In [4]:
train_loader = torch.utils.data.DataLoader(train_set,batch_size=1,shuffle=True,num_workers=16, 
                                            pin_memory=False)

val_loader = torch.utils.data.DataLoader(val_set,batch_size=1,shuffle=True,num_workers=16, 
                                            pin_memory=False)

In [33]:
def training_loop_random_sampling(e,train_loader,device,criterion,resnet,optimizer,gradient_step_length=3,reporting_step_length=10):
    grads = []
    track_loss = torch.tensor(0.0,device=device)
    for p in resnet.fc.parameters():
        grads.append(torch.zeros_like(p.data,device=device))
    for step,(slide, label, coords) in enumerate(train_loader):
        optimizer.zero_grad()
        slide,label = slide.squeeze(0).cuda(device),label.cuda(device)
        logits,_ = resnet(slide)
        loss = criterion(logits,label.float())
        loss.backward()
        track_loss += loss.detach().clone()
        for ix,p in enumerate(resnet.fc.parameters()):
            grads[ix] += p.grad.detach().clone()
            
        optimizer.zero_grad()
        if step%gradient_step_length ==0 and step>0:
            for ix,p in enumerate(resnet.fc.parameters()):
                p.grad.data = grads[ix]/gradient_step_length
                grads[ix] = torch.zeros_like(p.data,device=device)
            optimizer.step()
            optimizer.zero_grad()
                
        if step%reporting_step_length == 0 and step>0:
            print('Epoch: {0}, Step: {1}, Train NLL: {2:0.4f}'.format(e, step, track_loss.detach().cpu().numpy()/reporting_step_length))
            track_loss = 0.0
    del slide, label, loss, logits, _
    torch.cuda.empty_cache()

In [6]:
def validation_loop_for_random_sampler(e,val_loader,device,criterion,resnet):
    pred_batch = []
    true_label = []
    torch.cuda.empty_cache()
    loss = torch.tensor(0.0,device=device)
    with torch.no_grad():
        for step,(slide, label, coords) in enumerate(val_loader):
            slide,label = slide.squeeze(0).cuda(device),label.cuda(device)
            logits,_ = resnet(slide)
            loss += criterion(logits,label.float())
            pred_batch.append(torch.sigmoid(logits).detach().cpu().numpy()>0.5)
            true_label.append(label.detach().cpu().numpy())
            
            del slide,label,logits,_
            torch.cuda.empty_cache()
    #scheduler.step(loss)
    pred_batch = np.array(pred_batch)
    true_label = np.array(true_label)
    acc = np.mean(pred_batch==true_label)
    acc_1 = np.mean(pred_batch[true_label==1])
    acc_0 = np.mean(1-pred_batch[true_label==0])
    loss = loss.detach().cpu().numpy()
    
    print('Epoch: {0}, Val Mean NLL: {1:0.4f}, Val Accuracy: {2:0.2f} \
           Class Accuracy: WGD = {3:0.2f}, Diploid = {4:0.2f}'\
              .format(e,loss/step,acc,acc_1,acc_0))

In [11]:
for e in range(100):
    training_loop_random_sampling(e,train_loader,device,criterion,resnet,optim)
    validation_loop_for_random_sampler(e,val_loader,device,criterion,resnet)

Epoch: 0, Step: 0, Train NLL: 0.4971
Epoch: 0, Step: 30, Train NLL: 9.5432
Epoch: 0, Step: 60, Train NLL: 11.7971
Epoch: 0, Step: 90, Train NLL: 15.6518
Epoch: 0, Val Total NLL: 40.6831, Val Accuracy: 0.38            Class Accuracy: WGD = 0.40, Diploid = 0.37
Epoch: 1, Step: 0, Train NLL: 0.1479
Epoch: 1, Step: 30, Train NLL: 17.3398
Epoch: 1, Step: 60, Train NLL: 9.5740
Epoch: 1, Step: 90, Train NLL: 13.1353
Epoch: 1, Val Total NLL: 30.3013, Val Accuracy: 0.41            Class Accuracy: WGD = 0.50, Diploid = 0.37
Epoch: 2, Step: 0, Train NLL: 0.4447
Epoch: 2, Step: 30, Train NLL: 11.6979
Epoch: 2, Step: 60, Train NLL: 13.0108
Epoch: 2, Step: 90, Train NLL: 16.3869
Epoch: 2, Val Total NLL: 40.2237, Val Accuracy: 0.41            Class Accuracy: WGD = 0.50, Diploid = 0.37
Epoch: 3, Step: 0, Train NLL: 0.4103
Epoch: 3, Step: 30, Train NLL: 11.5963
Epoch: 3, Step: 60, Train NLL: 10.4036
Epoch: 3, Step: 90, Train NLL: 9.5632
Epoch: 3, Val Total NLL: 35.9975, Val Accuracy: 0.38            Cl

KeyboardInterrupt: 

In [27]:
sa_dev_train = dict()
for ix,(key, val) in enumerate(sa_trains[5].items()):
    sa_dev_train[key] = val
    if ix ==10:
        break

In [28]:
train_set = data_utils.TCGA_random_tiles_sampler(sa_dev_train, 
                                             root_dir + train_cancers[0] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification,tile_batch_size=512)    


In [29]:
train_loader = torch.utils.data.DataLoader(train_set,batch_size=1,shuffle=True,num_workers=16, 
                                            pin_memory=False)


In [34]:
for e in range(100):
    training_loop_random_sampling(e,train_loader,device,criterion,resnet,optim,gradient_step_length=3,reporting_step_length=3)

Epoch: 0, Step: 3, Train NLL: 0.1772
Epoch: 0, Step: 6, Train NLL: 0.2087
Epoch: 0, Step: 9, Train NLL: 0.2479
Epoch: 1, Step: 3, Train NLL: 0.1943
Epoch: 1, Step: 6, Train NLL: 0.1124
Epoch: 1, Step: 9, Train NLL: 0.0910
Epoch: 2, Step: 3, Train NLL: 0.1909
Epoch: 2, Step: 6, Train NLL: 0.0399
Epoch: 2, Step: 9, Train NLL: 0.1268
Epoch: 3, Step: 3, Train NLL: 0.1708
Epoch: 3, Step: 6, Train NLL: 0.1538
Epoch: 3, Step: 9, Train NLL: 0.1201
Epoch: 4, Step: 3, Train NLL: 0.8201
Epoch: 4, Step: 6, Train NLL: 0.0482
Epoch: 4, Step: 9, Train NLL: 0.1233
Epoch: 5, Step: 3, Train NLL: 0.2890
Epoch: 5, Step: 6, Train NLL: 0.1268
Epoch: 5, Step: 9, Train NLL: 0.0539
Epoch: 6, Step: 3, Train NLL: 0.3106
Epoch: 6, Step: 6, Train NLL: 0.0617
Epoch: 6, Step: 9, Train NLL: 0.1161
Epoch: 7, Step: 3, Train NLL: 0.2309
Epoch: 7, Step: 6, Train NLL: 0.1127
Epoch: 7, Step: 9, Train NLL: 0.1715
Epoch: 8, Step: 3, Train NLL: 0.1389
Epoch: 8, Step: 6, Train NLL: 0.0622
Epoch: 8, Step: 9, Train NLL: 0.1609
E

KeyboardInterrupt: 