In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
# load sample annotations pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, _, _, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)

In [3]:
# initialize Datasets
train_sets = []
val_sets = []

magnification = '10.0'
root_dir = '/n/mounted-data-drive/'
train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

train_cancers = ['COAD', 'BRCA', 'READ_10x', 'LUSC_10x', 'BLCA_10x', 'LUAD_10x', 'STAD_10x', 'HNSC_10x']
val_cancers = ['UCEC', 'LIHC_10x', 'KIRC_10x']

In [4]:
for i in range(len(train_cancers)):
    print(train_cancers[i], end=' ')
    train_set = data_utils.TCGADataset_tiles(sa_trains[batch_all.index(train_cancers[i])], 
                                             root_dir + train_cancers[i] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification, 
                                             batch_type='tile')
    train_sets.append(train_set)

COAD BRCA READ_10x LUSC_10x BLCA_10x LUAD_10x STAD_10x HNSC_10x 

In [5]:
for j in range(len(val_cancers)):
    print(val_cancers[j], end=' ')
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile')
    val_sets.append(val_set)

UCEC LIHC_10x KIRC_10x 

In [6]:
batch_size_train = 100
train_loader = torch.utils.data.DataLoader(data_utils.ConcatDataset(*train_sets), 
                                           batch_size=batch_size_train, 
                                           shuffle=True, 
                                           num_workers=20, 
                                           pin_memory=True)
batch_size_val = 100
val_loader = torch.utils.data.DataLoader(data_utils.ConcatDataset(*val_sets), 
                                        batch_size=batch_size_val, 
                                        shuffle=True, 
                                        num_workers=20, 
                                        pin_memory=True)

In [7]:
# model args
state_dict_file = '/n/tcga_models/resnet18_WGD_all_10x.pt'
device = torch.device('cuda', 0)
input_size = 2048
hidden_size = 512
output_size = 1

