In [1]:
import sys
sys.path.append('../')

import os
import gc
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [8]:
output_size = 1
state_dict_file = '/n/tcga_models/ResNet_10x_retrain_5-4-18_lr_log.pt'
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
for p in resnet.parameters():
    p.requires_grad = False

In [2]:
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'rb') as f: 
    train_embeddings,train_labels,train_jpgs_to_slide = pickle.load(f)
    
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'rb') as f: 
    val_embeddings,val_labels,val_jpgs_to_slide = pickle.load(f)

In [4]:
resnet_fc = resnet.fc

In [5]:
n_samples = train_jpgs_to_slide.max()+1
criterion = nn.BCEWithLogitsLoss()
logits_vec = torch.zeros((n_samples,1)).cuda()
labels_vec = torch.zeros_like(logits_vec).cuda()
train_embeddings = train_embeddings.cuda()
resnet_fc.cuda()

for idx in range(n_samples):
    with torch.no_grad():
        slide = train_embeddings[train_jpgs_to_slide==idx]
        labels_vec[idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(resnet_fc(slide))

In [24]:
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.5487804878048781

In [11]:
criterion = nn.BCEWithLogitsLoss()
criterion(logits_vec, labels_vec) * n_samples / 50

tensor(4.4817, device='cuda:0')

In [None]:
lnr_layer = nn.Linear(2048,2048)
lnr_layer_2 = nn.Linear(2048,1)
relu = nn.ReLU()
layers = [lnr_layer, relu, lnr_layer_2]
linear_layer = nn.Sequential(*layers)

In [3]:
#train_embeddings = train_embeddings.cuda()
net = model_utils.Attention(input_size=2048,hidden_size=2048,output_size=1)
criterion = nn.BCEWithLogitsLoss()
step_size = 50

In [4]:
net.cuda()

Attention(
  (V): Linear(in_features=2048, out_features=2048, bias=True)
  (w): Linear(in_features=2048, out_features=1, bias=True)
  (sigm): Sigmoid()
  (tanh): Tanh()
  (sm): Softmax()
  (linear_layer): Linear(in_features=2048, out_features=1, bias=True)
)

In [5]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(params=net.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=10,min_lr=1e-5)

In [6]:
n_samples_val = val_jpgs_to_slide.max()
n_samples_train = train_jpgs_to_slide.max()
idxs_train = np.linspace(0,n_samples_train,n_samples_train+1,dtype=int)
labels_to_idxs_train = np.concatenate([(train_labels[train_jpgs_to_slide==i]).unique().numpy() for i in idxs_train])
weights = 1/np.sum(labels_to_idxs_train==0),1/np.sum(labels_to_idxs_train==1)
sample_weight = [weights[l] for l in labels_to_idxs_train]
sample_weight = sample_weight/np.sum(sample_weight)

In [7]:
train_embeddings, val_embeddings = train_embeddings.cuda(),val_embeddings.cuda()

for e in range(300):
    train_utils.training_loop_pooled_embeddings(e,step_size,optimizer,net,train_embeddings,train_jpgs_to_slide
                                    ,train_labels,criterion,n_samples_train,sample_weight)
    
    loss,acc = train_utils.validation_loop_pooled_embeddings(e,scheduler,net,val_embeddings,val_jpgs_to_slide
                                    ,val_labels,criterion,n_samples_val)

Epoch: 0, Train NLL: 0.7045
Epoch: 0, Validation NLL: 0.6925, Total Acc: 0.573, Acc by label; diploid:1.000 WGD:0.000
Epoch: 1, Train NLL: 0.6791
Epoch: 1, Validation NLL: 0.7287, Total Acc: 0.549, Acc by label; diploid:0.915 WGD:0.057
Epoch: 2, Train NLL: 0.6506
Epoch: 2, Validation NLL: 0.7252, Total Acc: 0.561, Acc by label; diploid:0.787 WGD:0.257
Epoch: 3, Train NLL: 0.6445
Epoch: 3, Validation NLL: 0.6898, Total Acc: 0.573, Acc by label; diploid:1.000 WGD:0.000
Epoch: 4, Train NLL: 0.6016
Epoch: 4, Validation NLL: 0.7507, Total Acc: 0.610, Acc by label; diploid:0.638 WGD:0.571
Epoch: 5, Train NLL: 0.5620
Epoch: 5, Validation NLL: 0.6955, Total Acc: 0.622, Acc by label; diploid:0.979 WGD:0.143
Epoch: 6, Train NLL: 0.5789
Epoch: 6, Validation NLL: 0.7442, Total Acc: 0.561, Acc by label; diploid:0.596 WGD:0.514
Epoch: 7, Train NLL: 0.5241
Epoch: 7, Validation NLL: 0.7194, Total Acc: 0.634, Acc by label; diploid:0.915 WGD:0.257
Epoch: 8, Train NLL: 0.5269
Epoch: 8, Validation NLL: 0.

Epoch: 69, Train NLL: 0.3034
Epoch: 69, Validation NLL: 0.9952, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 70, Train NLL: 0.3181
Epoch: 70, Validation NLL: 0.9953, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 71, Train NLL: 0.2988
Epoch: 71, Validation NLL: 0.9957, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 72, Train NLL: 0.2940
Epoch: 72, Validation NLL: 0.9968, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 73, Train NLL: 0.2966
Epoch: 73, Validation NLL: 0.9980, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 74, Train NLL: 0.2855
Epoch: 74, Validation NLL: 0.9988, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 75, Train NLL: 0.3153
Epoch: 75, Validation NLL: 0.9992, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 76, Train NLL: 0.2962
Epoch: 76, Validation NLL: 0.9992, Total Acc: 0.524, Acc by label; diploid:0.617 WGD:0.400
Epoch: 77, Train NLL: 0.2870
Epoch: 77, 

Epoch: 137, Train NLL: 0.2949
Epoch: 137, Validation NLL: 1.0277, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 138, Train NLL: 0.2848
Epoch: 138, Validation NLL: 1.0275, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 139, Train NLL: 0.2626
Epoch: 139, Validation NLL: 1.0272, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 140, Train NLL: 0.2669
Epoch: 140, Validation NLL: 1.0271, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 141, Train NLL: 0.2751
Epoch: 141, Validation NLL: 1.0277, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 142, Train NLL: 0.2788
Epoch: 142, Validation NLL: 1.0288, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 143, Train NLL: 0.2979
Epoch: 143, Validation NLL: 1.0296, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 144, Train NLL: 0.3162
Epoch: 144, Validation NLL: 1.0296, Total Acc: 0.549, Acc by label; diploid:0.617 WGD:0.457
Epoch: 145, Train NLL: 0

KeyboardInterrupt: 

In [129]:
idexs = np.random.choice(idxs_train,size=n_samples_train.numpy(),p=sample_weight)

In [130]:
np.mean(labels_to_idxs_train[idexs])

0.5232198142414861

In [118]:
train_labels

tensor([1, 1, 1,  ..., 1, 1, 1])

In [17]:
for idx in range(n_samples):
    with torch.no_grad():
        slide = train_embeddings[train_jpgs_to_slide==idx]
        labels_vec[idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(linear_layer(slide))
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.6049382716049383

In [19]:
del train_embeddings
torch.cuda.empty_cache()
val_embeddings = val_embeddings.cuda()

In [20]:
n_samples = val_jpgs_to_slide.max()+1
logits_vec = torch.zeros((n_samples,1)).cuda()
labels_vec = torch.zeros_like(logits_vec).cuda()
for idx in range(n_samples):
    with torch.no_grad():
        slide = val_embeddings[val_jpgs_to_slide==idx]
        labels_vec[idx] = val_labels[val_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(linear_layer(slide))
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.573170731707317

In [23]:
del val_embeddings
torch.cuda.empty_cache()

In [31]:
train_jpgs_to_slide

tensor([  0,   0,   0,  ..., 323, 323, 323])

In [12]:
[p for p in resnet.bn1.parameters()]

[Parameter containing:
 tensor([ 2.3519e-01,  2.6554e-01, -5.1096e-08,  5.1923e-01,  3.4404e-09,
          2.2344e-01,  4.2637e-01,  1.3153e-07,  2.5157e-01,  1.5152e-06,
          3.1649e-01,  2.5043e-01,  3.7927e-01,  1.0862e-05,  2.7462e-01,
          2.3937e-01,  2.4569e-01,  3.9352e-01,  4.6928e-01,  2.8979e-01,
          2.7206e-01,  2.7792e-01,  2.9089e-01,  2.0875e-01,  2.6028e-01,
          2.7721e-01,  2.9026e-01,  3.1510e-01,  3.8906e-01,  3.0260e-01,
          2.6793e-01,  2.1056e-01,  2.9101e-01,  3.3062e-01,  4.2879e-01,
          3.7352e-01,  7.4804e-08,  1.9000e-01,  1.4740e-08,  2.2349e-01,
          1.8079e-01,  2.4823e-01,  2.7352e-01,  2.5917e-01,  2.9366e-01,
          3.0109e-01,  2.2261e-01,  2.6257e-01,  2.2001e-08,  2.6461e-01,
          2.2089e-01,  2.8305e-01,  3.2998e-01,  2.2688e-01,  3.6608e-01,
          2.1172e-01,  2.3945e-01,  2.4885e-01,  5.2481e-01,  2.4803e-01,
          2.9450e-01,  2.6038e-01,  4.8347e-01,  2.6588e-01]),
 Parameter containing:
 te