In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
# load sample annotations pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, _, _, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)

In [3]:
# initialize Datasets
train_sets = []
val_sets = []

magnification = '10.0'
root_dir = '/n/mounted-data-drive/'
train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

train_cancers = ['COAD', 'BRCA', 'READ_10x', 'LUSC_10x', 'BLCA_10x', 'LUAD_10x', 'STAD_10x', 'HNSC_10x']
val_cancers = ['UCEC', 'LIHC_10x', 'KIRC_10x']

In [4]:
for i in range(len(train_cancers)):
    print(train_cancers[i], end=' ')
    train_set = data_utils.TCGADataset_tiles(sa_trains[batch_all.index(train_cancers[i])], 
                                             root_dir + train_cancers[i] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification, 
                                             batch_type='tile')
    train_sets.append(train_set)

COAD BRCA READ_10x LUSC_10x BLCA_10x LUAD_10x STAD_10x HNSC_10x 

In [4]:
for j in range(len(val_cancers)):
    print(val_cancers[j], end=' ')
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile',
                                           return_jpg_to_sample=True)
    val_sets.append(val_set)

UCEC LIHC_10x KIRC_10x 

In [None]:
batch_size_train = 100
train_loader = torch.utils.data.DataLoader(data_utils.ConcatDataset(*train_sets), 
                                           batch_size=batch_size_train, 
                                           shuffle=True, 
                                           num_workers=20, 
                                           pin_memory=True)

In [5]:
batch_size_val = 100
#val_loader = torch.utils.data.DataLoader(data_utils.ConcatDataset(*val_sets, return_jpg_to_sample=True), 
#                                        batch_size=batch_size_val, 
#                                        shuffle=True, 
#                                        num_workers=20, 
#                                        pin_memory=True)

In [8]:
val_loaders = [torch.utils.data.DataLoader(val_set, 
                                            batch_size=batch_size_val, 
                                            shuffle=True, 
                                            num_workers=20, 
                                            pin_memory=True) for val_set in val_sets]
len(val_loaders)

3

In [7]:
len(train_loader), len(val_loader)

(1411, 1199)

In [9]:
# model args
state_dict_file = '/n/tcga_models/resnet18_WGD_all_10x.pt'
device = torch.device('cuda', 0)
input_size = 2048
hidden_size = 512
output_size = 1

In [10]:
# initialize trained resnet
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)

# freeze layers
resnet.fc = model_utils.Identity()
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

In [11]:
# initialize theta_global
model_global = model_utils.FeedForward(input_size, hidden_size, output_size).cuda()
theta_global = []
for p in model_global.parameters():
    theta_global.append(torch.randn(list(p.shape)).cuda())
    
model_global.linear1.weight = torch.nn.Parameter(theta_global[0])
model_global.linear1.bias = torch.nn.Parameter(theta_global[1])
model_global.linear2.weight = torch.nn.Parameter(theta_global[2])
model_global.linear2.bias = torch.nn.Parameter(theta_global[3])

# initialize local models, set theta_local = theta_global    
local_models = []
for i in range(len(train_cancers)):
    local_models.append(model_utils.FeedForward(input_size, hidden_size, output_size, theta_global).cuda()) 

In [12]:
# train params
num_epochs = 1000
alpha = 0.1
eta = 0.1
patience = 1
factor = 0.1
patience_count = 0
previous_loss = 1e8

In [17]:
e = 0
val_loss, tile_acc, slide_acc = train_utils.maml_validate_all(e, resnet, model_global, val_loaders, criterion=nn.BCEWithLogitsLoss())

