In [1]:
import os
import sys
import torch
import torch.nn as nn
import accimage
from PIL import Image
from imageio import imread
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms, set_image_backend, get_image_backend
import data_utils
import train_utils
import numpy as np
import pandas as pd
import pickle
import torch.nn.functional as F
from collections import Counter

%reload_ext autoreload
%autoreload 2

In [2]:
# https://github.com/pytorch/accimage
set_image_backend('accimage')
get_image_backend()

# set root dir for images
root_dir = '/n/mounted-data-drive/COAD/'

In [3]:
# normalize and tensorify jpegs
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(),
                                normalize])

In [4]:
sa_train, sa_val = data_utils.process_MSI_data()

In [5]:
train_set = data_utils.TCGADataset_tiles(sa_train, root_dir, transform=transform)
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=transform)

In [6]:
# set weights for random sampling of tiles such that batches are class balanced
counts = [c[1] for c in sorted(Counter(train_set.all_labels).items())]
weights = 1.0 / np.array(counts, dtype=float) * 1e3
reciprocal_weights =[]
for index in range(len(train_set)):
    reciprocal_weights.append(weights[train_set.all_labels[index]])

In [7]:
batch_size = 256
sampler = torch.utils.data.sampler.WeightedRandomSampler(reciprocal_weights, len(reciprocal_weights), replacement=False)
train_loader = DataLoader(train_set, batch_size=batch_size, pin_memory=True, sampler=sampler, num_workers=12)
#len(train_set) / batch_size

In [8]:
valid_loader = DataLoader(val_set, batch_size=batch_size, pin_memory=True, num_workers=12)
#len(val_set) / batch_size

## Fat Network

In [9]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Linear(2048,2,bias=True)#resnet18: 2048, resnet50: 8192, resnet152: 8192
resnet.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [10]:
learning_rate = 1e-2
criterion_train = nn.BCEWithLogitsLoss(reduction = 'mean')
criterion_val = nn.BCEWithLogitsLoss(reduction = 'none')
optimizer = torch.optim.Adam(resnet.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, min_lr=1e-6)

In [11]:
jpg_to_sample = val_set.jpg_to_sample

In [14]:
for e in range(1):
    if e % 1 == 0:
        print('---------- LR: {0:0.5f} ----------'.format(optimizer.state_dict()['param_groups'][0]['lr']))
    #train_utils.embedding_training_loop(e, train_loader, resnet, criterion, optimizer)
    val_loss = train_utils.embedding_validation_loop(e, valid_loader, resnet, criterion, jpg_to_sample, dataset='Val', scheduler=scheduler)

---------- LR: 0.01000 ----------
Epoch: 0, Batch: 10, Val NLL: 0.7417
Epoch: 0, Batch: 20, Val NLL: 0.4874
Epoch: 0, Batch: 30, Val NLL: 0.6039
Epoch: 0, Batch: 40, Val NLL: 0.6599
Epoch: 0, Batch: 50, Val NLL: 0.5442
Epoch: 0, Batch: 60, Val NLL: 0.6161
Epoch: 0, Batch: 70, Val NLL: 0.6396
Epoch: 0, Batch: 80, Val NLL: 0.5284
Epoch: 0, Batch: 90, Val NLL: 0.6477
Epoch: 0, Batch: 100, Val NLL: 0.5444
Epoch: 0, Batch: 110, Val NLL: 0.5287
Epoch: 0, Batch: 120, Val NLL: 0.7853
Epoch: 0, Batch: 130, Val NLL: 0.5133
Epoch: 0, Batch: 140, Val NLL: 0.4423
Epoch: 0, Batch: 150, Val NLL: 0.7483
Epoch: 0, Batch: 160, Val NLL: 0.5924
Epoch: 0, Batch: 170, Val NLL: 0.9514
Epoch: 0, Batch: 180, Val NLL: 0.6526
Epoch: 0, Batch: 190, Val NLL: 0.6673
Epoch: 0, Avg Val NLL: 0.6453, Val Tile-Level Acc: 0.3620
Val Tile-Level Acc by Label: 0: 0.3510, 1: 0.6990, 2: 0.0049
Val Slide-Level Acc: Mean-Pooling: 0.3750, Max-Pooling: 0.1111


In [None]:
#torch.save(resnet.state_dict(),'test.pt')