In [1]:
import sys
sys.path.append('../')

import os
import gc
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'rb') as f: 
    train_embeddings, train_labels, train_jpgs_to_slide = pickle.load(f)
    
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'rb') as f: 
    val_embeddings, val_labels, val_jpgs_to_slide = pickle.load(f)

In [3]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = torch.rand(shape,dtype=torch.float32,device='cuda')
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(logits.shape)
    return F.softmax( y / temperature,dim=1)

def gumbel_softmax(logits, temperature, hard=False):
    """
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        y = torch.argmax(logits,dim=1)
    return y

In [4]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [5]:
input_size = 2048
hidden_size = 2048
output_size_gen = 2
output_size_enc = 1

gen = model_utils.Generator(input_size, hidden_size, output_size_gen, dropout=0.5)
enc = model_utils.Encoder(input_size, hidden_size, output_size_enc, pool_fn, dropout=0.5)

In [6]:
step_size = 10
lamb1 = 0
lamb2 = 0
temp = 10
criterion = nn.BCEWithLogitsLoss()
params = list(enc.parameters()) + list(gen.parameters())
learning_rate = 1e-5
optimizer = torch.optim.Adam(params, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, min_lr=1e-6)

In [7]:
n_samples_val = val_jpgs_to_slide.max()
n_samples_train = train_jpgs_to_slide.max()
idxs_train = np.linspace(0,n_samples_train,n_samples_train+1,dtype=int)
labels_to_idxs_train = np.concatenate([(train_labels[train_jpgs_to_slide==i]).unique().numpy() for i in idxs_train])
weights = 1/np.sum(labels_to_idxs_train==0),1/np.sum(labels_to_idxs_train==1)
sample_weight = [weights[l] for l in labels_to_idxs_train]
sample_weight = sample_weight/np.sum(sample_weight)

In [8]:
train_embeddings = train_embeddings.cuda()
val_embeddings = val_embeddings.cuda()

In [9]:
gen.cuda()
enc.cuda()

Encoder(
  (d): Dropout(p=0.5)
  (m): ReLU()
  (linear1): Linear(in_features=2048, out_features=2048, bias=True)
  (linear2): Linear(in_features=2048, out_features=1, bias=True)
)

In [10]:
lsm = nn.LogSoftmax(dim=1)

In [11]:
def initialize_loop_vars(step_size):
    logits_vec = torch.zeros((step_size+1,1)).cuda()
    labels_vec = torch.zeros_like(logits_vec).cuda()
    znorm_vec = torch.zeros_like(logits_vec).cuda()
    zdist_vec = torch.zeros_like(logits_vec).cuda()
    batch_idx = 0
    return logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx

In [12]:
def training_loop_rationales_gs(e, step_size, optimizer, gen, enc, pool_fn, train_embeddings, train_jpgs_to_slide, 
                                train_labels, criterion, n_samples, sample_weight, lamb1, lamb2, temp):
    gen.train()
    enc.train()
    
    logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx = initialize_loop_vars(step_size)
    track_loss = 0   
    track_omega = 0
    rat_tiles = 0
    total_tiles = 0
    
    idxs_train = np.linspace(0,n_samples,n_samples+1,dtype=int)
    idexs = np.random.choice(idxs_train,size=n_samples.numpy(),p=sample_weight)
    
    for idx in idexs:
        slide = train_embeddings[train_jpgs_to_slide==idx] 
        labels_vec[batch_idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
        
        preds = gen(slide) 
        logits = lsm(preds)
        sample = gumbel_softmax(logits, temperature=temp)
        rationale = slide * sample[:,1].unsqueeze(1)
        
        # predict class based on rationales
        logits = enc(rationale)
        logits_vec[batch_idx] = logits
        
        znorm = torch.sum(sample[:,1])
        znorm_vec[batch_idx] = znorm / sample.shape[0]
        
        rat_tiles += znorm.detach().cpu().numpy()
        total_tiles += sample.shape[0]
        
        zdist = torch.sum(torch.abs(sample[:-1,1] - sample[1:,1]))
        zdist_vec[batch_idx] = zdist / sample.shape[0]
        
        if batch_idx == step_size:
            # compute loss and regularization term
            omega = ((lamb1 * znorm_vec.sum()) + (lamb2 * zdist_vec.sum())) / step_size
            bceloss = criterion(logits_vec, labels_vec)
            loss = bceloss + omega
            loss.backward()
            optimizer.step()        
            optimizer.zero_grad()
            track_loss += bceloss.detach().cpu().numpy()
            track_omega += omega.detach().cpu().numpy()
            logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx = initialize_loop_vars(step_size)
        else:
            batch_idx += 1
            
    omega = ((lamb1 * znorm_vec.sum()) + (lamb2 * zdist_vec.sum())) / batch_idx
    bceloss = criterion(logits_vec, labels_vec)
    loss = bceloss + omega
    loss.backward()
    track_loss += bceloss.detach().cpu().numpy()
    track_loss = track_loss * (float(step_size) / float(n_samples))
    track_omega += omega.detach().cpu().numpy()
    track_omega = track_omega * (float(step_size) / float(n_samples))
    optimizer.step()
    optimizer.zero_grad()
       
    frac_tiles = rat_tiles / total_tiles
    print('Epoch: {0}, Train Loss: {1:0.4f}, Train Omega: {2:0.4f}, Frac of Tiles: {3:0.4f}'.format(e, track_loss, 
                                                                                                   track_omega, 
                                                                                                   frac_tiles))

In [13]:
def validation_loop_rationales_gs(e, step_size, scheduler, gen, enc, pool_fn, val_embeddings, val_jpgs_to_slide, 
                                  val_labels, criterion, n_samples, lamb1, lamb2, temp):
    gen.eval()
    enc.eval()
    
    logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx = initialize_loop_vars(n_samples)
    rat_tiles = 0
    total_tiles = 0
    
    idexs = np.linspace(0,n_samples,n_samples+1,dtype=int)
    with torch.no_grad():
        for idx in idexs:
            slide = val_embeddings[val_jpgs_to_slide==idx] # num_tiles x 2048
            labels_vec[idx] = val_labels[val_jpgs_to_slide==idx].unique().float().cuda()

            preds = gen(slide)
            sample = torch.argmax(preds, dim=1).float()
            rationale = slide * sample.unsqueeze(1)

            # predict class based on rationales
            logits = enc(rationale)
            logits_vec[idx] = logits

            znorm = torch.sum(sample)
            znorm_vec[idx] = znorm / sample.shape[0]

            rat_tiles += znorm.detach().cpu().numpy()
            total_tiles += sample.shape[0]

            zdist = torch.sum(torch.abs(sample[:-1] - sample[1:]))
            zdist_vec[idx] = zdist / sample.shape[0]

    # compute loss and regularization term
    omega = ((lamb1 * znorm_vec.sum()) + (lamb2 * zdist_vec.sum())) / n_samples
    bceloss = criterion(logits_vec, labels_vec)
    loss = bceloss + omega
    frac_tiles = rat_tiles / total_tiles
    
    mask = labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy()
    acc = np.mean(mask)
    acc_1 = np.mean(mask[labels_vec.cpu().numpy()==1.])
    acc_0 = np.mean(mask[labels_vec.cpu().numpy()==0.])
    print('Epoch: {0}, Val Loss: {1:0.4f}, Val Omega: {2:0.4f}, Frac of Tiles: {3:0.4f}, Acc: {4:0.4f}, By Label: 0: {5:0.4f}, 1: {6:0.4f}'.format(e, bceloss, omega, frac_tiles, acc, acc_0, acc_1))
    return loss, acc, frac_tiles

In [14]:
best_loss = 1e8
best_acc = 0.0
frac_threshold = 0.5
path_acc_gen = '/n/tcga_models/COAD_rationale_model_gen_5_10_acc.pt'
path_acc_enc = '/n/tcga_models/COAD_rationale_model_enc_5_10_acc.pt'
path_loss_gen = '/n/tcga_models/COAD_rationale_model_gen_5_10.pt'
path_loss_enc = '/n/tcga_models/COAD_rationale_model_enc_5_10.pt'
for e in range(500):
    training_loop_rationales_gs(e, step_size, optimizer, gen, enc, pool_fn, train_embeddings, train_jpgs_to_slide, 
                                train_labels, criterion, n_samples_train, sample_weight, lamb1, lamb2, temp)
    loss, acc, frac = validation_loop_rationales_gs(e, step_size, scheduler, gen, enc, pool_fn, val_embeddings, 
                                              val_jpgs_to_slide, val_labels, criterion, n_samples_val, lamb1, lamb2, temp)
    if loss < best_loss and frac < frac_threshold:
        torch.save(gen.state_dict(), path_loss_gen)
        torch.save(enc.state_dict(), path_loss_enc)
        best_loss = loss
        print('SAVED BEST LOSS')
    if acc > best_acc and frac < frac_threshold:
        torch.save(gen.state_dict(), path_acc_gen)
        torch.save(enc.state_dict(), path_acc_enc)
        best_acc = acc
        print('SAVED BEST ACC')
    if e > 30:
        lamb1 += 0.01
        temp -= 0.1
    temp = np.max([temp, 1.0])
    lamb1 = np.min([lamb1, 1.0])
    if e % 5 == 0:
        print('Lambda: {0:0.4f}, Temperature: {1:0.4f}, LR: {2:0.8f}'.format(lamb1, temp, optimizer.state_dict()['param_groups'][0]['lr']))

Epoch: 0, Train Loss: 0.6370, Train Omega: 0.0000, Frac of Tiles: 0.4729
Epoch: 0, Val Loss: 0.6723, Val Omega: 0.0000, Frac of Tiles: 0.2045, Acc: 0.5610, By Label: 0: 1.0000, 1: 0.0000
SAVED BEST LOSS
SAVED BEST ACC
Lambda: 0.0000, Temperature: 10.0000, LR: 0.00001000
Epoch: 1, Train Loss: 0.6085, Train Omega: 0.0000, Frac of Tiles: 0.5382
Epoch: 1, Val Loss: 0.7144, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6951, By Label: 0: 0.6304, 1: 0.7778
Epoch: 2, Train Loss: 0.5630, Train Omega: 0.0000, Frac of Tiles: 0.7100
Epoch: 2, Val Loss: 0.6193, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6585, By Label: 0: 0.9783, 1: 0.2500
Epoch: 3, Train Loss: 0.5200, Train Omega: 0.0000, Frac of Tiles: 0.8587
Epoch: 3, Val Loss: 0.5967, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6585, By Label: 0: 0.9565, 1: 0.2778
Epoch: 4, Train Loss: 0.4701, Train Omega: 0.0000, Frac of Tiles: 0.9329
Epoch: 4, Val Loss: 0.6243, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6829, By Label: 0:

Epoch: 41, Train Loss: 0.3182, Train Omega: 0.1017, Frac of Tiles: 0.9989
Epoch: 41, Val Loss: 0.7033, Val Omega: 0.1012, Frac of Tiles: 1.0000, Acc: 0.7317, By Label: 0: 0.7826, 1: 0.6667
Epoch: 42, Train Loss: 0.3198, Train Omega: 0.1119, Frac of Tiles: 0.9989
Epoch: 42, Val Loss: 0.7136, Val Omega: 0.1114, Frac of Tiles: 1.0000, Acc: 0.7317, By Label: 0: 0.7609, 1: 0.6944
Epoch: 43, Train Loss: 0.2983, Train Omega: 0.1221, Frac of Tiles: 0.9988
Epoch: 43, Val Loss: 0.7098, Val Omega: 0.1215, Frac of Tiles: 1.0000, Acc: 0.6829, By Label: 0: 0.8261, 1: 0.5000
Epoch: 44, Train Loss: 0.2979, Train Omega: 0.1323, Frac of Tiles: 0.9988
Epoch: 44, Val Loss: 0.7078, Val Omega: 0.1316, Frac of Tiles: 1.0000, Acc: 0.6829, By Label: 0: 0.8261, 1: 0.5000
Epoch: 45, Train Loss: 0.2932, Train Omega: 0.1425, Frac of Tiles: 0.9990
Epoch: 45, Val Loss: 0.7116, Val Omega: 0.1417, Frac of Tiles: 1.0000, Acc: 0.6707, By Label: 0: 0.8043, 1: 0.5000
Lambda: 0.1500, Temperature: 8.5000, LR: 0.00001000
Epo

Epoch: 81, Val Loss: 0.5866, Val Omega: 0.0536, Frac of Tiles: 0.1061, Acc: 0.6707, By Label: 0: 0.9130, 1: 0.3611
SAVED BEST LOSS
Epoch: 82, Train Loss: 0.3547, Train Omega: 0.1122, Frac of Tiles: 0.2205
Epoch: 82, Val Loss: 0.5786, Val Omega: 0.0710, Frac of Tiles: 0.1381, Acc: 0.6707, By Label: 0: 0.9130, 1: 0.3611
Epoch: 83, Train Loss: 0.3648, Train Omega: 0.1064, Frac of Tiles: 0.2089
Epoch: 83, Val Loss: 0.6024, Val Omega: 0.0385, Frac of Tiles: 0.0731, Acc: 0.6707, By Label: 0: 0.9130, 1: 0.3611
Epoch: 84, Train Loss: 0.3427, Train Omega: 0.1129, Frac of Tiles: 0.2137
Epoch: 84, Val Loss: 0.5899, Val Omega: 0.0456, Frac of Tiles: 0.0851, Acc: 0.6585, By Label: 0: 0.9130, 1: 0.3333
SAVED BEST LOSS
Epoch: 85, Train Loss: 0.3817, Train Omega: 0.1038, Frac of Tiles: 0.1936
Epoch: 85, Val Loss: 0.5761, Val Omega: 0.0719, Frac of Tiles: 0.1325, Acc: 0.6463, By Label: 0: 0.8696, 1: 0.3611
Lambda: 0.5500, Temperature: 4.5000, LR: 0.00001000
Epoch: 86, Train Loss: 0.3636, Train Omega: 0

Epoch: 121, Val Loss: 0.5563, Val Omega: 0.0691, Frac of Tiles: 0.0796, Acc: 0.6829, By Label: 0: 0.9130, 1: 0.3889
Epoch: 122, Train Loss: 0.3169, Train Omega: 0.0885, Frac of Tiles: 0.1068
Epoch: 122, Val Loss: 0.5572, Val Omega: 0.0662, Frac of Tiles: 0.0766, Acc: 0.6829, By Label: 0: 0.9130, 1: 0.3889
Epoch: 123, Train Loss: 0.3275, Train Omega: 0.0850, Frac of Tiles: 0.1026
Epoch: 123, Val Loss: 0.5480, Val Omega: 0.0486, Frac of Tiles: 0.0565, Acc: 0.6707, By Label: 0: 0.8913, 1: 0.3889
Epoch: 124, Train Loss: 0.3094, Train Omega: 0.0852, Frac of Tiles: 0.1026
Epoch: 124, Val Loss: 0.5558, Val Omega: 0.0589, Frac of Tiles: 0.0663, Acc: 0.6463, By Label: 0: 0.8478, 1: 0.3889
Epoch: 125, Train Loss: 0.2946, Train Omega: 0.0940, Frac of Tiles: 0.1132
Epoch: 125, Val Loss: 0.5458, Val Omega: 0.0538, Frac of Tiles: 0.0611, Acc: 0.6829, By Label: 0: 0.9130, 1: 0.3889
Lambda: 0.9500, Temperature: 1.0000, LR: 0.00001000
Epoch: 126, Train Loss: 0.3019, Train Omega: 0.0915, Frac of Tiles: 

Epoch: 162, Train Loss: 0.2736, Train Omega: 0.0780, Frac of Tiles: 0.0846
Epoch: 162, Val Loss: 0.5918, Val Omega: 0.0582, Frac of Tiles: 0.0607, Acc: 0.6951, By Label: 0: 0.8696, 1: 0.4722
Epoch: 163, Train Loss: 0.2214, Train Omega: 0.0886, Frac of Tiles: 0.0930
Epoch: 163, Val Loss: 0.6036, Val Omega: 0.0575, Frac of Tiles: 0.0602, Acc: 0.6829, By Label: 0: 0.8913, 1: 0.4167
Epoch: 164, Train Loss: 0.2648, Train Omega: 0.0707, Frac of Tiles: 0.0780
Epoch: 164, Val Loss: 0.6193, Val Omega: 0.0623, Frac of Tiles: 0.0635, Acc: 0.7195, By Label: 0: 0.8478, 1: 0.5556
Epoch: 165, Train Loss: 0.2800, Train Omega: 0.0722, Frac of Tiles: 0.0815
Epoch: 165, Val Loss: 0.5821, Val Omega: 0.0494, Frac of Tiles: 0.0527, Acc: 0.6829, By Label: 0: 0.9130, 1: 0.3889
Lambda: 1.0000, Temperature: 1.0000, LR: 0.00001000
Epoch: 166, Train Loss: 0.2622, Train Omega: 0.0764, Frac of Tiles: 0.0858
Epoch: 166, Val Loss: 0.5944, Val Omega: 0.0534, Frac of Tiles: 0.0550, Acc: 0.7317, By Label: 0: 0.8478, 1: 

KeyboardInterrupt: 