Step: 0, Val NLL: 212.1946, Acc: 0.4300, By Label: 0: 0.45, 1: 0.35
Step: 100, Val NLL: 191.5525, Acc: 0.4800, By Label: 0: 0.4729, 1: 0.5
Step: 200, Val NLL: 208.2188, Acc: 0.4400, By Label: 0: 0.4096, 1: 0.5882
Step: 300, Val NLL: 219.0420, Acc: 0.4200, By Label: 0: 0.4197, 1: 0.4210
Step: 400, Val NLL: 219.8965, Acc: 0.3900, By Label: 0: 0.3513, 1: 0.5
Step: 500, Val NLL: 222.8416, Acc: 0.4400, By Label: 0: 0.4675, 1: 0.3478
Step: 600, Val NLL: 182.6078, Acc: 0.4800, By Label: 0: 0.4324, 1: 0.6153
Step: 700, Val NLL: 217.5059, Acc: 0.4000, By Label: 0: 0.3783, 1: 0.4615
Step: 800, Val NLL: 185.4231, Acc: 0.4200, By Label: 0: 0.4166, 1: 0.4285
Step: 900, Val NLL: 219.0046, Acc: 0.4900, By Label: 0: 0.4729, 1: 0.5384
Step: 1000, Val NLL: 198.4301, Acc: 0.4200, By Label: 0: 0.3918, 1: 0.5
Step: 1100, Val NLL: 198.5161, Acc: 0.4300, By Label: 0: 0.4342, 1: 0.4166
Step: 1200, Val NLL: 198.9433, Acc: 0.4900, By Label: 0: 0.4125, 1: 0.8
Step: 1300, Val NLL: 187.8407, Acc: 0.4500, By Label:

In [15]:
# train meta-learner
for e in range(num_epochs):
    # reduce LR on plateau
    if patience_count > patience:
        alpha = factor * alpha
        eta = factor * eta
        patience_count = 0
        print('--- LR DECAY --- Alpha: {0:0.8f}, Eta: {1:0.8f}'.format(alpha, eta))
    
    for step, (tiles, labels) in enumerate(train_loader):  
        tiles, labels = tiles.cuda(), labels.cuda().float()           
        grads, local_models = train_utils.maml_train_local(step, tiles, labels, resnet, local_models, alpha = alpha)
        theta_global, model_global = train_utils.maml_train_global(theta_global, model_global, grads, eta = eta)
        for i in range(len(local_models)):
            local_models[i].update_params(theta_global)
            
    loss, acc, mean_pool_acc = train_utils.maml_validate(e, resnet, model_global, val_loader)
    
    if loss > previous_loss:
        patience_count += 1
    else:
        patience_count = 0
        
    previous_loss = loss

Step: 0, Train NLL: 566.9117, Acc: 0.3200, By Label: 0: 0.0, 1: 1.0
Step: 50, Train NLL: 0.7116, Acc: 0.6800, By Label: 0: 1.0, 1: 0.0588
Step: 100, Train NLL: 0.6816, Acc: 0.7400, By Label: 0: 1.0, 1: 0.0
Step: 150, Train NLL: 0.7392, Acc: 0.7200, By Label: 0: 1.0, 1: 0.0
Step: 200, Train NLL: 0.7264, Acc: 0.7400, By Label: 0: 1.0, 1: 0.0


KeyboardInterrupt: 

## Archive

In [29]:
loss, acc = run_validation(e, resnet, model_global, val_loader) # batch size = 200, lr = 0.1, SGD

Step: 0, Val NLL: 0.7332, Acc: 0.7100, By Label: 0: 0.9330, 1: 0.0526
Step: 50, Val NLL: 0.7045, Acc: 0.7067, By Label: 0: 0.9398, 1: 0.1071
Step: 100, Val NLL: 0.6924, Acc: 0.7100, By Label: 0: 0.9571, 1: 0.1333
Step: 150, Val NLL: 0.7229, Acc: 0.7167, By Label: 0: 0.9452, 1: 0.0987
Step: 200, Val NLL: 0.7237, Acc: 0.6800, By Label: 0: 0.9377, 1: 0.0879
Step: 250, Val NLL: 0.7034, Acc: 0.7033, By Label: 0: 0.9369, 1: 0.0384
Step: 300, Val NLL: 0.6934, Acc: 0.6867, By Label: 0: 0.9336, 1: 0.1011
Step: 350, Val NLL: 0.7325, Acc: 0.6833, By Label: 0: 0.9436, 1: 0.0459
Step: 400, Val NLL: 0.7169, Acc: 0.6967, By Label: 0: 0.9248, 1: 0.1379
Step: 450, Val NLL: 0.7155, Acc: 0.7533, By Label: 0: 0.9511, 1: 0.16
Step: 500, Val NLL: 0.7533, Acc: 0.7233, By Label: 0: 0.9372, 1: 0.1038
Step: 550, Val NLL: 0.7105, Acc: 0.7033, By Label: 0: 0.9209, 1: 0.1529
Step: 600, Val NLL: 0.7235, Acc: 0.6867, By Label: 0: 0.9086, 1: 0.0864
Step: 650, Val NLL: 0.6936, Acc: 0.7033, By Label: 0: 0.9311, 1: 0.09

