In [12]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [13]:
# define global variables
classification = 'WGD'
magnification = '10.0'
output_shape = 1
device = torch.device('cuda', 0)
root_dir = '/n/mounted-data-drive/'

## Prep Images - DONE

In [3]:
# get image file paths
batch_one = ['COAD', 'BRCA', 'UCEC']
batch_two_orig = ['BLCA', 'KIRC', 'READ', 'HNSC', 'LUSC', 'LIHC', 'LUAD', 'STAD']
if magnification == '10.0':
    batch_two = [b + '_10x' for b in batch_two_orig]
elif magnification == '5.0':
    batch_two = [b + '_5x' for b in batch_two_orig]

In [None]:
# get sample annotations
# NOTE: ONLY FOR WGD
wgd_path = 'ALL_WGD_TABLE.xlsx'
wgd_raw = pd.read_excel(wgd_path)
#wgd_raw.head(3)

batch_all_orig = batch_one + batch_two_orig
wgd_filtered = wgd_raw.loc[wgd_raw['Type'].isin(batch_all_orig)]
#wgd_filtered.head(3)

wgd_filtered.loc[wgd_filtered['Genome_doublings'].values == 2, 'Genome_doublings'] = 1

wgd_filtered.set_index('Sample', inplace=True)
#wgd_filtered.head(3)

In [None]:
# get sample annotations for all cancer types
# split samples into two sets of train/val
sa_trains1 = []
sa_vals1 = []
sa_trains2 = []
sa_vals2 = []
batch_all = batch_one + batch_two

print('Num Samples with Images and Labels:')
for cancer in batch_all:
    sa_train1, sa_val1, sa_train2, sa_val2 = data_utils.process_WGD_data(root_dir='/n/mounted-data-drive/', 
                                                                         cancer_type=cancer, 
                                                                         wgd_path=None, 
                                                                         split_in_two=True, 
                                                                         print_overlap=True, 
                                                                         wgd_raw=wgd_filtered)
    sa_trains1.append(sa_train1)
    sa_vals1.append(sa_val1)
    sa_trains2.append(sa_train2)
    sa_vals2.append(sa_val2)

In [None]:
# save sample annotations in a pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
with open(pickle_file, 'wb') as f: 
    pickle.dump([batch_all, sa_trains1, sa_vals1, sa_trains2, sa_vals2], f)

## Start Here

In [14]:
# load sample annotations pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, _, _, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)

In [15]:
# initialize Datasets
train_sets = []
val_sets = []

train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

train_cancers = ['COAD', 'BRCA', 'READ_10x', 'LUSC_10x', 'BLCA_10x', 'LUAD_10x', 'STAD_10x', 'HNSC_10x']
val_cancers = ['UCEC', 'LIHC_10x', 'KIRC_10x']

In [16]:
for i in range(len(train_cancers)):
    print(train_cancers[i], end=' ')
    train_set = data_utils.TCGADataset_tiles(sa_trains[batch_all.index(train_cancers[i])], 
                                             root_dir + train_cancers[i] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification, 
                                             batch_type='tile')
    train_sets.append(train_set)

COAD BRCA READ_10x LUSC_10x BLCA_10x LUAD_10x STAD_10x HNSC_10x 

In [17]:
for j in range(len(val_cancers)):
    print(val_cancers[j], end=' ')
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile')
    val_sets.append(val_set)

UCEC LIHC_10x KIRC_10x 

In [None]:
print('Num Tiles:')
for cancer, tset, vset in zip(batch_all, train_sets, val_sets):
    print('{0:<8}  Train: {1:>10,d}              Val: {2:>8,d}'.format(cancer, tset.__len__(), vset.__len__()))
    print('          Train: (0) {0:0.4f}, (1) {1:0.4f}  Val: (0) {2:0.4f} (1) {3:0.4f}'.format(np.mean(np.array(tset.all_labels) == 0),
                                                                                              np.mean(np.array(tset.all_labels) == 1),
                                                                                              np.mean(np.array(vset.all_labels) == 0),
                                                                                              np.mean(np.array(vset.all_labels) == 1)))

## Prep Model

In [18]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [19]:
# get model file paths
if classification == 'WGD':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_10x_sa.pkl'
        #state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
        sa_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_all_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_v04_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_v04.pt'
elif classification == 'MSI':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02.pt'

In [20]:
# load embedding network

# alternative 1
#resnet = models.resnet18(pretrained=True)

# alternative 2
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)

# freeze layers
resnet.fc = Identity()
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

# initialize fully-connected final layer 
final_embed_layer = nn.Linear(2048, 2048)
final_embed_layer.cuda()

Linear(in_features=2048, out_features=2048, bias=True)

## ConcatDataset

In [21]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __len__(self):
        return min(len(d) for d in self.datasets)
    
    def __getitem__(self, i):
        return torch.stack([d[i][0] for d in self.datasets]), torch.cat([torch.tensor(d[i][1]).view(-1) for d in self.datasets])

In [22]:
batch_size = 100
support_size = 10

In [23]:
train_loader = torch.utils.data.DataLoader(ConcatDataset(*train_sets), 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=20, 
                                           pin_memory=True)

In [24]:
val_loader = torch.utils.data.DataLoader(ConcatDataset(*val_sets), 
                                         batch_size=batch_size, 
                                         shuffle=True, 
                                         num_workers=20, 
                                         pin_memory=True)

In [25]:
learning_rate = 1e-5
lsm = nn.LogSoftmax(dim=1)
criterion = nn.BCELoss()
#optimizer = torch.optim.Adam(resnet.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(final_embed_layer.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, verbose=True, min_lr=1e-8)

