In [None]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [None]:
sa_train, sa_val = data_utils.load_COAD_train_val_sa_pickle('/n/tcga_models/resnet18_WGD_10x_sa.pkl')

In [None]:
root_dir = '/n/mounted-data-drive/COAD/'
magnification = '10.0'
batch_type = 'tile'

In [None]:
train_transform = train_utils.transform_train
train_set = data_utils.TCGADataset_tiles(sa_train, root_dir, transform=train_transform, magnification=magnification, batch_type=batch_type)
train_loader = DataLoader(train_set, batch_size=1, shuffle=True, pin_memory=True, num_workers=1)

In [None]:
val_transform = train_utils.transform_validation
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=val_transform, magnification=magnification, batch_type=batch_type)
valid_loader = DataLoader(val_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=1)

In [None]:
def update_tile_shape(H_in, W_in, kernel_size, dilation = 1., padding = 0., stride = 1.):
    H_out = (H_in + 2. * padding - dilation * (kernel_size-1) -1)/stride + 1
    W_out = (W_in + 2. * padding - dilation * (kernel_size-1) -1)/stride + 1
    return int(np.floor(H_out)),int(np.floor(W_out))

In [None]:
class Generator(nn.Module):
    def __init__(self, n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=0.5,
                dilation = 1., padding = 0, H_in = 32, W_in = 32):
        super(Generator, self).__init__()
        self.n_conv_layers = n_conv_layers
        self.kernel_size = kernel_size
        self.n_conv_filters = n_conv_filters
        self.hidden_size = hidden_size
        self.n_rnn_layers = n_rnn_layers
        self.conv_layers = []
        self.m = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU()
        self.H_in, self.W_in = H_in, W_in
         
        in_channels = 3        
        for layer in range(self.n_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, self.n_conv_filters[layer], self.kernel_size[layer]))
            self.conv_layers.append(self.relu)
            self.conv_layers.append(self.m)
            # convolution
            self.H_in, self.W_in = update_tile_shape(self.H_in, self.W_in, kernel_size[layer])
            # max pooling
            self.H_in, self.W_in = update_tile_shape(self.H_in, self.W_in, 2, stride = 2)
            in_channels = self.n_conv_filters[layer]            
        in_channels = in_channels * self.H_in * self.W_in
        self.conv = nn.Sequential(*self.conv_layers)    
        self.lstm = nn.LSTM(in_channels, self.hidden_size, self.n_rnn_layers, batch_first=True, 
                            dropout=dropout, bidirectional=True) 
        in_channels = hidden_size * 2
        self.classification_layer = nn.Linear(in_channels, 2)
        
    def forward(self, x):
        embed = self.conv(x)
        embed = embed.view(1,x.shape[0],-1)
        self.lstm.flatten_parameters()
        output, hidden = self.lstm(embed)
        y = self.classification_layer(output)
        return y
    
    def zero_grad(self):
        """Sets gradients of all model parameters to zero."""
        for p in self.parameters():
            if p.grad is not None:
                p.grad.data.zero_()  

In [None]:
class ConvNet(nn.Module):
    def __init__(self, n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=0.5,
                dilation = 1., padding = 0, H_in = 32, W_in = 32):
        super(ConvNet, self).__init__()
        self.n_conv_layers = n_conv_layers
        self.n_fc_layers = n_fc_layers
        self.kernel_size = kernel_size
        self.n_conv_filters = n_conv_filters
        self.hidden_size = hidden_size
        self.conv_layers = []
        self.fc_layers = []
        self.m = nn.MaxPool2d(2, stride=2)
        self.n = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.H_in, self.W_in = H_in, W_in
        
        in_channels = 3        
        for layer in range(self.n_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, self.n_conv_filters[layer], self.kernel_size[layer]))
            self.conv_layers.append(self.relu)
            self.conv_layers.append(self.m)
            # convolution
            self.H_in, self.W_in = update_tile_shape(self.H_in, self.W_in, kernel_size[layer])
            # max pooling
            self.H_in, self.W_in = update_tile_shape(self.H_in, self.W_in, 2, stride = 2)
            in_channels = self.n_conv_filters[layer]
        in_channels = in_channels * self.H_in * self.W_in
        for layer in range(self.n_fc_layers):
            self.fc_layers.append(nn.Linear(in_channels, self.hidden_size[layer]))
            self.fc_layers.append(self.relu)
            self.fc_layers.append(self.n)
            in_channels = self.hidden_size[layer]
        self.conv = nn.Sequential(*self.conv_layers)
        self.fc = nn.Sequential(*self.fc_layers)
        self.classification_layer = nn.Linear(in_channels, 2)
        
    def forward(self, x):
        embed = self.conv(x)
        embed = embed.view(x.shape[0],-1)
        y = self.fc(embed)
        return y