In [8]:
def train_local(step, tiles, labels, resnet, local_models, alpha=0.01, criterion=nn.BCEWithLogitsLoss()):
    resnet.eval()
    idx = int(tiles.shape[0] / 2)
    num_tasks = int(tiles.shape[1])
    
    # grads storage    
    grads = [torch.zeros(p.shape).cuda() for p in local_models[0].parameters()]

    #t = torch.randint(num_tasks, (1,)).item()
    for t in range(num_tasks):
        # first forward pass, step
        net = local_models[t]
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr = alpha)

        inputs = tiles[:idx,t,:,:,:]
        embed = resnet(inputs)
        output = net(embed)
        loss = criterion(output, labels[:idx,t].unsqueeze(1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # second forward pass, store grads
        inputs = tiles[idx:,t,:,:,:]
        embed = resnet(inputs)
        output = net(embed)
        loss = criterion(output, labels[idx:,t].unsqueeze(1))
        loss.backward()
        grads[0] = grads[0] + net.linear1.weight.grad.data
        grads[1] = grads[1] + net.linear1.bias.grad.data
        grads[2] = grads[2] + net.linear2.weight.grad.data
        grads[3] = grads[3] + net.linear2.bias.grad.data
        optimizer.zero_grad()
        
    if step % 50 == 0:
        output = (output.contiguous().view(-1) > 0.5).float().detach().cpu().numpy()
        labels = labels[idx:,t].contiguous().view(-1).float().detach().cpu().numpy()
        acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(labels, output)
        print('Step: {0}, Train NLL: {1:0.4f}, Acc: {2:0.4f}, By Label: {3}'.format(step, loss, acc, tile_acc_by_label))

    return grads, local_models

In [9]:
def train_global(theta_global, model_global, grads, eta=0.01):
    theta_global = [theta_global[i] - (eta * grads[i]) for i in range(len(theta_global))]
    
    model_global.linear1.weight = torch.nn.Parameter(theta_global[0])
    model_global.linear1.bias = torch.nn.Parameter(theta_global[1])
    model_global.linear2.weight = torch.nn.Parameter(theta_global[2])
    model_global.linear2.bias = torch.nn.Parameter(theta_global[3])

    return theta_global, model_global

In [10]:
def run_validation(e, resnet, model_global, val_loader, criterion=nn.BCEWithLogitsLoss()):
    resnet.eval()
    model_global.eval()
    
    total_loss = 0
    all_output = []
    all_labels = []
    
    for step, (batch,labels) in enumerate(val_loader):
        batch_size = batch.shape[0]
        num_tasks = batch.shape[1]
        labels = labels.cuda().transpose(0,1).reshape(batch_size * num_tasks, 1).float()
        inputs = batch.cuda().transpose(0,1).reshape(batch_size * num_tasks, 3, 256, 256)
        
        embed = resnet(inputs)
        output = model_global(embed)
        loss = criterion(output, labels)
        
        output = (output.contiguous().view(-1) > 0.5).float().detach().cpu().numpy()
        labels = labels.contiguous().view(-1).float().detach().cpu().numpy()
        
        total_loss += loss.detach().cpu().numpy()
        all_output.extend(output)
        all_labels.extend(labels)
    
        if step % 50 == 0:
            acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(labels, output)
            print('Step: {0}, Val NLL: {1:0.4f}, Acc: {2:0.4f}, By Label: {3}'.format(step, loss, acc, tile_acc_by_label))
                
    acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(all_labels, all_output)
    print('Epoch: {0}, Val NLL: {1:0.4f}, Acc: {2:0.4f}, By Label: {3}'.format(e, loss, acc, tile_acc_by_label))
    return loss, acc

In [11]:
# initialize trained resnet
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)

# freeze layers
resnet.fc = model_utils.Identity()
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

In [12]:
# initialize theta_global
model_global = model_utils.FeedForward(input_size, hidden_size, output_size).cuda()
theta_global = []
for p in model_global.parameters():
    theta_global.append(torch.randn(list(p.shape)).cuda())
    
model_global.linear1.weight = torch.nn.Parameter(theta_global[0])
model_global.linear1.bias = torch.nn.Parameter(theta_global[1])
model_global.linear2.weight = torch.nn.Parameter(theta_global[2])
model_global.linear2.bias = torch.nn.Parameter(theta_global[3])

# initialize local models, set theta_local = theta_global    
local_models = []
for i in range(len(train_cancers)):
    local_models.append(model_utils.FeedForward(input_size, hidden_size, output_size, theta_global).cuda()) 

In [13]:
# train params
num_epochs = 1000
alpha = 1e-4
eta = 1e-4
patience = 1
factor = 0.1
patience_count = 0
previous_loss = 1e8

In [None]:
# train meta-learner
for e in range(num_epochs):
    # reduce LR on plateau
    if patience_count > patience:
        alpha = factor * alpha
        eta = factor * eta
        patience_count = 0
        print('--- LR DECAY --- Alpha: {0:0.8f}, Eta: {1:0.8f}'.format(alpha, eta))
    
    for step, (tiles, labels) in enumerate(train_loader):  
        tiles, labels = tiles.cuda(), labels.cuda().float()           
        grads, local_models = train_local(step, tiles, labels, resnet, local_models, alpha = alpha)
        theta_global, model_global = train_global(theta_global, model_global, grads, eta = eta)
        for i in range(len(local_models)):
            local_models[i].update_params(theta_global)
            
    loss, acc = run_validation(e, resnet, model_global, val_loader)
    
    if loss > previous_loss:
        patience_count += 1
    else:
        patience_count = 0
        
    previous_loss = loss

Step: 0, Train NLL: 492.5553, Acc: 0.3200, By Label: 0: 0.0571, 1: 0.9333
Step: 50, Train NLL: 163.4610, Acc: 0.5400, By Label: 0: 0.5757, 1: 0.4705
Step: 100, Train NLL: 100.3759, Acc: 0.6600, By Label: 0: 0.8, 1: 0.45
Step: 150, Train NLL: 72.8824, Acc: 0.7200, By Label: 0: 0.8275, 1: 0.5714


In [29]:
loss, acc = run_validation(e, resnet, model_global, val_loader) # batch size = 200, lr = 0.1

Step: 0, Val NLL: 0.7332, Acc: 0.7100, By Label: 0: 0.9330, 1: 0.0526
Step: 50, Val NLL: 0.7045, Acc: 0.7067, By Label: 0: 0.9398, 1: 0.1071
Step: 100, Val NLL: 0.6924, Acc: 0.7100, By Label: 0: 0.9571, 1: 0.1333
Step: 150, Val NLL: 0.7229, Acc: 0.7167, By Label: 0: 0.9452, 1: 0.0987
Step: 200, Val NLL: 0.7237, Acc: 0.6800, By Label: 0: 0.9377, 1: 0.0879
Step: 250, Val NLL: 0.7034, Acc: 0.7033, By Label: 0: 0.9369, 1: 0.0384
Step: 300, Val NLL: 0.6934, Acc: 0.6867, By Label: 0: 0.9336, 1: 0.1011
Step: 350, Val NLL: 0.7325, Acc: 0.6833, By Label: 0: 0.9436, 1: 0.0459
Step: 400, Val NLL: 0.7169, Acc: 0.6967, By Label: 0: 0.9248, 1: 0.1379
Step: 450, Val NLL: 0.7155, Acc: 0.7533, By Label: 0: 0.9511, 1: 0.16
Step: 500, Val NLL: 0.7533, Acc: 0.7233, By Label: 0: 0.9372, 1: 0.1038
Step: 550, Val NLL: 0.7105, Acc: 0.7033, By Label: 0: 0.9209, 1: 0.1529
Step: 600, Val NLL: 0.7235, Acc: 0.6867, By Label: 0: 0.9086, 1: 0.0864
Step: 650, Val NLL: 0.6936, Acc: 0.7033, By Label: 0: 0.9311, 1: 0.09

In [27]:
loss, acc = run_validation(e, resnet, model_global, val_loader) # batch_size = 100, lr = 0.01

Step: 0, Val NLL: 0.9906, Acc: 0.6233, By Label: 0: 0.7336, 1: 0.2676
Step: 50, Val NLL: 1.1043, Acc: 0.6100, By Label: 0: 0.7370, 1: 0.2988
Step: 100, Val NLL: 0.9727, Acc: 0.6633, By Label: 0: 0.7782, 1: 0.3417
Step: 150, Val NLL: 1.1521, Acc: 0.6033, By Label: 0: 0.7162, 1: 0.3176
Step: 200, Val NLL: 1.2545, Acc: 0.5800, By Label: 0: 0.7017, 1: 0.1944
Step: 250, Val NLL: 1.2396, Acc: 0.6133, By Label: 0: 0.7534, 1: 0.2345
Step: 300, Val NLL: 1.3489, Acc: 0.5767, By Label: 0: 0.7242, 1: 0.2093
Step: 350, Val NLL: 0.8989, Acc: 0.6733, By Label: 0: 0.7973, 1: 0.2876
Step: 400, Val NLL: 1.1073, Acc: 0.5800, By Label: 0: 0.6950, 1: 0.2467
Step: 450, Val NLL: 1.1416, Acc: 0.5933, By Label: 0: 0.7104, 1: 0.2658
Step: 500, Val NLL: 1.1372, Acc: 0.6200, By Label: 0: 0.7268, 1: 0.3452
Step: 550, Val NLL: 1.0899, Acc: 0.6467, By Label: 0: 0.7899, 1: 0.2592
Step: 600, Val NLL: 1.0686, Acc: 0.6200, By Label: 0: 0.7314, 1: 0.3333
Step: 650, Val NLL: 1.2468, Acc: 0.6033, By Label: 0: 0.7053, 1: 0.