In [2]:
import sys
sys.path.append('../')

import os
import gc
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [3]:
root_dir = data_utils.root_dir_all
pickle_file = '/n/tcga_models/resnet18_WGD_10x_sa.pkl'
transform_val = train_utils.transform_validation
magnification = '10.0'
batch_size = 256
n_workers = 0

In [8]:
sa_trains, sa_vals = data_utils.load_COAD_train_val_sa_pickle(pickle_file=pickle_file,
                                                                         return_all_cancers=False,
                                                                         split_in_two=False)

In [9]:
train_cancers = ['COAD']
val_cancers = ['COAD']
#train_idxs = [batch_all.index(cancer) for cancer in train_cancers]    
#val_idxs = [batch_all.index(cancer) for cancer in val_cancers]
train_sets = []
val_sets = []
for i in range(len(train_cancers)):
    train_set = data_utils.TCGADataset_tiles(sa_trains, 
                                         root_dir + train_cancers[i] + '/', 
                                         transform=transform_val, 
                                         magnification=magnification, return_jpg_to_sample=True)
    train_sets.append(train_set)    

for j in range(len(val_cancers)):
    val_set = data_utils.TCGADataset_tiles(sa_vals, 
                                       root_dir + val_cancers[j] + '/', 
                                       transform=transform_val, 
                                       magnification=magnification, return_jpg_to_sample=True)
    val_sets.append(val_set)

In [10]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=n_workers, 
                                           pin_memory=False)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=n_workers, 
                                         pin_memory=False)

In [11]:
output_size = 1
state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.fc = model_utils.Identity()
for p in resnet.parameters():
    p.requires_grad = False

In [12]:
resnet.cuda()
resnet.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [13]:
emb_all = []
lab_all = []
jpg_all = []
print(len(val_loader))
for idx,(tiles,labels,jpgs) in enumerate(val_loader):
    tiles = tiles.cuda()
    with torch.no_grad():
        embedding = resnet(tiles)
    emb_all.append(embedding.detach().cpu())
    lab_all.append(labels)
    jpg_all.append(jpgs)
    if idx % 10 == 0:
        print(idx, end=' ')

pickle_file = 'tmp.val.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([emb_all, lab_all, jpg_all], f)

val_embeddings = torch.cat(emb_all,dim=0)
val_labels = torch.cat(lab_all)
val_jpgs_to_slide = torch.cat(jpg_all)
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([val_embeddings, val_labels, val_jpgs_to_slide], f,protocol=4)

838
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 710 720 730 740 750 760 770 780 790 800 810 820 830 

In [14]:
emb_all = []
lab_all = []
jpg_all = []
print(len(train_loader))
for idx,(tiles,labels,jpgs) in enumerate(train_loader):
    tiles = tiles.cuda()
    with torch.no_grad():
        embedding = resnet(tiles)
    emb_all.append(embedding.detach().cpu())
    lab_all.append(labels)
    jpg_all.append(jpgs)
    if idx % 10 == 0:
        print(idx, end=' ')

pickle_file = 'tmp.train.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([emb_all, lab_all, jpg_all], f)

train_embeddings = torch.cat(emb_all,dim=0)
train_labels = torch.cat(lab_all)
train_jpgs_to_slide = torch.cat(jpg_all)
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'wb') as f: 
        pickle.dump([train_embeddings, train_labels, train_jpgs_to_slide], f,protocol=4)

3809
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 710 720 730 740 750 760 770 780 790 800 810 820 830 840 850 860 870 880 890 900 910 920 930 940 950 960 970 980 990 1000 1010 1020 1030 1040 1050 1060 1070 1080 1090 1100 1110 1120 1130 1140 1150 1160 1170 1180 1190 1200 1210 1220 1230 1240 1250 1260 1270 1280 1290 1300 1310 1320 1330 1340 1350 1360 1370 1380 1390 1400 1410 1420 1430 1440 1450 1460 1470 1480 1490 1500 1510 1520 1530 1540 1550 1560 1570 1580 1590 1600 1610 1620 1630 1640 1650 1660 1670 1680 1690 1700 1710 1720 1730 1740 1750 1760 1770 1780 1790 1800 1810 1820 1830 1840 1850 1860 1870 1880 1890 1900 1910 1920 1930 1940 1950 1960 1970 1980 1990 2000 2010 2020 2030 2040 2050 2060 2070 2080 2090 2100 2110 2120 2130 2140 2150 2160 2170 2180 2190 2200 2210 2220 2230 2240 2250 2260 2270 2280 2290 2300 2310 2320 2