In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
# load sample annotations pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, _, _, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)

In [3]:
# initialize Datasets
train_sets = []
val_sets = []

magnification = '10.0'
root_dir = '/n/mounted-data-drive/'
train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

train_cancers = ['COAD', 'BRCA', 'READ_10x', 'LUSC_10x', 'BLCA_10x', 'LUAD_10x', 'STAD_10x', 'HNSC_10x']
val_cancers = ['UCEC', 'LIHC_10x', 'KIRC_10x']

In [4]:
for i in range(len(train_cancers)):
    print(train_cancers[i], end=' ')
    train_set = data_utils.TCGADataset_tiles(sa_trains[batch_all.index(train_cancers[i])], 
                                             root_dir + train_cancers[i] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification, 
                                             batch_type='tile')
    train_sets.append(train_set)

COAD BRCA READ_10x LUSC_10x BLCA_10x LUAD_10x STAD_10x HNSC_10x 

In [5]:
for j in range(len(val_cancers)):
    print(val_cancers[j], end=' ')
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile')
    val_sets.append(val_set)

UCEC LIHC_10x KIRC_10x 

In [6]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __len__(self):
        return min(len(d) for d in self.datasets)
    
    def __getitem__(self, i):
        return torch.stack([d[i][0] for d in self.datasets]), torch.cat([torch.tensor(d[i][1]).view(-1) for d in self.datasets])

In [7]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(ConcatDataset(*train_sets), 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=16, 
                                           pin_memory=True)
val_loader = torch.utils.data.DataLoader(ConcatDataset(*val_sets), 
                                        batch_size=batch_size, 
                                        shuffle=True, 
                                        num_workers=16, 
                                        pin_memory=True)

In [8]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [9]:
class FeedForward(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, initial_vals=None, dropout=0.0):
        super(FeedForward, self).__init__()
        self.d = nn.Dropout(dropout)
        self.m = nn.ReLU()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
        
        if initial_vals != None:
            self.linear1.weight = torch.nn.Parameter(initial_vals[0])
            self.linear1.bias = torch.nn.Parameter(initial_vals[1])
            self.linear2.weight = torch.nn.Parameter(initial_vals[2])
            self.linear2.bias = torch.nn.Parameter(initial_vals[3])
        
    def forward(self, inputs):
        hidden = self.m(self.linear1(self.d(inputs)))
        output = self.linear2(self.d(hidden))
        return output

In [11]:
# model args
state_dict_file = '/n/tcga_models/resnet18_WGD_all_10x.pt'
device = torch.device('cuda', 0)
input_size = 2048
hidden_size = 512
output_size = 1

In [37]:
def train_local(tiles, labels, resnet, theta_global, alpha = 0.01, criterion = nn.BCEWithLogitsLoss(),
                input_size = input_size, hidden_size = hidden_size, output_size = output_size):
    idx = int(tiles.shape[0] / 2)
    num_tasks = int(batch.shape[1])
    
    # initialize models, set theta_local = theta_global    
    models = []
    for i in range(num_tasks):
        models.append(FeedForward(input_size, hidden_size, output_size, theta_global).cuda())  

    # grads storage    
    grads = [torch.zeros(theta_global[i].shape).cuda() for i in range(len(theta_global))]

    for t in range(num_tasks):
        # first forward pass, step
        net = models[t]
        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr = alpha)
        
        inputs = tiles[:idx,t,:,:,:]
        embed = resnet(inputs)
        output = net(embed)
        loss = criterion(output, labels[:idx,t])
        loss.backward()
        optimizer.step()

        # second forward pass, store grads
        optimizer.zero_grad()
        inputs = batch[idx:,t,:,:,:]
        embed = resnet(inputs)        
        loss = criterion(output, labels)
        loss.backward()
        grads[0] = grads[0] + net.linear1.weight.grad.data
        grads[1] = grads[1] + net.linear1.bias.grad.data
        grads[2] = grads[2] + net.linear2.weight.grad.data
        grads[3] = grads[3] + net.linear2.bias.grad.data

    return grads

In [38]:
def train_global(theta_global, model_global, grads, eta = 0.01):
    theta_global = [theta_global[i] - (eta * grads[i]) for i in range(len(theta_global))]
    
    model_global.linear1.weight = torch.nn.Parameter(theta_global[0])
    model_global.linear1.bias = torch.nn.Parameter(theta_global[1])
    model_global.linear2.weight = torch.nn.Parameter(theta_global[2])
    model_global.linear2.bias = torch.nn.Parameter(theta_global[3])

    return theta_global, model_global

In [39]:
def run_validation(e, resnet, model_global, val_loader, criterion = nn.BCEWithLogitsLoss()):
    model_global.eval()
    
    total_loss = 0
    all_output = []
    all_labels = []
    
    for batch,labels in val_loader:
        batch_size = batch.shape[0]
        num_tasks = batch.shape[1]
        labels = labels.cuda().transpose(0,1)
        inputs = batch.cuda().transpose(0,1).reshape(batch_size * len(num_tasks), 3, 256, 256)
        
        embed = resnet(inputs)
        output = model_global(embed)
        loss = criterion(output, labels)
        
        total_loss += loss.detach().cpu().numpy()
        all_output.extend((output.contiguous().view(-1) > 0.5).float().detach().cpu().numpy())
        all_labels.extend(labels.contiguous().view(-1).float().detach().cpu().numpy())
    
    acc = np.mean(np.array(all_output) == np.array(all_labels))
    if e % 1 == 0:
        print('Epoch: {0}, Val NLL: {1:0.4f}, Val Acc: {2:0.4f}'.format(e, loss, acc))
    
    return loss, acc

In [40]:
# initialize trained resnet
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)

# freeze layers
resnet.fc = Identity()
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

In [41]:
# initialize theta_global
model_global = FeedForward(input_size, hidden_size, output_size).cuda()
theta_global = []
for p in model_global.parameters():
    theta_global.append(torch.randn(list(p.shape)).cuda())
    
model_global.linear1.weight = torch.nn.Parameter(theta_global[0])
model_global.linear1.bias = torch.nn.Parameter(theta_global[1])
model_global.linear2.weight = torch.nn.Parameter(theta_global[2])
model_global.linear2.bias = torch.nn.Parameter(theta_global[3])

In [42]:
# train params
num_epochs = 1000
alpha = 0.1
eta = 0.1
patience = 3
factor = 0.1
patience_count = 0
previous_loss = 1e8

In [None]:
# train meta-learner
for e in range(num_epochs):
    # reduce LR on plateau
    if patience_count > patience:
        alpha = factor * alpha
        eta = factor * eta
        patience_count = 0
        print('--- LR DECAY --- Alpha: {0:0.8f}, Eta: {1:0.8f}'.format(alpha, eta))
    
    for step, (tiles, labels) in enumerate(train_loader):  
        if step % 1 == 0:
            print(step, end=' ')
        tiles, labels = tiles.cuda(), labels.cuda()           
        grads = train_local(tiles, labels, resnet, theta_global, alpha = alpha)
        theta_global, model_global = train_global(theta_global, model_global, grads, eta = eta)
    
    loss, acc = run_validation(e, resnet, model_global, val_loader)
    
    if loss > previous_loss:
        patience_count += 1
    else:
        patience_count = 0
        
    previous_loss = loss