In [1]:
import os
import sys
import torch
import accimage
import numpy as np
import pandas as pd
from PIL import Image
from imageio import imread
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms, set_image_backend, get_image_backend
import torch.nn.functional as F
import pickle
import train_utils
import data_utils
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend
set_image_backend('accimage')
from data_utils import *
import train_utils
root_dir_coad = '/n/mounted-data-drive/COAD/'

%reload_ext autoreload
%autoreload 2

class Pool_and_classify(nn.Module):
    def __init__(self, pool_type, input_features,output_features ):
        super(Pool_and_classify, self).__init__()
        self.fc = nn.Linear(input_features,output_features)
        self.pool_type = pool_type
    def forward(self, x):
        if self.pool_type == 'mean':
            x = self.fc(torch.mean(x,dim=0))
        elif self.pool_type == 'max':
            x = self.fc(torch.max(x,dim=0))
        return x

In [2]:
class TCGA_random_tiles_sampler(Dataset):
    """TCGA dataset."""

    def __init__(self, sample_annotations, root_dir, transform=None, loader=default_loader, 
                 magnification='5.0', tile_batch_size = 256):
        """
        Args:
            sample_annot (dict): dictionary of sample names and their respective labels.
            root_dir (string): directory containing all of the samples and their respective images.
            transform (callable, optional): optional transform to be applied on the images of a sample.
        """
        self.sample_names = list(sample_annotations.keys())
        self.sample_labels = list(sample_annotations.values())
        self.root_dir = root_dir
        self.transform = transform
        self.loader = loader
        self.magnification = magnification
        self.img_dirs = [self.root_dir + sample_name + '.svs/' \
                         + sample_name + '_files/' + self.magnification for sample_name in self.sample_names]
        self.jpegs = [os.listdir(img_dir) for img_dir in self.img_dirs]
        self.all_jpegs = []
        self.all_labels = []
        self.jpg_to_sample = []
        self.coords = []
        self.tile_batch_size = tile_batch_size
        for idx,(im_dir,label,l) in enumerate(zip(self.img_dirs,self.sample_labels,self.jpegs)):
            sample_coords = []
            for jpeg in l:
                self.all_jpegs.append(im_dir+'/'+jpeg)
                self.all_labels.append(label)
                self.jpg_to_sample.append(idx)
                x,y = jpeg[:-5].split('_') # 'X_Y.jpeg'
                x,y = int(x), int(y)
                sample_coords.append(torch.tensor([x,y]))
            self.coords.append(torch.stack(sample_coords))
                
            
    def __len__(self):
        ''' number of slides: jpegs is a list of lists '''
        return len(self.jpegs)

    def __getitem__(self, idx):
        slide_tiles = []
        tiles_batch = []
        perm = torch.randperm(len(self.jpegs[idx]))
        
        if len(self.jpegs[idx]) > self.tile_batch_size:
            idxs = perm[:self.tile_batch_size]
        else: 
            idxs = range(len(self.jpegs[idx]))
            
        for tile_num in idxs:
            im = self.jpegs[idx][tile_num]
            path = self.img_dirs[idx] + '/' + im
            image = self.loader(path)
            
            if self.transform is not None:
                image = self.transform(image)
            if image.shape[1] < 256 or image.shape[2] < 256:
                image = pad_tensor_up_to(image,256,256,channels_last=False)
            tiles_batch.append(image)

        # create batch of random tiles
        slide = torch.stack(tiles_batch)

        label = self.sample_labels[idx]
        coords = torch.stack([self.coords[idx][i] for i in idxs])
        return slide, label, coords

In [3]:
sa_train, sa_val = data_utils.load_COAD_train_val_sa_pickle('/n/tcga_models/resnet18_WGD_v04_sa.pkl')
root_dir = '/n/mounted-data-drive/COAD/'
magnification = '5.0'
train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation


train_set = TCGA_random_tiles_sampler(sa_train, root_dir, transform=train_transform, magnification=magnification)
train_loader = DataLoader(train_set, batch_size=1, pin_memory=True, num_workers=12, shuffle=True)

val_set = TCGA_random_tiles_sampler(sa_val, root_dir, transform=val_transform, magnification=magnification, tile_batch_size=1500)
val_loader = DataLoader(val_set, batch_size=1, pin_memory=False, num_workers=12, shuffle=True)


In [4]:
device = torch.device('cuda',0)
resnet = models.resnet18(pretrained=True)
resnet.fc = Pool_and_classify(pool_type='mean',input_features=2048,output_features=1)
resnet.cuda(device)
learning_rate = 1e-4
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, min_lr=1e-8)


In [5]:
def training_loop_random_sampling(e,train_loader,device,criterion,resnet,optimizer):
    for step,(slide, label, coords) in enumerate(train_loader):
        optimizer.zero_grad()
        slide,label = slide.squeeze(0).cuda(device),label.cuda(device)
        logits = resnet(slide)
        loss = criterion(logits,label.float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if step%30 == 0:
            print('Epoch: {0}, Step: {1}, Train NLL: {2:0.4f}'.format(e, step, loss.detach().cpu().numpy()))
    del slide, label, loss, logits
    torch.cuda.empty_cache()

In [6]:
for e in range(1):
    training_loop_random_sampling(e,train_loader,device,criterion,resnet,optimizer)

Epoch: 0, Step: 0, Train NLL: 0.8872
Epoch: 0, Step: 30, Train NLL: 0.6759
Epoch: 0, Step: 60, Train NLL: 0.9597
Epoch: 0, Step: 90, Train NLL: 0.9654
Epoch: 0, Step: 120, Train NLL: 0.6199
Epoch: 0, Step: 150, Train NLL: 0.8669
Epoch: 0, Step: 180, Train NLL: 0.4607
Epoch: 0, Step: 210, Train NLL: 0.6722
Epoch: 0, Step: 240, Train NLL: 0.5693
Epoch: 0, Step: 270, Train NLL: 0.3679
Epoch: 0, Step: 300, Train NLL: 0.6737


In [7]:
def validation_loop_for_random_sampler(e,val_loader,device,criterion,resnet,scheduler):
    pred_batch = []
    true_label = []
    torch.cuda.empty_cache()
    loss = torch.tensor(0.0,device=device)
    with torch.no_grad():
        for step,(slide, label, coords) in enumerate(val_loader):
            slide,label = slide.squeeze(0).cuda(device),label.cuda(device)
            logits = resnet(slide)
            loss += criterion(logits,label.float())
            pred_batch.append(torch.sigmoid(logits).detach().cpu().numpy()>0.5)
            true_label.append(label.detach().cpu().numpy())
            
            del slide,label,logits
            torch.cuda.empty_cache()
    scheduler.step(loss)
    pred_batch = np.array(pred_batch)
    true_label = np.array(true_label)
    acc = np.mean(pred_batch==true_label)
    acc_1 = np.mean(pred_batch[true_label==1])
    acc_0 = np.mean(1-pred_batch[true_label==0])
    loss = loss.detach().cpu().numpy()
    
    print('Epoch: {0}, Val Total NLL: {1:0.4f}, Val Accuracy: {2:0.2f} \
           Class Accuracy: WGD = {3:0.2f}, Diploid = {4:0.2f}'\
              .format(e,loss,acc,acc_1,acc_0))

In [8]:
validation_loop_for_random_sampler(0,val_loader,device,criterion,resnet,scheduler)

RuntimeError: CUDA out of memory. Tried to allocate 5.86 GiB (GPU 0; 15.78 GiB total capacity; 7.13 GiB already allocated; 4.73 GiB free; 913.49 MiB cached)