In [2]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [4]:
# define global variables
classification = 'WGD'
magnification = '10.0'
output_size = 1
device = torch.device('cuda', 0)

## Prep Images

In [6]:
# get image file paths
root_dir = '/n/mounted-data-drive/'
batch_one = ['COAD', 'BRCA', 'UCEC']
batch_two_orig = ['BLCA', 'KIRC', 'READ', 'HNSC', 'LUSC', 'LIHC', 'LUAD', 'STAD']
if magnification == '10.0':
    batch_two = [b + '_10x' for b in batch_two_orig]
elif magnification == '5.0':
    batch_two = [b + '_5x' for b in batch_two_orig]

In [7]:
# get sample annotations
# NOTE: ONLY FOR WGD
wgd_path = 'ALL_WGD_TABLE.xlsx'
wgd_raw = pd.read_excel(wgd_path)
#wgd_raw.head(3)

batch_all_orig = batch_one + batch_two_orig
wgd_filtered = wgd_raw.loc[wgd_raw['Type'].isin(batch_all_orig)]
#wgd_filtered.head(3)

wgd_filtered.loc[wgd_filtered['Genome_doublings'].values == 2, 'Genome_doublings'] = 1

wgd_filtered.set_index('Sample', inplace=True)
#wgd_filtered.head(3)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  self.obj[item] = s


In [54]:
# get sample annotations for all cancer types
sa_trains = []
sa_vals = []
batch_all = batch_one + batch_two

print('Num Samples with Images and Labels:')
for cancer in batch_all:
    sa_train, sa_val = data_utils.process_WGD_data(root_dir='/n/mounted-data-drive/', cancer_type=cancer, 
                                                   wgd_path=None, wgd_raw = wgd_filtered)
    sa_trains.append(sa_train)
    sa_vals.append(sa_val)

Num Samples with Images and Labels:
COAD      Num Images:   433  Num Labels:   433  Overlap:   406
BRCA      Num Images: 1,054  Num Labels: 1,048  Overlap:   998
UCEC      Num Images:   505  Num Labels:   517  Overlap:   477
BLCA_10x  Num Images:   387  Num Labels:   402  Overlap:   377
KIRC_10x  Num Images:   508  Num Labels:   483  Overlap:   459
READ_10x  Num Images:   157  Num Labels:   155  Overlap:   143
HNSC_10x  Num Images:   365  Num Labels:   512  Overlap:   351
LUSC_10x  Num Images:   479  Num Labels:   482  Overlap:   460
LIHC_10x  Num Images:   365  Num Labels:   362  Overlap:   351
LUAD_10x  Num Images:   466  Num Labels:   503  Overlap:   448
STAD_10x  Num Images:   373  Num Labels:   427  Overlap:   358


In [18]:
# save sample annotations in a pickle
pickle_file = 'tcga_wgd_sa_all.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([batch_all, sa_trains, sa_vals], f)

In [19]:
# load sample annotations pickle
batch_all, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                         return_all_cancers=True)

In [None]:
# initialize Datasets
train_sets = []
val_sets = []

train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

for i in range(len(batch_all)):
    train_set = data_utils.TCGADataset_tiles(sa_trains[i], 
                                             root_dir + batch_all[i] + '/', 
                                             transform=train_transform, 
                                             magnification=magnification, 
                                             batch_type='tile')
    val_set = data_utils.TCGADataset_tiles(sa_vals[i], 
                                           root_dir + batch_all[i] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile')
    train_sets.append(train_set)
    val_sets.append(val_set)

In [56]:
print('Num Tiles:')
for cancer, tset, vset in zip(batch_all, train_sets, val_sets):
    print('{0:<8}  Train: {1:>10,d}              Val: {2:>8,d}'.format(cancer, tset.__len__(), vset.__len__()))
    print('          Train: (0) {0:0.4f}, (1) {1:0.4f}  Val: (0) {2:0.4f} (1) {3:0.4f}'.format(np.mean(np.array(tset.all_labels) == 0),
                                                                                              np.mean(np.array(tset.all_labels) == 1),
                                                                                              np.mean(np.array(vset.all_labels) == 0),
                                                                                              np.mean(np.array(vset.all_labels) == 1)))

Num Tiles:
COAD      Train:    942,176              Val:  247,402
          Train: (0) 0.6246, (1) 0.3754  Val: (0) 0.6002 (1) 0.3998
BRCA      Train:  1,891,942              Val:  511,924
          Train: (0) 0.5509, (1) 0.4491  Val: (0) 0.5445 (1) 0.4555
UCEC      Train:  1,570,751              Val:  417,222
          Train: (0) 0.8230, (1) 0.1770  Val: (0) 0.7520 (1) 0.2480
BLCA_10x  Train:  1,102,415              Val:  289,763
          Train: (0) 0.3706, (1) 0.6294  Val: (0) 0.4747 (1) 0.5253
KIRC_10x  Train:  1,175,609              Val:  323,656
          Train: (0) 0.7961, (1) 0.2039  Val: (0) 0.8145 (1) 0.1855
READ_10x  Train:    293,972              Val:   79,397
          Train: (0) 0.4715, (1) 0.5285  Val: (0) 0.4454 (1) 0.5546
HNSC_10x  Train:    985,504              Val:  219,335
          Train: (0) 0.6641, (1) 0.3359  Val: (0) 0.7502 (1) 0.2498
LUSC_10x  Train:  1,018,233              Val:  250,477
          Train: (0) 0.4718, (1) 0.5282  Val: (0) 0.3744 (1) 0.6256
LIHC_

## Prep Model

In [5]:
# get model file paths
if classification == 'WGD':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_10x_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_v04_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_v04.pt'
elif classification == 'MSI':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02.pt'

In [None]:
# load embedding network and freeze layers
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.fc = nn.Linear(2048, 2048, bias=False)
resnet.fc.weight.data=torch.eye(2048)
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

# initialize fully-connected final layer 
final_embed_layer = nn.Linear(2048, 2048)
final_embed_layer.cuda()

In [None]:
# define few-shot learning params
n_support = 5 # number of training examples in the support set
n_query = 20 # number of training examples in the query set
n_task = 4 # number of 'tasks' to sample from each cancer type