In [1]:
import sys
sys.path.append('../')

import os
import gc
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'rb') as f: 
    train_embeddings, train_labels, train_jpgs_to_slide = pickle.load(f)
    
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'rb') as f: 
    val_embeddings, val_labels, val_jpgs_to_slide = pickle.load(f)

In [3]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = torch.rand(shape,dtype=torch.float32,device='cuda')
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(logits.shape)
    return F.softmax( y / temperature,dim=1)

def gumbel_softmax(logits, temperature, hard=False):
    """
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        y = torch.argmax(logits,dim=1)
    return y

In [4]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [5]:
input_size = 2048
hidden_size = 2048
output_size_gen = 2
output_size_enc = 1

gen = model_utils.Generator(input_size, hidden_size, output_size_gen, dropout=0.5)
enc = model_utils.Encoder(input_size, hidden_size, output_size_enc, pool_fn, dropout=0.5)

In [6]:
step_size = 10
lamb1 = 0
lamb2 = 0
temp = 10
criterion = nn.BCEWithLogitsLoss()
params = list(enc.parameters()) + list(gen.parameters())
learning_rate = 1e-5
optimizer = torch.optim.Adam(params, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, min_lr=1e-6)

In [7]:
n_samples_val = val_jpgs_to_slide.max()
n_samples_train = train_jpgs_to_slide.max()
idxs_train = np.linspace(0,n_samples_train,n_samples_train+1,dtype=int)
labels_to_idxs_train = np.concatenate([(train_labels[train_jpgs_to_slide==i]).unique().numpy() for i in idxs_train])
weights = 1/np.sum(labels_to_idxs_train==0),1/np.sum(labels_to_idxs_train==1)
sample_weight = [weights[l] for l in labels_to_idxs_train]
sample_weight = sample_weight/np.sum(sample_weight)

In [8]:
train_embeddings = train_embeddings.cuda()
val_embeddings = val_embeddings.cuda()

In [9]:
gen.cuda()
enc.cuda()

Encoder(
  (d): Dropout(p=0.5)
  (m): ReLU()
  (linear1): Linear(in_features=2048, out_features=2048, bias=True)
  (linear2): Linear(in_features=2048, out_features=1, bias=True)
)

In [10]:
lsm = nn.LogSoftmax(dim=1)

In [11]:
def initialize_loop_vars(step_size):
    logits_vec = torch.zeros((step_size+1,1)).cuda()
    labels_vec = torch.zeros_like(logits_vec).cuda()
    znorm_vec = torch.zeros_like(logits_vec).cuda()
    zdist_vec = torch.zeros_like(logits_vec).cuda()
    batch_idx = 0
    return logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx

In [12]:
def training_loop_rationales_gs(e, step_size, optimizer, gen, enc, pool_fn, train_embeddings, train_jpgs_to_slide, 
                                train_labels, criterion, n_samples, sample_weight, lamb1, lamb2, temp):
    gen.train()
    enc.train()
    
    logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx = initialize_loop_vars(step_size)
    track_loss = 0   
    track_omega = 0
    rat_tiles = 0
    total_tiles = 0
    
    idxs_train = np.linspace(0,n_samples,n_samples+1,dtype=int)
    idexs = np.random.choice(idxs_train,size=n_samples.numpy(),p=sample_weight)
    
    for idx in idexs:
        slide = train_embeddings[train_jpgs_to_slide==idx] 
        labels_vec[batch_idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
        
        preds = gen(slide) 
        logits = lsm(preds)
        sample = gumbel_softmax(logits, temperature=temp)
        rationale = slide * sample[:,1].unsqueeze(1)
        
        # predict class based on rationales
        logits = enc(rationale)
        logits_vec[batch_idx] = logits
        
        znorm = torch.sum(sample[:,1])
        znorm_vec[batch_idx] = znorm / sample.shape[0]
        
        rat_tiles += znorm.detach().cpu().numpy()
        total_tiles += sample.shape[0]
        
        zdist = torch.sum(torch.abs(sample[:-1,1] - sample[1:,1]))
        zdist_vec[batch_idx] = zdist / sample.shape[0]
        
        if batch_idx == step_size:
            # compute loss and regularization term
            omega = ((lamb1 * znorm_vec.sum()) + (lamb2 * zdist_vec.sum())) / step_size
            bceloss = criterion(logits_vec, labels_vec)
            loss = bceloss + omega
            loss.backward()
            optimizer.step()        
            optimizer.zero_grad()
            track_loss += bceloss.detach().cpu().numpy()
            track_omega += omega.detach().cpu().numpy()
            logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx = initialize_loop_vars(step_size)
        else:
            batch_idx += 1
            
    omega = ((lamb1 * znorm_vec.sum()) + (lamb2 * zdist_vec.sum())) / batch_idx
    bceloss = criterion(logits_vec, labels_vec)
    loss = bceloss + omega
    loss.backward()
    track_loss += bceloss.detach().cpu().numpy()
    track_loss = track_loss * (float(step_size) / float(n_samples))
    track_omega += omega.detach().cpu().numpy()
    track_omega = track_omega * (float(step_size) / float(n_samples))
    optimizer.step()
    optimizer.zero_grad()
       
    frac_tiles = rat_tiles / total_tiles
    print('Epoch: {0}, Train Loss: {1:0.4f}, Train Omega: {2:0.4f}, Frac of Tiles: {3:0.4f}'.format(e, track_loss, 
                                                                                                   track_omega, 
                                                                                                   frac_tiles))

In [17]:
def validation_loop_rationales_gs(e, step_size, scheduler, gen, enc, pool_fn, val_embeddings, val_jpgs_to_slide, 
                                  val_labels, criterion, n_samples, lamb1, lamb2, temp):
    gen.eval()
    enc.eval()
    
    logits_vec, labels_vec, znorm_vec, zdist_vec, batch_idx = initialize_loop_vars(n_samples)
    rat_tiles = 0
    total_tiles = 0
    
    idexs = np.linspace(0,n_samples,n_samples+1,dtype=int)
    with torch.no_grad():
        for idx in idexs:
            slide = val_embeddings[val_jpgs_to_slide==idx] # num_tiles x 2048
            labels_vec[idx] = val_labels[val_jpgs_to_slide==idx].unique().float().cuda()

            preds = gen(slide)
            sample = torch.argmax(preds, dim=1).float()
            rationale = slide * sample.unsqueeze(1)

            # predict class based on rationales
            logits = enc(rationale)
            logits_vec[idx] = logits

            znorm = torch.sum(sample)
            znorm_vec[idx] = znorm / sample.shape[0]

            rat_tiles += znorm.detach().cpu().numpy()
            total_tiles += sample.shape[0]

            zdist = torch.sum(torch.abs(sample[:-1] - sample[1:]))
            zdist_vec[idx] = zdist / sample.shape[0]

    # compute loss and regularization term
    omega = ((lamb1 * znorm_vec.sum()) + (lamb2 * zdist_vec.sum())) / n_samples
    bceloss = criterion(logits_vec, labels_vec)
    loss = bceloss + omega
    frac_tiles = rat_tiles / total_tiles
    
    mask = labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy()
    acc = np.mean(mask)
    acc_1 = np.mean(mask[labels_vec.cpu().numpy()==1.])
    acc_0 = np.mean(mask[labels_vec.cpu().numpy()==0.])
    print('Epoch: {0}, Val Loss: {1:0.4f}, Val Omega: {2:0.4f}, Frac of Tiles: {3:0.4f}, Acc: {4:0.4f}, By Label: 0: {5:0.4f}, 1: {6:0.4f}'.format(e, bceloss, omega, frac_tiles, acc, acc_0, acc_1))
    return loss, acc

In [18]:
for e in range(100):
    training_loop_rationales_gs(e, step_size, optimizer, gen, enc, pool_fn, train_embeddings, train_jpgs_to_slide, 
                                train_labels, criterion, n_samples_train, sample_weight, lamb1, lamb2, temp)
    loss, acc = validation_loop_rationales_gs(e, step_size, scheduler, gen, enc, pool_fn, val_embeddings, 
                                              val_jpgs_to_slide, val_labels, criterion, n_samples_val, lamb1, lamb2, temp)
    if e > 10:
        lamb1 += 0.01
        temp -= 0.1
    temp = np.max([temp, 1.0])
    lamb1 = np.min([lamb1, 1.0])
    if e % 5 == 0:
        print('Lambda: {0:0.4f}, Temperature: {1:0.4f}, LR: {2:0.8f}'.format(lamb1, temp, optimizer.state_dict()['param_groups'][0]['lr']))

Epoch: 0, Train Loss: 0.5549, Train Omega: 0.0000, Frac of Tiles: 0.7587
Epoch: 0, Val Loss: 0.6300, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6585, By Label: 0: 0.9130, 1: 0.3333
Lambda: 0.0000, Temperature: 10.0000, LR: 0.00001000
Epoch: 1, Train Loss: 0.4968, Train Omega: 0.0000, Frac of Tiles: 0.8939
Epoch: 1, Val Loss: 0.6100, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.5854, By Label: 0: 1.0000, 1: 0.0556
Epoch: 2, Train Loss: 0.4895, Train Omega: 0.0000, Frac of Tiles: 0.9403
Epoch: 2, Val Loss: 0.5999, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6829, By Label: 0: 0.9348, 1: 0.3611
Epoch: 3, Train Loss: 0.4565, Train Omega: 0.0000, Frac of Tiles: 0.9590
Epoch: 3, Val Loss: 0.5969, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6707, By Label: 0: 0.8696, 1: 0.4167
Epoch: 4, Train Loss: 0.4355, Train Omega: 0.0000, Frac of Tiles: 0.9730
Epoch: 4, Val Loss: 0.5806, Val Omega: 0.0000, Frac of Tiles: 1.0000, Acc: 0.6707, By Label: 0: 0.9130, 1: 0.3611
Epoch: 5, Tr

KeyboardInterrupt: 