In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
# define global variables
classification = 'WGD'
magnification = '10.0'
output_shape = 1
device = torch.device('cuda', 0)
root_dir = '/n/mounted-data-drive/'

## Prep Images - DONE

In [3]:
# get image file paths
batch_one = ['COAD', 'BRCA', 'UCEC']
batch_two_orig = ['BLCA', 'KIRC', 'READ', 'HNSC', 'LUSC', 'LIHC', 'LUAD', 'STAD']
if magnification == '10.0':
    batch_two = [b + '_10x' for b in batch_two_orig]
elif magnification == '5.0':
    batch_two = [b + '_5x' for b in batch_two_orig]

In [4]:
# get sample annotations
# NOTE: ONLY FOR WGD
wgd_path = 'ALL_WGD_TABLE.xlsx'
wgd_raw = pd.read_excel(wgd_path)
#wgd_raw.head(3)

batch_all_orig = batch_one + batch_two_orig
wgd_filtered = wgd_raw.loc[wgd_raw['Type'].isin(batch_all_orig)]
#wgd_filtered.head(3)

wgd_filtered.loc[wgd_filtered['Genome_doublings'].values == 2, 'Genome_doublings'] = 1

wgd_filtered.set_index('Sample', inplace=True)
#wgd_filtered.head(3)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  self.obj[item] = s


In [6]:
# get sample annotations for all cancer types
# split samples into two sets of train/val
sa_trains1 = []
sa_vals1 = []
sa_trains2 = []
sa_vals2 = []
batch_all = batch_one + batch_two

print('Num Samples with Images and Labels:')
for cancer in batch_all:
    sa_train1, sa_val1, sa_train2, sa_val2 = data_utils.process_WGD_data(root_dir='/n/mounted-data-drive/', 
                                                                         cancer_type=cancer, 
                                                                         wgd_path=None, 
                                                                         split_in_two=True, 
                                                                         print_overlap=True, 
                                                                         wgd_raw=wgd_filtered)
    sa_trains1.append(sa_train1)
    sa_vals1.append(sa_val1)
    sa_trains2.append(sa_train2)
    sa_vals2.append(sa_val2)

Num Samples with Images and Labels:
COAD      Num Images:   433  Num Labels:   433  Overlap:   406
BRCA      Num Images: 1,054  Num Labels: 1,048  Overlap:   998
UCEC      Num Images:   505  Num Labels:   517  Overlap:   477
BLCA_10x  Num Images:   387  Num Labels:   402  Overlap:   377
KIRC_10x  Num Images:   508  Num Labels:   483  Overlap:   459
READ_10x  Num Images:   157  Num Labels:   155  Overlap:   143
HNSC_10x  Num Images:   365  Num Labels:   512  Overlap:   351
LUSC_10x  Num Images:   479  Num Labels:   482  Overlap:   460
LIHC_10x  Num Images:   365  Num Labels:   362  Overlap:   351
LUAD_10x  Num Images:   466  Num Labels:   503  Overlap:   448
STAD_10x  Num Images:   373  Num Labels:   427  Overlap:   358


In [7]:
# save sample annotations in a pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
with open(pickle_file, 'wb') as f: 
    pickle.dump([batch_all, sa_trains1, sa_vals1, sa_trains2, sa_vals2], f)

## Start Here

In [3]:
# load sample annotations pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, _, _, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)

In [15]:
# initialize Datasets
train_sets = []
val_sets = []

#train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

train_cancers = ['COAD', 'BRCA', 'READ_10x', 'LUSC_10x', 'BLCA_10x', 'LUAD_10x', 'STAD_10x', 'HNSC_10x']
val_cancers = ['UCEC', 'LIHC_10x', 'KIRC_10x']

In [5]:
for i in range(len(train_cancers)):
    print(train_cancers[i], end=' ')
    train_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(train_cancers[i])], 
                                             root_dir + train_cancers[i] + '/', 
                                             transform=val_transform, 
                                             magnification=magnification, 
                                             batch_type='tile')
    train_sets.append(train_set)

COAD BRCA READ_10x LUSC_10x BLCA_10x LUAD_10x STAD_10x HNSC_10x 

In [6]:
for j in range(len(val_cancers)):
    print(val_cancers[j], end=' ')
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile')
    val_sets.append(val_set)

UCEC LIHC_10x KIRC_10x 

In [None]:
print('Num Tiles:')
for cancer, tset, vset in zip(batch_all, train_sets, val_sets):
    print('{0:<8}  Train: {1:>10,d}              Val: {2:>8,d}'.format(cancer, tset.__len__(), vset.__len__()))
    print('          Train: (0) {0:0.4f}, (1) {1:0.4f}  Val: (0) {2:0.4f} (1) {3:0.4f}'.format(np.mean(np.array(tset.all_labels) == 0),
                                                                                              np.mean(np.array(tset.all_labels) == 1),
                                                                                              np.mean(np.array(vset.all_labels) == 0),
                                                                                              np.mean(np.array(vset.all_labels) == 1)))

