In [1]:
import sys
sys.path.append('../')

import os
import gc
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
root_dir = data_utils.root_dir_all
pickle_file = '/n/tcga_wgd_sa_all_1.0.pkl'
transform_val = train_utils.transform_validation
magnification = '10.0'
batch_size = 256
n_workers = 0

In [3]:
batch_all, sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file,
                                                                         return_all_cancers=True,
                                                                         split_in_two=False)
train_cancers = ['COAD']
val_cancers = ['COAD']
train_idxs = [batch_all.index(cancer) for cancer in train_cancers]    
val_idxs = [batch_all.index(cancer) for cancer in val_cancers]
train_sets = []
val_sets = []
for i in range(len(train_cancers)):
    train_set = data_utils.TCGADataset_tiles(sa_trains[batch_all.index(train_cancers[i])], 
                                         root_dir + train_cancers[i] + '/', 
                                         transform=transform_val, 
                                         magnification=magnification, return_jpg_to_sample=True)
    train_sets.append(train_set)    

for j in range(len(val_cancers)):
    val_set = data_utils.TCGADataset_tiles(sa_vals[batch_all.index(val_cancers[j])], 
                                       root_dir + val_cancers[j] + '/', 
                                       transform=transform_val, 
                                       magnification=magnification, return_jpg_to_sample=True)
    val_sets.append(val_set)

In [4]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=n_workers, 
                                           pin_memory=False)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=n_workers, 
                                         pin_memory=False)

In [5]:
output_size = 1
state_dict_file = '/n/tcga_models/ResNet_10x_retrain_5-4-18_lr_log.pt'
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.fc = model_utils.Identity()
for p in resnet.parameters():
    p.requires_grad = False

In [6]:
resnet.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [7]:
emb_all = []
lab_all = []
jpg_all = []

In [8]:
print(len(val_loader))
for idx,(tiles,labels,jpgs) in enumerate(val_loader):
    tiles = tiles.cuda()
    with torch.no_grad():
        embedding = resnet(tiles)
    emb_all.append(embedding.detach().cpu())
    lab_all.append(labels)
    jpg_all.append(jpgs)
    if idx % 10 == 0:
        print(idx, end=' ')

881
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 710 720 730 740 750 760 770 780 790 800 810 820 830 840 850 860 870 880 

In [9]:
pickle_file = 'tmp.val.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([emb_all, lab_all, jpg_all], f)

In [10]:
val_embeddings = torch.cat(emb_all,dim=0)
val_labels = torch.cat(lab_all)
val_jpgs_to_slide = torch.cat(jpg_all)
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([val_embeddings, val_labels, val_jpgs_to_slide], f,protocol=4)

In [None]:
pickle_file = 'tmp.train.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([emb_all, lab_all, jpg_all], f)

In [None]:
train_embeddings = torch.cat(emb_all,dim=0)
train_labels = torch.cat(lab_all)
train_jpgs_to_slide = torch.cat(jpg_all)

In [None]:
train_embeddings.shape[0],train_labels.shape[0],train_jpgs_to_slide.shape[0]

In [None]:
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([train_embeddings, train_labels, train_jpgs_to_slide], f,protocol=4)

In [None]:
del tiles,embedding,resnet
torch.cuda.empty_cache()

In [None]:
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'rb') as f: 
    train_embeddings,train_labels,train_jpgs_to_slide = pickle.load(f)
    
    
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'rb') as f: 
    val_embeddings,val_labels,val_jpgs_to_slide = pickle.load(f)

In [None]:
attn_pool = model_utils.Attention(input_size=2048,hidden_size=512,gated=False,output_size=1)

In [None]:
attn_pool.cuda()

In [None]:
idx = 0


In [None]:
train_embeddings = train_embeddings.cuda()

In [None]:
learning_rate = 1e-3
n_samples = train_jpgs_to_slide.max()+1
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=attn_pool.parameters(),lr=learning_rate)

In [None]:
step_size = 149
logits_vec = torch.zeros((step_size+1,1)).cuda()
labels_vec = torch.zeros_like(logits_vec).cuda()
batch_idx = 0
for idx in range(n_samples):
    slide = train_embeddings[train_jpgs_to_slide==idx]
    labels_vec[batch_idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
    logits_vec[batch_idx], _  = attn_pool(slide)
    if batch_idx == step_size:
        loss = criterion(logits_vec,labels_vec)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        logits_vec = torch.zeros((step_size+1,1)).cuda()
        labels_vec = torch.zeros_like(logits_vec).cuda()
        batch_idx = 0
    else:
        batch_idx += 1
        
loss = criterion(logits_vec,labels_vec)
loss.backward()
optimizer.step()
optimizer.zero_grad()
