In [1]:
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

set_image_backend('accimage')

In [2]:
pickle_file = '/n/tcga_models/resnet18_WGD_10x_sa.pkl'

with open(pickle_file, 'rb') as f: 
    sa_train, sa_val = pickle.load(f)

In [3]:
root_dir = '/n/mounted-data-drive/COAD/'
magnification = '10.0'
batch_type = 'slide'

In [4]:
train_transform = train_utils.transform_train
train_set = data_utils.TCGADataset_tiles(sa_train, root_dir, transform=train_transform, magnification=magnification, batch_type=batch_type)
train_loader = DataLoader(train_set, batch_size=1, pin_memory=True, num_workers=8)

In [5]:
val_transform = train_utils.transform_validation
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=val_transform, magnification=magnification, batch_type=batch_type)
valid_loader = DataLoader(val_set, batch_size=1, pin_memory=True, num_workers=8)

In [6]:
state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
device = torch.device('cuda', 0)
output_shape = 1

In [7]:
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)

In [8]:
resnet.fc = nn.Linear(2048, 2048, bias=False)
resnet.fc.weight.data = torch.eye(2048)
resnet.cuda(device=device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [9]:
for param in resnet.parameters():
    param.requires_grad = False

In [10]:
slide_level_classification_layer = nn.Linear(2048,1)
slide_level_classification_layer.cuda()

Linear(in_features=2048, out_features=1, bias=True)

In [11]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [12]:
e = 0
learning_rate = 1e-4
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(slide_level_classification_layer.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, min_lr=1e-8)

In [15]:
train_utils.tcga_embedding_training_loop(e, train_loader, resnet, slide_level_classification_layer, criterion, 
                                         optimizer, pool_fn, tile_batch_size=800)

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 256, 256] instead

In [None]:
loss = train_utils.tcga_embedding_validation_loop(e, valid_loader, resnet, slide_level_classification_layer, criterion, 
                                                  pool_fn, tile_batch_size=800, scheduler=scheduler, dataset='Val')