## Prep Model

In [7]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [8]:
# get model file paths
if classification == 'WGD':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_10x_sa.pkl'
        #state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
        sa_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_all_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_v04_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_v04.pt'
elif classification == 'MSI':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02.pt'

In [22]:
# load embedding network and freeze layers
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.fc = Identity()
#resnet.fc = nn.Linear(2048, 2048, bias=False)
#resnet.fc.weight.data=torch.eye(2048)
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False
#resnet.fc.weight.parameters.requires_grad = True
#resnet.fc.bias.parameters.requires_grad = True

# initialize fully-connected final layer 
final_embed_layer = nn.Linear(2048, 2048)
final_embed_layer.cuda()

Linear(in_features=2048, out_features=2048, bias=True)

## ConcatDataset

In [10]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return torch.stack([d[i][0] for d in self.datasets]), torch.cat([torch.tensor(d[i][1]).view(-1) for d in self.datasets])

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [23]:
batch_size = 100
support_size = 20

In [12]:
train_loader = torch.utils.data.DataLoader(ConcatDataset(*train_sets), 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=24, 
                                           pin_memory=True)

In [24]:
learning_rate = 1e-5
lsm = nn.LogSoftmax(dim=1)
criterion = nn.BCELoss()
#optimizer = torch.optim.Adam(resnet.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(final_embed_layer.parameters(), lr=learning_rate)

In [14]:
for step, (tiles, labels) in enumerate(train_loader):  
    labels = labels.cuda().float().transpose(0,1)    
    # flatten batch_size x num_cancer_types 
    batch = tiles.cuda().transpose(0,1).reshape(batch_size * len(train_cancers), 3, 256, 256)    
    break

In [27]:
for e in range(2000):    
    # forward pass
    output = resnet(batch)
    
    # un-flatten num_cancer_types x batch_size
    cancers_by_feats = torch.stack(torch.chunk(output, len(train_cancers)))    
        
    # split feats, labels into support, query sets
    feats_support = cancers_by_feats[:, :support_size, :]
    feats_support = feats_support.reshape(support_size * len(train_cancers), 2048)
    feats_support = final_embed_layer(feats_support)
    feats_support = torch.stack(torch.chunk(feats_support, len(train_cancers)))
    
    feats_query = cancers_by_feats[:, support_size:, :]
    labels_support = labels[:,:support_size]
    labels_query = labels[:,support_size:]
    
    # get preds    
    scores = lsm(torch.bmm(feats_support, feats_query.transpose(1,2))).exp()
    preds = torch.bmm(labels_support.unsqueeze(1), scores).squeeze(1)
    clamped_preds = torch.clamp(preds, 0, 1)
    
    # calc loss, backprop, step    
    loss = criterion(clamped_preds, labels_query)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    #if step % 5 == 0:
    #    print('Epoch: {0}, Step: {1}, Train NLL: {2:0.4f}'.format(e, step, loss.detach().cpu().numpy()))
    if e % 100 == 0:
        print('Epoch: {0}, Train NLL: {1:0.4f}'.format(e, loss.detach().cpu().numpy()))

Epoch: 500, Train NLL: 13.2893
Epoch: 510, Train NLL: 13.2892
Epoch: 520, Train NLL: 13.2890
Epoch: 530, Train NLL: 13.2888
Epoch: 540, Train NLL: 13.2887
Epoch: 550, Train NLL: 13.2885
Epoch: 560, Train NLL: 13.2884
Epoch: 570, Train NLL: 13.2882
Epoch: 580, Train NLL: 13.2880
Epoch: 590, Train NLL: 13.2878
Epoch: 600, Train NLL: 13.2876
Epoch: 610, Train NLL: 13.2875
Epoch: 620, Train NLL: 13.2873
Epoch: 630, Train NLL: 13.2871
Epoch: 640, Train NLL: 13.2869
Epoch: 650, Train NLL: 13.2867
Epoch: 660, Train NLL: 13.2865
Epoch: 670, Train NLL: 13.2863
Epoch: 680, Train NLL: 13.2861
Epoch: 690, Train NLL: 13.2859
Epoch: 700, Train NLL: 13.2857
Epoch: 710, Train NLL: 13.2855
Epoch: 720, Train NLL: 13.2853
Epoch: 730, Train NLL: 13.2851
Epoch: 740, Train NLL: 13.2848
Epoch: 750, Train NLL: 13.2846
Epoch: 760, Train NLL: 13.2844
Epoch: 770, Train NLL: 13.2842
Epoch: 780, Train NLL: 13.2839
Epoch: 790, Train NLL: 13.2837
Epoch: 800, Train NLL: 13.2835
Epoch: 810, Train NLL: 13.2832
Epoch: 8