In [1]:
import sys
sys.path.append('../')

import os
import gc
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils
import model_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
output_size = 1
state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_size)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
for p in resnet.parameters():
    p.requires_grad = False

In [3]:
pickle_file = '/n/data_labeled_histopathology_images/COAD/train.pkl'
with open(pickle_file, 'rb') as f: 
    train_embeddings,train_labels,train_jpgs_to_slide = pickle.load(f)
    
pickle_file = '/n/data_labeled_histopathology_images/COAD/val.pkl'
with open(pickle_file, 'rb') as f: 
    val_embeddings,val_labels,val_jpgs_to_slide = pickle.load(f)

In [4]:
resnet_fc = resnet.fc

In [5]:
n_samples = train_jpgs_to_slide.max()+1
criterion = nn.BCEWithLogitsLoss()
logits_vec = torch.zeros((n_samples,1)).cuda()
labels_vec = torch.zeros_like(logits_vec).cuda()
train_embeddings = train_embeddings.cuda()
resnet_fc.cuda()

for idx in range(n_samples):
    with torch.no_grad():
        slide = train_embeddings[train_jpgs_to_slide==idx]
        labels_vec[idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(resnet_fc(slide))

In [6]:
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.808641975308642

In [7]:
n_samples = val_jpgs_to_slide.max()+1
criterion = nn.BCEWithLogitsLoss()
logits_vec = torch.zeros((n_samples,1)).cuda()
labels_vec = torch.zeros_like(logits_vec).cuda()
val_embeddings = val_embeddings.cuda()
resnet_fc.cuda()

for idx in range(n_samples):
    with torch.no_grad():
        slide = val_embeddings[val_jpgs_to_slide==idx]
        labels_vec[idx] = val_labels[val_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(resnet_fc(slide))

In [8]:
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.6097560975609756

In [9]:
criterion = nn.BCEWithLogitsLoss()
criterion(logits_vec, labels_vec) * n_samples / 50

tensor(0.9982, device='cuda:0')

In [None]:
lnr_layer = nn.Linear(2048,2048)
lnr_layer_2 = nn.Linear(2048,1)
relu = nn.ReLU()
layers = [lnr_layer, relu, lnr_layer_2]
linear_layer = nn.Sequential(*layers)

In [25]:
#train_embeddings = train_embeddings.cuda()
net = model_utils.Attention(input_size=2048,hidden_size=2048,output_size=1)
criterion = nn.BCEWithLogitsLoss()
step_size = 50

In [26]:
net.cuda()

Attention(
  (V): Linear(in_features=2048, out_features=2048, bias=True)
  (w): Linear(in_features=2048, out_features=1, bias=True)
  (sigm): Sigmoid()
  (tanh): Tanh()
  (sm): Softmax()
  (linear_layer): Linear(in_features=2048, out_features=1, bias=True)
)

In [27]:
learning_rate = 1e-5
optimizer = torch.optim.Adam(params=net.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=3,min_lr=1e-6)

In [28]:
n_samples_val = val_jpgs_to_slide.max()
n_samples_train = train_jpgs_to_slide.max()
idxs_train = np.linspace(0,n_samples_train,n_samples_train+1,dtype=int)
labels_to_idxs_train = np.concatenate([(train_labels[train_jpgs_to_slide==i]).unique().numpy() for i in idxs_train])
weights = 1/np.sum(labels_to_idxs_train==0),1/np.sum(labels_to_idxs_train==1)
sample_weight = [weights[l] for l in labels_to_idxs_train]
sample_weight = sample_weight/np.sum(sample_weight)

In [29]:
train_embeddings, val_embeddings = train_embeddings.cuda(),val_embeddings.cuda()
best_loss = 1e8
best_acc = 0
for e in range(300):
    train_utils.training_loop_pooled_embeddings(e,step_size,optimizer,net,train_embeddings,train_jpgs_to_slide
                                    ,train_labels,criterion,n_samples_train,sample_weight)
    
    loss,acc = train_utils.validation_loop_pooled_embeddings(e,scheduler,net,val_embeddings,val_jpgs_to_slide
                                    ,val_labels,criterion,n_samples_val)
    if loss < best_loss:
        torch.save(net.state_dict(), '/n/tcga_models/COAD_attention_model_reworked_5_8.pt')
        best_loss = loss
    if acc > best_acc:
        best_acc = acc
        torch.save(net.state_dict(), '/n/tcga_models/COAD_attention_model_reworked_5_8_acc.pt')

Epoch: 0, Train NLL: 0.7205
Epoch: 0, Validation NLL: 0.6953, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 1, Train NLL: 0.7094
Epoch: 1, Validation NLL: 0.6912, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 2, Train NLL: 0.6937
Epoch: 2, Validation NLL: 0.6877, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 3, Train NLL: 0.7000
Epoch: 3, Validation NLL: 0.6869, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 4, Train NLL: 0.6924
Epoch: 4, Validation NLL: 0.6865, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 5, Train NLL: 0.6859
Epoch: 5, Validation NLL: 0.6854, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 6, Train NLL: 0.6918
Epoch: 6, Validation NLL: 0.6833, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 7, Train NLL: 0.6911
Epoch: 7, Validation NLL: 0.6817, Total Acc: 0.561, Acc by label; diploid:1.000 WGD:0.000
Epoch: 8, Train NLL: 0.6781
Epoch: 8, Validation NLL: 0.

Epoch: 69, Train NLL: 0.5212
Epoch: 69, Validation NLL: 0.5991, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 70, Train NLL: 0.4857
Epoch: 70, Validation NLL: 0.5981, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 71, Train NLL: 0.4942
Epoch: 71, Validation NLL: 0.5975, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 72, Train NLL: 0.5247
Epoch: 72, Validation NLL: 0.5971, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 73, Train NLL: 0.5084
Epoch: 73, Validation NLL: 0.5972, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 74, Train NLL: 0.4891
Epoch: 74, Validation NLL: 0.5975, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 75, Train NLL: 0.5094
Epoch: 75, Validation NLL: 0.5977, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 76, Train NLL: 0.5332
Epoch: 76, Validation NLL: 0.5971, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 77, Train NLL: 0.5259
Epoch: 77, 

Epoch: 137, Train NLL: 0.4845
Epoch: 137, Validation NLL: 0.5898, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 138, Train NLL: 0.5121
Epoch: 138, Validation NLL: 0.5895, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 139, Train NLL: 0.4511
Epoch: 139, Validation NLL: 0.5891, Total Acc: 0.732, Acc by label; diploid:0.913 WGD:0.500
Epoch: 140, Train NLL: 0.4526
Epoch: 140, Validation NLL: 0.5897, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 141, Train NLL: 0.5023
Epoch: 141, Validation NLL: 0.5906, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 142, Train NLL: 0.4808
Epoch: 142, Validation NLL: 0.5914, Total Acc: 0.744, Acc by label; diploid:0.891 WGD:0.556
Epoch: 143, Train NLL: 0.5039
Epoch: 143, Validation NLL: 0.5923, Total Acc: 0.744, Acc by label; diploid:0.891 WGD:0.556
Epoch: 144, Train NLL: 0.4729
Epoch: 144, Validation NLL: 0.5927, Total Acc: 0.744, Acc by label; diploid:0.891 WGD:0.556
Epoch: 145, Train NLL: 0

Epoch: 204, Validation NLL: 0.5753, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 205, Train NLL: 0.4546
Epoch: 205, Validation NLL: 0.5742, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 206, Train NLL: 0.4725
Epoch: 206, Validation NLL: 0.5740, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 207, Train NLL: 0.4469
Epoch: 207, Validation NLL: 0.5736, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 208, Train NLL: 0.4517
Epoch: 208, Validation NLL: 0.5733, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 209, Train NLL: 0.4605
Epoch: 209, Validation NLL: 0.5729, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 210, Train NLL: 0.4930
Epoch: 210, Validation NLL: 0.5728, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 211, Train NLL: 0.4427
Epoch: 211, Validation NLL: 0.5731, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 212, Train NLL: 0.4440
Epoch: 212, Validation N

Epoch: 272, Train NLL: 0.4138
Epoch: 272, Validation NLL: 0.5622, Total Acc: 0.780, Acc by label; diploid:0.935 WGD:0.583
Epoch: 273, Train NLL: 0.4107
Epoch: 273, Validation NLL: 0.5624, Total Acc: 0.780, Acc by label; diploid:0.935 WGD:0.583
Epoch: 274, Train NLL: 0.4093
Epoch: 274, Validation NLL: 0.5622, Total Acc: 0.780, Acc by label; diploid:0.935 WGD:0.583
Epoch: 275, Train NLL: 0.4125
Epoch: 275, Validation NLL: 0.5625, Total Acc: 0.780, Acc by label; diploid:0.935 WGD:0.583
Epoch: 276, Train NLL: 0.4230
Epoch: 276, Validation NLL: 0.5630, Total Acc: 0.780, Acc by label; diploid:0.935 WGD:0.583
Epoch: 277, Train NLL: 0.4085
Epoch: 277, Validation NLL: 0.5636, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 278, Train NLL: 0.3978
Epoch: 278, Validation NLL: 0.5636, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 279, Train NLL: 0.4382
Epoch: 279, Validation NLL: 0.5638, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 280, Train NLL: 0

In [30]:
for e in range(300,600):
    train_utils.training_loop_pooled_embeddings(e,step_size,optimizer,net,train_embeddings,train_jpgs_to_slide
                                    ,train_labels,criterion,n_samples_train,sample_weight)
    
    loss,acc = train_utils.validation_loop_pooled_embeddings(e,scheduler,net,val_embeddings,val_jpgs_to_slide
                                    ,val_labels,criterion,n_samples_val)
    if loss < best_loss:
        torch.save(net.state_dict(), '/n/tcga_models/COAD_attention_model_reworked_5_8.pt')
        best_loss = loss
    if acc > best_acc:
        best_acc = acc
        torch.save(net.state_dict(), '/n/tcga_models/COAD_attention_model_reworked_5_8_acc.pt')

Epoch: 300, Train NLL: 0.4541
Epoch: 300, Validation NLL: 0.5582, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 301, Train NLL: 0.4867
Epoch: 301, Validation NLL: 0.5569, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 302, Train NLL: 0.4261
Epoch: 302, Validation NLL: 0.5560, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 303, Train NLL: 0.4259
Epoch: 303, Validation NLL: 0.5554, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 304, Train NLL: 0.4213
Epoch: 304, Validation NLL: 0.5549, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 305, Train NLL: 0.4507
Epoch: 305, Validation NLL: 0.5541, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 306, Train NLL: 0.4471
Epoch: 306, Validation NLL: 0.5542, Total Acc: 0.756, Acc by label; diploid:0.913 WGD:0.556
Epoch: 307, Train NLL: 0.4044
Epoch: 307, Validation NLL: 0.5545, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 308, Train NLL: 0

Epoch: 367, Validation NLL: 0.5478, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 368, Train NLL: 0.4078
Epoch: 368, Validation NLL: 0.5474, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 369, Train NLL: 0.4019
Epoch: 369, Validation NLL: 0.5474, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 370, Train NLL: 0.4354
Epoch: 370, Validation NLL: 0.5481, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 371, Train NLL: 0.4184
Epoch: 371, Validation NLL: 0.5485, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 372, Train NLL: 0.4439
Epoch: 372, Validation NLL: 0.5490, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 373, Train NLL: 0.4144
Epoch: 373, Validation NLL: 0.5490, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 374, Train NLL: 0.4145
Epoch: 374, Validation NLL: 0.5479, Total Acc: 0.768, Acc by label; diploid:0.913 WGD:0.583
Epoch: 375, Train NLL: 0.3953
Epoch: 375, Validation N

KeyboardInterrupt: 

In [130]:
np.mean(labels_to_idxs_train[idexs])

0.5232198142414861

In [118]:
train_labels

tensor([1, 1, 1,  ..., 1, 1, 1])

In [17]:
for idx in range(n_samples):
    with torch.no_grad():
        slide = train_embeddings[train_jpgs_to_slide==idx]
        labels_vec[idx] = train_labels[train_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(linear_layer(slide))
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.6049382716049383

In [19]:
del train_embeddings
torch.cuda.empty_cache()
val_embeddings = val_embeddings.cuda()

In [20]:
n_samples = val_jpgs_to_slide.max()+1
logits_vec = torch.zeros((n_samples,1)).cuda()
labels_vec = torch.zeros_like(logits_vec).cuda()
for idx in range(n_samples):
    with torch.no_grad():
        slide = val_embeddings[val_jpgs_to_slide==idx]
        labels_vec[idx] = val_labels[val_jpgs_to_slide==idx].unique().float().cuda()
        logits_vec[idx] = torch.mean(linear_layer(slide))
np.mean(labels_vec.cpu().numpy() == (logits_vec>0.5).float().cpu().numpy())

0.573170731707317

In [23]:
del val_embeddings
torch.cuda.empty_cache()

In [31]:
train_jpgs_to_slide

tensor([  0,   0,   0,  ..., 323, 323, 323])

In [12]:
[p for p in resnet.bn1.parameters()]

[Parameter containing:
 tensor([ 2.3519e-01,  2.6554e-01, -5.1096e-08,  5.1923e-01,  3.4404e-09,
          2.2344e-01,  4.2637e-01,  1.3153e-07,  2.5157e-01,  1.5152e-06,
          3.1649e-01,  2.5043e-01,  3.7927e-01,  1.0862e-05,  2.7462e-01,
          2.3937e-01,  2.4569e-01,  3.9352e-01,  4.6928e-01,  2.8979e-01,
          2.7206e-01,  2.7792e-01,  2.9089e-01,  2.0875e-01,  2.6028e-01,
          2.7721e-01,  2.9026e-01,  3.1510e-01,  3.8906e-01,  3.0260e-01,
          2.6793e-01,  2.1056e-01,  2.9101e-01,  3.3062e-01,  4.2879e-01,
          3.7352e-01,  7.4804e-08,  1.9000e-01,  1.4740e-08,  2.2349e-01,
          1.8079e-01,  2.4823e-01,  2.7352e-01,  2.5917e-01,  2.9366e-01,
          3.0109e-01,  2.2261e-01,  2.6257e-01,  2.2001e-08,  2.6461e-01,
          2.2089e-01,  2.8305e-01,  3.2998e-01,  2.2688e-01,  3.6608e-01,
          2.1172e-01,  2.3945e-01,  2.4885e-01,  5.2481e-01,  2.4803e-01,
          2.9450e-01,  2.6038e-01,  4.8347e-01,  2.6588e-01]),
 Parameter containing:
 te