In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

import copy

In [2]:
# setup
set_image_backend('accimage')
device = torch.device('cuda', 0)

# load sample annotations pickle
pickle_file = '/home/sxchao/MSI_prediction/tcga_project/tcga_wgd_sa_all.pkl'
batch_all, _, _, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file,
                                                                               return_all_cancers=True, 
                                                                               split_in_two=True)
# normalize and tensorify jpegs
val_transform = train_utils.transform_validation

# initialize Datasets
val_sets = []
val_cancers = ['UCEC', 'LIHC_10x', 'KIRC_10x']
magnification = '10.0'
root_dir = '/n/mounted-data-drive/'
for j in range(len(val_cancers)):
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                           root_dir + val_cancers[j] + '/', 
                                           transform=val_transform, 
                                           magnification=magnification, 
                                           batch_type='tile', 
                                           return_jpg_to_sample=False)
    val_sets.append(val_set)

# get DataLoaders    
batch_size_val = 400
n_workers = 16
val_loaders = [torch.utils.data.DataLoader(val_set, 
                                        batch_size=batch_size_val, 
                                        shuffle=True, 
                                        num_workers=n_workers, 
                                        pin_memory=True) for val_set in val_sets]

In [3]:
# model args
state_dict_file_resnet = '/n/tcga_models/resnet18_WGD_all_10x.pt'
state_dict_file_maml = '/n/tcga_models/maml_WGD_10x.pt'
input_size = 2048
hidden_size = 512
output_size = 1

# initialize trained resnet
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size, bias=True)
saved_state = torch.load(state_dict_file_resnet, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)

# freeze layers
resnet.fc = model_utils.Identity()
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False
    
# initialize theta_global
    net = model_utils.FeedForward(input_size, hidden_size, output_size)
    saved_state = torch.load(state_dict_file_maml, map_location=lambda storage, loc: storage)
    net.load_state_dict(saved_state)
    net.cuda(device=device)

In [4]:
theta_global = [p.detach().clone() for p in net.parameters()]

In [5]:
resnet.eval()
num_steps = 10
#num_tasks = int(tiles.shape[1])
#for t in range(num_tasks):
all_losses = np.zeros((num_steps, len(val_loaders[0])))

In [10]:
alpha = 1e-2
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = alpha)

In [12]:
for batch_num,(tiles,labels) in enumerate(val_loaders[0]):
    print(batch_num, end=' ')
    
    tiles, labels = tiles.cuda(), labels.cuda().float()
    idx = int(tiles.shape[0] / 2)
    inputs_a = tiles[:idx,:,:,:]
    inputs_b = tiles[idx:,:,:,:]
    labels_a = labels[:idx].unsqueeze(1)
    labels_b = labels[idx:].unsqueeze(1)    
    
    for step in range(num_steps):        
        optimizer.zero_grad()
        if step == 0:
            net.eval()
            output_b = net(resnet(inputs_b))
            loss = criterion(output_b, labels_b)
            all_losses[step, batch_num] = loss.detach().cpu().item()
        else:
            # first forward pass, step 
            net.train()
            output_a = net(resnet(inputs_a))
            loss = criterion(output_a, labels_a)
            loss.backward()
            optimizer.step()            
            # second forward pass   
            net.eval()
            output_b = net(resnet(inputs_b))
            loss = criterion(output_b, labels_b)
            all_losses[step, batch_num] = loss.detach().cpu().item()
    print(all_losses[:,batch_num])
    #net.linear1.weight.data = theta_global[0].clone()
    #net.linear1.bias.data = theta_global[1].clone()
    #net.linear2.weight.data = theta_global[2].clone()
    #net.linear2.bias.data = theta_global[3].clone()
    net.update_params(theta_global)

0 [1.40835369 2.39327192 1.88396323 1.16007531 2.9691081  1.44426143
 2.31927657 1.47561824 1.64480376 2.49590039]
1 [2.91604567 2.35293484 1.48522079 1.24402201 1.47609496 1.37248039
 1.42844009 1.47775078 1.53149951 1.61077023]
2 [3.60560536 2.67700076 2.13805628 1.88240087 1.64428353 1.90511072
 1.53838599 1.32009089 1.27488601 1.29028821]
3 [3.43054748 2.74031782 2.99221635 2.16195297 2.35563302 2.15735507
 2.13523865 2.30686164 2.52127194 2.96156406]
4 [3.34314322 2.60716343 2.19643068 1.7887404  1.59041154 1.52957809
 1.43499696 1.43333352 1.56361604 1.75463784]
5 [2.84291625 1.94059026 1.56476462 1.87794495 1.52578425 1.29282808
 1.31949794 1.44589794 1.61176729 1.51044679]
6 [3.37569094 2.36453891 2.05386066 2.25720811 2.48291969 2.21124959
 2.17643833 2.50970149 3.1274159  3.1077466 ]
7 [3.92434812 3.7060132  2.56321526 1.97081971 1.68264306 1.80115235
 1.54789543 1.66056061 1.59020591 1.5105679 ]
8 [4.02207708 2.50514436 2.18152356 1.85807455 2.22588515 1.75629973
 1.68750608

71 [3.09142399 2.17179132 1.8532275  1.45565712 1.48311865 1.31265819
 1.28105462 1.27485251 1.33091807 1.38645113]
72 [2.83916974 1.93725979 2.33901358 1.74084878 1.99729228 1.35335124
 1.92469811 1.63140774 1.9698602  2.26310349]
73 [3.68423843 3.22030544 2.15728879 2.17349648 2.44706416 2.82055259
 2.21734619 1.9856838  1.99940765 2.33087611]
74 [4.00544071 3.31622481 3.56498098 2.44904518 2.4576354  2.56261969
 2.55616808 2.35766864 2.70202756 2.38553953]
75 [4.51596498 3.85257125 3.1304419  2.68331861 2.60273862 2.52270055
 2.49925137 2.41657186 2.47303438 2.58869839]
76 [4.0538826  3.19204712 2.67170811 2.07694983 2.05416059 1.62882507
 1.64636064 1.7064023  1.76038444 1.80221903]
77 [3.36139393 2.24078679 2.54530573 1.80162823 1.83813691 1.74567354
 2.08612919 1.97456634 2.48521614 2.3241539 ]
78 [3.45928097 3.48326159 3.4006896  2.56680965 2.24646664 2.59157801
 2.32789803 2.09280872 2.15281439 2.20636582]
79 [3.90829301 3.18331289 3.04693151 2.67625427 3.23664451 2.25808716
 1

KeyboardInterrupt: 