In [26]:
def fewshot_training_loop(train_loader, train_cancers, batch_size, resnet, final_embed_layer, criterion, optimizer):
    all_labels = []
    all_preds = []
    for step, (tiles, labels) in enumerate(train_loader):  
        labels = labels.cuda().float().transpose(0,1)    

        # flatten batch_size x num_cancer_types 
        batch = tiles.cuda().transpose(0,1).reshape(batch_size * len(train_cancers), 3, 256, 256)    

        # forward pass
        output = resnet(batch)

        # un-flatten num_cancer_types x batch_size
        cancers_by_feats = torch.stack(torch.chunk(output, len(train_cancers)))    

        # split feats, labels into support, query sets
        feats_support = cancers_by_feats[:, :support_size, :]
        feats_support = feats_support.reshape(support_size * len(train_cancers), 2048)
        feats_support = final_embed_layer(feats_support)
        feats_support = torch.stack(torch.chunk(feats_support, len(train_cancers)))    
        feats_query = cancers_by_feats[:, support_size:, :]    
        labels_support = labels[:,:support_size]
        labels_query = labels[:,support_size:]

        # get preds    
        scores = lsm(torch.bmm(feats_support, feats_query.transpose(1,2))).exp()
        preds = torch.bmm(labels_support.unsqueeze(1), scores).squeeze(1)
        clamped_preds = torch.clamp(preds, 0, 1)

        # calc loss, backprop, step    
        loss = criterion(clamped_preds, labels_query)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        clamped_preds = (clamped_preds.contiguous().view(-1) > 0.5).float().detach().cpu().numpy()
        labels_query = labels_query.contiguous().view(-1).float().detach().cpu().numpy()
        all_preds.extend(clamped_preds)
        all_labels.extend(labels_query)
        
        if step % 100 == 0:
            acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(labels_query, clamped_preds)
            print('Epoch: {0}, Step: {1}, Train NLL: {2:0.4f}, Acc: {3:04f}, By Label: {4}'.format(e, step, loss.detach().cpu().numpy(), acc, tile_acc_by_label))
        
    acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(all_labels, all_preds)
    print('Epoch: {0}, Train NLL: {1:0.4f}, Acc: {2:0.4f}, By Label: {3}'.format(e, loss.detach().cpu().numpy(), acc, tile_acc_by_label))

In [27]:
def fewshot_validation_loop(val_loader, val_cancers, batch_size, resnet, final_embed_layer, criterion):
    all_labels = []
    all_preds = []
    for step, (tiles, labels) in enumerate(val_loader):  
        labels = labels.cuda().float().transpose(0,1)    

        # flatten batch_size x num_cancer_types 
        batch = tiles.cuda().transpose(0,1).reshape(batch_size * len(val_cancers), 3, 256, 256)    

        # forward pass
        output = resnet(batch)

        # un-flatten num_cancer_types x batch_size
        cancers_by_feats = torch.stack(torch.chunk(output, len(val_cancers)))    

        # split feats, labels into support, query sets
        feats_support = cancers_by_feats[:, :support_size, :]
        feats_support = feats_support.reshape(support_size * len(val_cancers), 2048)
        feats_support = final_embed_layer(feats_support)
        feats_support = torch.stack(torch.chunk(feats_support, len(val_cancers)))    
        feats_query = cancers_by_feats[:, support_size:, :]    
        labels_support = labels[:,:support_size]
        labels_query = labels[:,support_size:]

        # get preds    
        scores = lsm(torch.bmm(feats_support, feats_query.transpose(1,2))).exp()
        preds = torch.bmm(labels_support.unsqueeze(1), scores).squeeze(1)
        clamped_preds = torch.clamp(preds, 0, 1)

        # calc loss
        loss = criterion(clamped_preds, labels_query)
        
        clamped_preds = (clamped_preds.contiguous().view(-1) > 0.5).float().detach().cpu().numpy()
        labels_query = labels_query.contiguous().view(-1).float().detach().cpu().numpy()
        all_preds.extend(clamped_preds)
        all_labels.extend(labels_query)
        
        if step % 100 == 0:
            acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(labels_query, clamped_preds)
            print('Epoch: {0}, Step: {1}, Val NLL: {2:0.4f}, Acc: {3:04f}, By Label: {4}'.format(e, step, loss.detach().cpu().numpy(), acc, tile_acc_by_label))
    
    acc, tile_acc_by_label = train_utils.calc_tile_acc_stats(all_labels, all_preds)
    print('Epoch: {0}, Val NLL: {1:0.4f}, Acc: {2:0.4f}, By Label: {3}'.format(e, loss.detach().cpu().numpy(), acc, tile_acc_by_label))
    
    if e > 1000:
        scheduler.step(loss)
     
    return total_loss, acc

In [28]:
for e in range(5000):
    fewshot_training_loop(train_loader, train_cancers, batch_size, resnet, final_embed_layer, criterion, optimizer)
    loss, acc = fewshot_validation_loop(val_loader, val_cancers, batch_size, resnet, final_embed_layer, criterion)

Epoch: 0, Step: 0, Train NLL: 10.5369, Acc: 0.516667, By Label: 0: 0.7595, 1: 0.2655
Epoch: 0, Step: 100, Train NLL: 14.1943, Acc: 0.548611, By Label: 0: 0.2905, 1: 0.7940
Epoch: 0, Step: 200, Train NLL: 11.8967, Acc: 0.569444, By Label: 0: 0.5726, 1: 0.5664
Epoch: 0, Step: 300, Train NLL: 12.0326, Acc: 0.572222, By Label: 0: 0.4441, 1: 0.6868


KeyboardInterrupt: 