In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils
import models

%reload_ext autoreload
%autoreload 2


from sklearn.metrics import roc_auc_score

def compute_auc(valid_loader, net, criterion, pool_fn):
    net.eval()
    total_loss = 0
    labels = []
    preds = []
    logits = []
    with torch.no_grad():
        for slide,label in valid_loader:
            slide.squeeze_()
            slide, label = slide.cuda(), label.cuda()
            output = net(slide)
            pool = pool_fn(output).unsqueeze(0)
            output = net.classification_layer(pool)
            loss = criterion(output, label)
            logits.append(output[0][1].detach().cpu().numpy())
            total_loss += loss.detach().cpu().numpy()
            labels.extend(label.float().cpu().numpy())
            preds.append(torch.argmax(output).float().detach().cpu().numpy())
    
        acc = np.mean(np.array(labels) == np.array(preds))
    logits = np.stack(logits)
    labels = np.array(labels)
    auc = roc_auc_score(labels, logits)
    return auc

In [2]:
train = data_utils.COAD_dataset(data_utils.COAD_TRAIN)
train_loader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True , pin_memory=True)

valid = data_utils.COAD_dataset(data_utils.COAD_VALID)
valid_loader = torch.utils.data.DataLoader(valid, batch_size=1, shuffle=True , pin_memory=True)

# Mean pooling benchmark

In [3]:
n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
net = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
net = net.cuda()

In [4]:
lr = 0.0001
weight_decay = 0.0005
def pool_fn(x):
    v = torch.mean(x,0)
    return v
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

In [5]:
best_loss = 1e8
for e in range(100):
    train_utils.embedding_training_loop(e, train_loader, net, criterion, optimizer,pool_fn)
    loss = train_utils.embedding_validation_loop(e, valid_loader, net, criterion,pool_fn)
    print('LR = {}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
    if loss < best_loss:
        torch.save(net.state_dict(),'convnet_mean_embed.pt')
        best_loss = loss
        print('WROTE MODEL')

Epoch: 0, Train NLL: 34.3639
Epoch: 0, Val NLL: 34.9153, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 1, Train NLL: 33.9725
Epoch: 1, Val NLL: 35.0570, Val Acc: 0.4400
LR = 0.0001
Epoch: 2, Train NLL: 34.0094
Epoch: 2, Val NLL: 34.9143, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 3, Train NLL: 33.8674
Epoch: 3, Val NLL: 35.0414, Val Acc: 0.4400
LR = 0.0001
Epoch: 4, Train NLL: 33.8154
Epoch: 4, Val NLL: 34.8646, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 5, Train NLL: 33.9091
Epoch: 5, Val NLL: 34.8800, Val Acc: 0.4400
LR = 0.0001
Epoch: 6, Train NLL: 33.7869
Epoch: 6, Val NLL: 34.9257, Val Acc: 0.4400
LR = 0.0001
Epoch: 7, Train NLL: 34.1032
Epoch: 7, Val NLL: 34.8426, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 8, Train NLL: 33.7169
Epoch: 8, Val NLL: 35.3528, Val Acc: 0.4400
LR = 0.0001
Epoch: 9, Train NLL: 33.6053
Epoch: 9, Val NLL: 34.8635, Val Acc: 0.4400
LR = 0.0001
Epoch: 10, Train NLL: 33.9253
Epoch: 10, Val NLL: 34.9161, Val Acc: 0.4400
LR = 0.0001
Epoch: 11, Trai

Epoch: 91, Train NLL: 7.3497
Epoch: 91, Val NLL: 21.8068, Val Acc: 0.8600
LR = 0.0001
Epoch: 92, Train NLL: 6.2891
Epoch: 92, Val NLL: 20.3960, Val Acc: 0.8600
LR = 0.0001
Epoch: 93, Train NLL: 6.6234
Epoch: 93, Val NLL: 19.8022, Val Acc: 0.8400
LR = 0.0001
Epoch: 94, Train NLL: 8.7323
Epoch: 94, Val NLL: 18.8535, Val Acc: 0.8400
LR = 0.0001
Epoch: 95, Train NLL: 6.8555
Epoch: 95, Val NLL: 19.4858, Val Acc: 0.8400
LR = 0.0001
Epoch: 96, Train NLL: 6.0016
Epoch: 96, Val NLL: 20.4090, Val Acc: 0.8400
LR = 0.0001
Epoch: 97, Train NLL: 4.8574
Epoch: 97, Val NLL: 22.4974, Val Acc: 0.8400
LR = 0.0001
Epoch: 98, Train NLL: 5.9782
Epoch: 98, Val NLL: 20.3884, Val Acc: 0.8800
LR = 0.0001
Epoch: 99, Train NLL: 4.9569
Epoch: 99, Val NLL: 21.0085, Val Acc: 0.8800
LR = 0.0001


In [35]:
def pool_fn(x):
    v = torch.mean(x,0)
    return v

state_dict = torch.load('convnet_mean_embed.pt')
net.load_state_dict(state_dict)
net = net.cuda()
compute_auc(valid_loader, net, criterion,pool_fn)

0.9172077922077922

# Neural Attention Benchmark

In [39]:

input_size = 512
hidden_size = 512
output_size = 1
attn = models.Attention(input_size, hidden_size, output_size)
attn.cuda()
pool_fn = models.pool(attn)

n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
net = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
net = net.cuda()

lr = 0.0001
weight_decay = 0.0005
criterion = nn.CrossEntropyLoss()
parameters = [p for p in net.parameters()]
parameters.extend([p for p in attn.parameters()])
optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)