In [27]:
loss, acc = run_validation(e, resnet, model_global, val_loader) # batch_size = 100, lr = 0.01, SGD

Step: 0, Val NLL: 0.9906, Acc: 0.6233, By Label: 0: 0.7336, 1: 0.2676
Step: 50, Val NLL: 1.1043, Acc: 0.6100, By Label: 0: 0.7370, 1: 0.2988
Step: 100, Val NLL: 0.9727, Acc: 0.6633, By Label: 0: 0.7782, 1: 0.3417
Step: 150, Val NLL: 1.1521, Acc: 0.6033, By Label: 0: 0.7162, 1: 0.3176
Step: 200, Val NLL: 1.2545, Acc: 0.5800, By Label: 0: 0.7017, 1: 0.1944
Step: 250, Val NLL: 1.2396, Acc: 0.6133, By Label: 0: 0.7534, 1: 0.2345
Step: 300, Val NLL: 1.3489, Acc: 0.5767, By Label: 0: 0.7242, 1: 0.2093
Step: 350, Val NLL: 0.8989, Acc: 0.6733, By Label: 0: 0.7973, 1: 0.2876
Step: 400, Val NLL: 1.1073, Acc: 0.5800, By Label: 0: 0.6950, 1: 0.2467
Step: 450, Val NLL: 1.1416, Acc: 0.5933, By Label: 0: 0.7104, 1: 0.2658
Step: 500, Val NLL: 1.1372, Acc: 0.6200, By Label: 0: 0.7268, 1: 0.3452
Step: 550, Val NLL: 1.0899, Acc: 0.6467, By Label: 0: 0.7899, 1: 0.2592
Step: 600, Val NLL: 1.0686, Acc: 0.6200, By Label: 0: 0.7314, 1: 0.3333
Step: 650, Val NLL: 1.2468, Acc: 0.6033, By Label: 0: 0.7053, 1: 0.

In [20]:
loss, acc, mean_pool_acc = run_validation(e, resnet, model_global, val_loader) # batch_size = 100, lr = 1e-4, SGD

Step: 0, Val NLL: 67.1556, Acc: 0.5233, By Label: 0: 0.5450, 1: 0.4615
Step: 50, Val NLL: 71.8051, Acc: 0.5000, By Label: 0: 0.4932, 1: 0.5189
Step: 100, Val NLL: 65.5996, Acc: 0.5000, By Label: 0: 0.4892, 1: 0.5373
Step: 150, Val NLL: 81.1949, Acc: 0.4800, By Label: 0: 0.4545, 1: 0.55
Step: 200, Val NLL: 70.8121, Acc: 0.5000, By Label: 0: 0.5090, 1: 0.4743
Step: 250, Val NLL: 62.1865, Acc: 0.6067, By Label: 0: 0.5844, 1: 0.6666
Step: 300, Val NLL: 67.4404, Acc: 0.4633, By Label: 0: 0.4511, 1: 0.4941
Step: 350, Val NLL: 65.1226, Acc: 0.4867, By Label: 0: 0.4841, 1: 0.4936
Step: 400, Val NLL: 72.8518, Acc: 0.4867, By Label: 0: 0.4633, 1: 0.5487
Step: 450, Val NLL: 66.6190, Acc: 0.5267, By Label: 0: 0.4837, 1: 0.6352
Step: 500, Val NLL: 71.3242, Acc: 0.4967, By Label: 0: 0.5142, 1: 0.4555
Step: 550, Val NLL: 76.1082, Acc: 0.4967, By Label: 0: 0.5022, 1: 0.48
Step: 600, Val NLL: 71.5169, Acc: 0.4967, By Label: 0: 0.4882, 1: 0.5172
Step: 650, Val NLL: 63.6326, Acc: 0.5133, By Label: 0: 0.5

ValueError: too many values to unpack (expected 2)