In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [3]:
# define global vars
classification = 'WGD'
magnification = '10.0'
output_size = 1
device = torch.device('cuda', 0)

In [4]:
# get model file paths
if classification == 'WGD':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_10x_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_WGD_v04_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_WGD_v04.pt'
elif classification == 'MSI':
    if magnification == '10.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_10x.pt'
    elif magnification == '5.0':
        #sa_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02_sa.pkl'
        state_dict_file = '/n/tcga_models/resnet18_MSI_singlelabel_v02.pt'

In [None]:
# load embedding network and freeze layers
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.fc = nn.Linear(2048, 2048, bias=False)
resnet.fc.weight.data=torch.eye(2048)
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

In [None]:
# initialize fully-connected final layer 
final_embed_layer = nn.Linear(2048, 2048)
final_embed_layer.cuda()

In [6]:
# get image file paths
root_dir = '/n/mounted-data-drive/'
batch_one = ['COAD', 'BRCA', 'UCEC']
batch_two_orig = ['BLCA', 'KIRC', 'READ', 'HNSC', 'LUSC', 'LIHC', 'LUAD', 'STAD']
if magnification == '10.0':
    batch_two = [b + '_10x' for b in batch_two_orig]
elif magnification == '5.0':
    batch_two = [b + '_5x' for b in batch_two_orig]

In [None]:
# get sample annotations
# NOTE: ONLY FOR WGD
wgd_path = 'ALL_WGD_TABLE.xlsx'
wgd_raw = pd.read_excel(wgd_path)
#wgd_raw.head(3)

batch_all = batch_one + batch_two_orig
wgd_filtered = wgd_raw.loc[wgd_raw['Type'].isin(batch_all)]
#wgd_filtered.head(3)

wgd_filtered.loc[wgd_filtered['Genome_doublings'].values == 2, 'Genome_doublings'] = 1

wgd_filtered.set_index('Sample', inplace=True)
#wgd_filtered.head(3)

In [66]:
# get sample annotations for all cancer types
sa_trains = []
sa_vals = []
batch_all = batch_one + batch_two

for cancer in batch_all:
    sa_train, sa_val = data_utils.process_WGD_data(root_dir='/n/mounted-data-drive/', cancer_type=cancer, 
                                                   wgd_path=None, wgd_raw = wgd_filtered)
    sa_trains.append(sa_train)
    sa_vals.append(sa_val)

In [69]:
# save sample annotations in a pickle
pickle_file = 'tcga_wgd_sa_all.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([batch_all, sa_trains, sa_vals], f)

In [75]:
# load sample annotations pickle
batch_all, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file, 
                                                                         return_all_cancers=True)

In [78]:
# initialize Datasets
train_sets = []
val_sets = []
batch_all = batch_one + batch_two_orig

train_transform = train_utils.transform_train
val_transform = train_utils.transform_validation

for i in range(len(batch_all)):
    train_set = data_utils.TCGADataset_tiles(sa_trains[i], root_dir + batch_all[i] + '/', transform=train_transform, 
                                             magnification=magnification, batch_type='tile')
    val_set = data_utils.TCGADataset_tiles(sa_vals[i], root_dir + batch_all[i] + '/', transform=val_transform, 
                                           magnification=magnification, batch_type='tile')
    train_sets.append(train_set)
    val_sets.append(val_set)

FileNotFoundError: [Errno 2] No such file or directory: '/n/mounted-data-drive/UCEC/TCGA-D1-A0ZS-01Z-00-DX1.8021A060-3CA2-418E-AE16-C48E911F5C25.svs/TCGA-D1-A0ZS-01Z-00-DX1.8021A060-3CA2-418E-AE16-C48E911F5C25_files/10.0'

In [None]:
for tset, vset in zip(train_sets, val_sets):
    print(tset.__len__(), vset.__len__())

In [2]:
# define few-shot learning params
n_support = 5 # number of training examples in the support set
n_query = 20 # number of training examples in the query set
n_task = 4 # number of 'tasks' to sample from each cancer type