In [40]:
best_loss = 1e8
for e in range(100):
    train_utils.embedding_training_loop(e, train_loader, net, criterion, optimizer,pool_fn)
    loss = train_utils.embedding_validation_loop(e, valid_loader, net, criterion,pool_fn)
    print('LR = {}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
    if loss < best_loss:
        torch.save(net.state_dict(),'convnet_attention_embed.pt')
        torch.save(attn.state_dict(),'attention_pool_embed.pt')
        best_loss = loss
        print('WROTE MODEL')

Epoch: 0, Train NLL: 34.2017
Epoch: 0, Val NLL: 35.2774, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 1, Train NLL: 34.5259
Epoch: 1, Val NLL: 34.7277, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 2, Train NLL: 33.9126
Epoch: 2, Val NLL: 35.0359, Val Acc: 0.4400
LR = 0.0001
Epoch: 3, Train NLL: 33.9082
Epoch: 3, Val NLL: 34.9596, Val Acc: 0.4400
LR = 0.0001
Epoch: 4, Train NLL: 33.8870
Epoch: 4, Val NLL: 34.7221, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 5, Train NLL: 33.9272
Epoch: 5, Val NLL: 35.0649, Val Acc: 0.4400
LR = 0.0001
Epoch: 6, Train NLL: 33.8881
Epoch: 6, Val NLL: 34.7648, Val Acc: 0.4400
LR = 0.0001
Epoch: 7, Train NLL: 33.7409
Epoch: 7, Val NLL: 34.9427, Val Acc: 0.4400
LR = 0.0001
Epoch: 8, Train NLL: 33.6009
Epoch: 8, Val NLL: 34.6753, Val Acc: 0.4400
LR = 0.0001
WROTE MODEL
Epoch: 9, Train NLL: 33.8874
Epoch: 9, Val NLL: 34.7165, Val Acc: 0.4400
LR = 0.0001
Epoch: 10, Train NLL: 33.3413
Epoch: 10, Val NLL: 34.1933, Val Acc: 0.5600
LR = 0.0001
WROTE MODEL
Epo

Epoch: 92, Train NLL: 7.0301
Epoch: 92, Val NLL: 21.6809, Val Acc: 0.8400
LR = 0.0001
Epoch: 93, Train NLL: 5.0907
Epoch: 93, Val NLL: 20.2760, Val Acc: 0.8600
LR = 0.0001
Epoch: 94, Train NLL: 4.6036
Epoch: 94, Val NLL: 23.3069, Val Acc: 0.8400
LR = 0.0001
Epoch: 95, Train NLL: 12.0253
Epoch: 95, Val NLL: 20.7364, Val Acc: 0.8200
LR = 0.0001
Epoch: 96, Train NLL: 7.5517
Epoch: 96, Val NLL: 17.3303, Val Acc: 0.8800
LR = 0.0001
Epoch: 97, Train NLL: 5.3005
Epoch: 97, Val NLL: 18.8591, Val Acc: 0.8800
LR = 0.0001
Epoch: 98, Train NLL: 5.2294
Epoch: 98, Val NLL: 20.5604, Val Acc: 0.8600
LR = 0.0001
Epoch: 99, Train NLL: 3.4743
Epoch: 99, Val NLL: 19.1279, Val Acc: 0.8800
LR = 0.0001


In [41]:

state_dict = torch.load('convnet_attention_embed.pt')
net.load_state_dict(state_dict)
net = net.cuda()

input_size = 512
hidden_size = 512
output_size = 1
attn = models.Attention(input_size, hidden_size, output_size)
state_dict = torch.load('attention_pool_embed.pt')
attn.load_state_dict(state_dict)
attn.cuda()
pool_fn = models.pool(attn)


compute_auc(valid_loader, net, criterion,pool_fn)

0.9464285714285714

# Neural Rationales Benchmark

In [34]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = models.Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
#state_dict = torch.load('/home/sxchao/MSI_prediction/labeled_nuclei_project/generator.pt')
#gen.load_state_dict(state_dict)
gen = gen.cuda()

n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
#state_dict = torch.load('/home/sxchao/MSI_prediction/labeled_nuclei_project/encoder.pt')
#enc.load_state_dict(state_dict)
end = enc.cuda()
def pool_fn(x):
    v = torch.mean(x,0)
    return v



lamb1 = 0
lamb2 = 0
xent = nn.CrossEntropyLoss()
learning_rate = 1e-4
temp = 5
params = list(enc.parameters()) + list(gen.parameters())
optimizer = torch.optim.Adam(params, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, min_lr=1e-6)
total_val_tiles = 0

for slide,label in valid_loader:
    total_val_tiles += slide.shape[1]
    
print(total_val_tiles)



10703


In [36]:
#best_val_loss = 1e8 
#best_val_frac = 0.10
for e in range(500):
    train_utils.rationales_training_loop_GS(e, train_loader, gen, enc, pool_fn, lamb1, lamb2, xent, learning_rate, optimizer,temp)
    loss, frac_tiles, total_tiles = train_utils.rationales_validation_loop_GS(e, valid_loader, gen, enc, pool_fn, xent, scheduler)
    if e > 50:
        lamb1 += 0.001
        temp -= 0.25
    temp = np.max([temp,1])
    lamb1 = np.min([lamb1,0.4])
    if e % 5 == 0:
        print('========== Train Set ==========')
        _, _, _ = train_utils.rationales_validation_loop_GS(e, train_loader, gen, enc, pool_fn, xent, scheduler)
        print('Lambda: {0:0.7f}, LR: {1:0.7f}, Temperature: {2:0.7f}'.format(lamb1, optimizer.state_dict()['param_groups'][0]['lr'], temp))
    if loss < best_val_loss and frac_tiles < best_val_frac and total_tiles == total_val_tiles:
        best_val_loss = loss
        torch.save(gen.state_dict(),'generator.pt')
        torch.save(enc.state_dict(),'encoder.pt')
        print('WROTE MODEL!')
    #if frac_tiles < 0.9:
    #    break

Epoch: 0, Train Loss: 18.840653, Train Omega: 0.0000, Fraction of Tiles: 0.9847
Epoch: 0, Val Loss: 24.5958, Val Acc: 0.8200, Fraction of Tiles: 1.0000, Total Tiles: 10703.0
Epoch: 0, Val Loss: 18.0476, Val Acc: 0.8571, Fraction of Tiles: 1.0000, Total Tiles: 9659.0
Lambda: 0.0000000, LR: 0.0001000, Temperature: 5.0000000
Epoch: 1, Train Loss: 18.166777, Train Omega: 0.0000, Fraction of Tiles: 0.9857
Epoch: 1, Val Loss: 24.4054, Val Acc: 0.8000, Fraction of Tiles: 1.0000, Total Tiles: 10703.0
Epoch: 2, Train Loss: 19.524276, Train Omega: 0.0000, Fraction of Tiles: 0.9863
Epoch: 2, Val Loss: 25.2501, Val Acc: 0.7400, Fraction of Tiles: 1.0000, Total Tiles: 10703.0
Epoch: 3, Train Loss: 20.246536, Train Omega: 0.0000, Fraction of Tiles: 0.9866
Epoch: 3, Val Loss: 32.6224, Val Acc: 0.6600, Fraction of Tiles: 1.0000, Total Tiles: 10703.0


KeyboardInterrupt: 

In [12]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = models.Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
state_dict = torch.load('generator_retrain.pt')
gen.load_state_dict(state_dict)
gen = gen.cuda()

n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
state_dict = torch.load('encoder_retrain.pt')
enc.load_state_dict(state_dict)
enc = enc.cuda()


#compute_auc(valid_loader, net, criterion,pool_fn)

In [14]:
def compute_rationales_auc(valid_loader, gen, enc, pool_fn):
    gen.eval()
    enc.eval()
    
    labels = []
    logits = []

    for slide,label in valid_loader:
        slide, label = slide.squeeze(0).cuda(), label.cuda()

        prez = gen(slide)
        z = torch.argmax(prez, dim=2).squeeze(0)
        rationale = slide[z==1,:,:,:]

        output = enc(rationale)
        pool = pool_fn(output)
        y_hat = enc.classification_layer(pool)
        logits.append(y_hat[1].detach().cpu().numpy())
        labels.extend(label.float().cpu().numpy())

    
    auc = roc_auc_score(labels, logits)
    return auc

In [15]:
compute_rationales_auc(valid_loader, gen, enc, pool_fn)

0.9561688311688312

In [16]:
loss, frac_tiles, total_tiles = train_utils.rationales_validation_loop_GS(e, valid_loader, gen, enc, pool_fn, xent, scheduler)

Epoch: 47, Val Loss: 16.6692, Val Acc: 0.8800, Fraction of Tiles: 0.0126, Total Tiles: 10703.0
