In [1]:
import os
import sys
import torch
import torch.nn as nn
import accimage
from PIL import Image
from imageio import imread
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms, set_image_backend, get_image_backend
import data_utils
import train_utils
import numpy as np
import pandas as pd
import pickle
import torch.nn.functional as F
from collections import Counter

%reload_ext autoreload
%autoreload 2

In [2]:
# https://github.com/pytorch/accimage
set_image_backend('accimage')
get_image_backend()

# set root dir for images
root_dir = '/n/mounted-data-drive/COAD/'

In [3]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(), normalize])

sa_train, sa_val = data_utils.process_WGD_data()

train_set = data_utils.TCGADataset_tiles(sa_train, root_dir, transform=transform)
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=transform)

In [9]:
Counter(list(sa_val.values())), 52/82, 30/82

(Counter({0: 52, 1: 30}), 0.6341463414634146, 0.36585365853658536)

In [4]:
# set weights for random sampling of tiles such that batches are class balanced
counts = [c[1] for c in sorted(Counter(train_set.all_labels).items())]
counts

[147478, 91018]

In [11]:
Counter(val_set.all_labels), 44196/(44196+20837), 20837/(44196+20837)

(Counter({0: 44196, 1: 20837}), 0.6795934371780481, 0.32040656282195196)

In [6]:
weights = 1.0 / np.array(counts, dtype=float) * 1e3
reciprocal_weights =[]
for index in range(len(train_set)):
    reciprocal_weights.append(weights[train_set.all_labels[index]])
weights

array([0.00678067, 0.01098684])

In [7]:
batch_size = 256
sampler = torch.utils.data.sampler.WeightedRandomSampler(reciprocal_weights, len(reciprocal_weights), replacement=True)
train_loader = DataLoader(train_set, batch_size=batch_size, pin_memory=True, sampler=sampler, num_workers=12)
valid_loader = DataLoader(val_set, batch_size=batch_size, pin_memory=True, num_workers=12)
len(train_set) / batch_size, len(val_set) / batch_size

(931.625, 247.578125)

In [8]:
wgd_path = 'COAD_WGD_TABLE.xls'
wgd_raw = pd.read_excel(wgd_path)
wgd_raw.head(3)

Unnamed: 0,Sample,Type,AneuploidyScore(AS),AS_del,AS_amp,Genome_doublings,Leuk,Purity,Stroma,Stroma_notLeukocyte,Stroma_notLeukocyte_Floor,SilentMutationspeMb,Non-silentMutationsperMb
0,TCGA-5M-AATE-01,COAD,20,7,13,1,0.080152,0.65,0.35,0.269848,0.269848,1.20409,2.552671
1,TCGA-A6-2683-01,COAD,20,11,9,1,0.012109,0.85,0.15,0.137891,0.137891,1.812046,5.617341
2,TCGA-AA-A01T-01,COAD,22,17,5,1,0.045103,0.71,0.29,0.244897,0.244897,0.683307,2.525263


In [9]:
counts = np.array(list(Counter(wgd_raw.Genome_doublings).values()))
counts / np.sum(counts)

array([0.39953811, 0.60046189])

## Fat Network

In [10]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Linear(2048,1,bias=True)#resnet18: 2048, resnet50: 8192, resnet152: 8192
resnet.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [11]:
learning_rate = 1e-2
criterion_train = nn.BCEWithLogitsLoss(reduction = 'mean')
criterion_val = nn.BCEWithLogitsLoss(reduction = 'none')
optimizer = torch.optim.Adam(resnet.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, min_lr=1e-6)

In [12]:
jpg_to_sample = val_set.jpg_to_sample

In [17]:
for e in range(1):
    if e % 1 == 0:
        print('---------- LR: {0:0.5f} ----------'.format(optimizer.state_dict()['param_groups'][0]['lr']))
    #train_utils.embedding_training_loop(e, train_loader, resnet, criterion_train, optimizer)
    val_loss = train_utils.embedding_validation_loop(e, valid_loader, resnet, criterion_val, jpg_to_sample, dataset='Val', scheduler=scheduler, task='WGD')

---------- LR: 0.01000 ----------
Epoch: 0, Batch: 100, Val NLL: 1.4290
Epoch: 0, Batch: 200, Val NLL: 0.3196
Epoch: 0, Avg Val NLL: 1.1962, Median Val NLL: 0.4365
------ Val Tile-Level Acc: 0.6253; By Label: 0: 0.9879, 1: 0.0108
------ Val Slide-Level Acc: Mean-Pooling: 0.5926, Max-Pooling: 0.4198


In [None]:
#torch.save(resnet.state_dict(),'test.pt')