In [None]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = torch.rand(shape,dtype=torch.float32,device='cuda')
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(logits.shape)
    return F.softmax( y / temperature,dim=1)

def gumbel_softmax(logits, temperature, hard=False):
    """
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        y = torch.argmax(logits,dim=1)
    return y

In [None]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [None]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
gen.cuda()

n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
enc.cuda()

lamb1 = 0
lamb2 = 0
xent = nn.CrossEntropyLoss()
learning_rate = 1e-4
temp = 10
params = list(enc.parameters()) + list(gen.parameters())
optimizer = torch.optim.Adam(params, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, min_lr=1e-6)

In [None]:
size = 32
stride = 32
e = 0
lsm = nn.LogSoftmax(dim=2)

In [None]:
# train loop
gen.train()
enc.train()

rat_tiles = 0
total_tiles = 0
total_xeloss = 0
total_omega = 0

for step,(slide,label) in enumerate(train_loader):
    slide,label = slide.squeeze(0).cuda(),label.cuda()
    slide = slide.unfold(1,size,stride).unfold(2,size,stride).contiguous().view(-1,3,size,size) # num_patches x 3 x 32 x 32

    # generate tile rationales
    preds = gen(slide) # num_patches x 2
    logits = lsm(preds).squeeze(0)
    sample = gumbel_softmax(logits, temperature=temp)
    rationale = slide.view(slide.shape[1], slide.shape[2], slide.shape[3], -1) * sample[:,1]
    rationale = rationale.view(-1, slide.shape[1], slide.shape[2], slide.shape[3])

    # predict class based on rationales
    output = enc(rationale) # num_patches x 512
    pool = pool_fn(output) # 512
    y_hat = enc.classification_layer(pool) # 2

    # compute loss and regularization term
    znorm = torch.sum(sample[:,1])
    zdist = torch.sum(torch.abs(sample[:-1,1] - sample[1:,1]))
    omega = ((lamb1 * znorm) + (lamb2 * zdist)) / sample.shape[0]
    xeloss = xent(y_hat.unsqueeze(0), label)
    loss = xeloss + omega
    loss.backward()
    optimizer.step()

    optimizer.zero_grad()
    total_xeloss += xeloss.detach().cpu().numpy()
    total_omega += omega.detach().cpu().numpy() 

    rat_tiles += znorm
    total_tiles += float(sample.shape[0])
    
    if step % 100 == 0:
        print('Epoch: {0}, Step: {1}, Loss: {2:0.4f}, Omega: {3:0.4f}'.format(e, step, xeloss.detach().cpu().numpy(),
                                                                              omega.detach().cpu().numpy()))
frac_tiles = rat_tiles / total_tiles
print('Epoch: {0}, Train Loss: {1:04f}, Train Omega: {2:0.4f}, Fraction of Tiles: {3:0.4f}'.format(e, total_xeloss, 
                                                                                                  total_omega, frac_tiles))

In [None]:
# val loop
gen.eval()
enc.eval()

rat_tiles = 0
total_tiles = 0
total_loss = 0
labels = []
preds = []

for step,(slide,label) in enumerate(valid_loader):
    slide, label = slide.squeeze(0).cuda(), label.cuda()
    slide = slide.unfold(1,size,stride).unfold(2,size,stride).contiguous().view(-1,3,size,size) # num_patches x 3 x 32 x 32
        
    prez = gen(slide)
    z = torch.argmax(prez, dim=2).squeeze(0)
    rationale = slide[z==1,:,:,:]
    znorm = torch.sum(z.float())

    if znorm > 0:
        output = enc(rationale)
        pool = pool_fn(output)
        y_hat = enc.classification_layer(pool)

        loss = xent(y_hat.unsqueeze(0), label)
        total_loss += loss.detach().cpu().numpy()

        rat_tiles += znorm
        total_tiles += float(z.shape[0])

        labels.extend(label.float().cpu().numpy())
        preds.append(torch.argmax(y_hat).float().detach().cpu().numpy())
        
        if step % 100 == 0:
            print('Epoch: {0}, Step: {1}, Loss: {2:0.4f}'.format(e, step, loss.detach().cpu().numpy()))

if e > 50:
    scheduler.step(total_loss)

acc = np.mean(np.array(labels) == np.array(preds))
frac_tiles = rat_tiles / total_tiles if total_tiles else 0
print('Epoch: {0}, Val Loss: {1:0.4f}, Val Acc: {2:0.4f}, Fraction of Tiles: {3:0.4f}, Total Tiles: {4}'.format(e, total_loss, acc, frac_tiles, total_tiles))
return total_loss, frac_tiles, total_tiles