In [1]:
import os
import sys
import torch
import torch.nn as nn
import accimage
from PIL import Image
from imageio import imread
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms, set_image_backend, get_image_backend
import data_utils
import train_utils
import numpy as np
import pandas as pd
import pickle
import torch.nn.functional as F
from collections import Counter

%reload_ext autoreload
%autoreload 2

In [2]:
# https://github.com/pytorch/accimage
set_image_backend('accimage')
get_image_backend()

# set root dir for images
root_dir = '/n/mounted-data-drive/COAD/'

In [3]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(), normalize])

sa_train, sa_val = data_utils.process_WGD_data()

train_set = data_utils.TCGADataset_tiles(sa_train, root_dir, transform=transform)
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=transform)

In [47]:
# set weights for random sampling of tiles such that batches are class balanced
counts = [c[1] for c in sorted(Counter(train_set.all_labels).items())]
counts

[129243, 44701, 36028]

In [48]:
Counter(train_set.all_labels)

Counter({2: 36028, 1: 44701, 0: 129243})

In [51]:
weights = 1.0 / np.array(counts, dtype=float) * 1e3
reciprocal_weights =[]
for index in range(len(train_set)):
    reciprocal_weights.append(weights[train_set.all_labels[index]])
weights

array([0.00773736, 0.02237086, 0.02775619])

In [50]:
reciprocal_weights[0], train_set.all_labels[0]

(0.027756189630287552, 2)

In [35]:
np.argwhere(np.array(train_set.all_labels) == 1)

array([[  2046],
       [  2047],
       [  2048],
       ...,
       [240807],
       [240808],
       [240809]])

In [4]:
batch_size = 128
sampler = torch.utils.data.sampler.WeightedRandomSampler(reciprocal_weights, len(reciprocal_weights), replacement=True)
train_loader = DataLoader(train_set, batch_size=batch_size, pin_memory=True, sampler=sampler, num_workers=12)
valid_loader = DataLoader(val_set, batch_size=batch_size, pin_memory=True, num_workers=12)
len(train_set) / batch_size, len(val_set) / batch_size

(1921.515625, 436.890625)

In [23]:
weights

array([0.00625806, 0.01160631])

In [9]:
wgd_path = 'COAD_WGD_TABLE.xls'
wgd_raw = pd.read_excel(wgd_path)
wgd_raw.head(3)

Unnamed: 0,Sample,Type,AneuploidyScore(AS),AS_del,AS_amp,Genome_doublings,Leuk,Purity,Stroma,Stroma_notLeukocyte,Stroma_notLeukocyte_Floor,SilentMutationspeMb,Non-silentMutationsperMb
0,TCGA-5M-AATE-01,COAD,20,7,13,1,0.080152,0.65,0.35,0.269848,0.269848,1.20409,2.552671
1,TCGA-A6-2683-01,COAD,20,11,9,1,0.012109,0.85,0.15,0.137891,0.137891,1.812046,5.617341
2,TCGA-AA-A01T-01,COAD,22,17,5,1,0.045103,0.71,0.29,0.244897,0.244897,0.683307,2.525263


[(0, 260), (1, 173)]

In [14]:
counts = np.array(list(Counter(wgd_raw.Genome_doublings).values()))
counts / np.sum(counts)

array([0.39953811, 0.60046189])

## WGD Dev

In [None]:
# def process_WGD_data()
root_dir = '/n/mounted-data-drive/COAD/'
wgd_path = 'COAD_WGD_TABLE.xls'
wgd_raw = pd.read_excel(wgd_path)
wgd_raw.head(3)

In [None]:
sample_name = wgd_raw['Sample'][0]
name_len = len(sample_name)
sample_name, name_len

In [None]:
coad_full_name = os.listdir(root_dir)
coad_full_name[0]

In [None]:
coad_img = np.array([v[0:name_len] for v in coad_full_name])
wgd_raw.shape, len(coad_img), coad_img[5], coad_full_name[5]

In [None]:
coad_both = np.intersect1d(coad_img, wgd_raw.Sample)
coad_both.size

In [None]:
sample_names = []
for sample in coad_both:
    key = np.argwhere(coad_img == sample).squeeze()
    if key.size != 0:
        sample_names.append(coad_full_name[key][:-4])

In [None]:
wgd_raw.set_index('Sample', inplace=True)
reorder = np.random.permutation(len(sample_names))
idx = int(np.floor(len(sample_names)*0.8))
train = reorder[:idx]
val = reorder[idx:]

In [None]:
sample_annotations = {}
sample_names = np.array(sample_names)
for sample_name in sample_names[train]:
    sample_annotations[sample_name] = wgd_raw.loc[sample_name[0:name_len], 'Genome_doublings']
sample_annotations_train = sample_annotations

In [None]:
sample_annotations = {}
sample_names = np.array(sample_names)
for sample_name in sample_names[val]:
    sample_annotations[sample_name] = wgd_raw.loc[sample_name[0:name_len], 'Genome_doublings']
sample_annotations_val = sample_annotations
#return sample_annotations_train, sample_annotations_val

## Fat Network

In [None]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Linear(2048,2,bias=True)#resnet18: 2048, resnet50: 8192, resnet152: 8192
resnet.cuda()

In [None]:
learning_rate = 1e-2
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, min_lr=1e-6)

In [None]:
for e in range(10):
    if e % 1 == 0:
        print('---------- LR: {0:0.5f} ----------'.format(optimizer.state_dict()['param_groups'][0]['lr']))
    train_utils.embedding_training_loop(e, train_loader, resnet, criterion, optimizer)
    val_loss = train_utils.embedding_validation_loop(e, valid_loader, resnet, criterion, dataset='Val', scheduler=scheduler)

In [None]:
torch.save(resnet.state_dict(),'test.pt')

## Archive

In [39]:
sa_train, sa_val = data_utils.process_MSI_data()

train_set = data_utils.TCGADataset_tiles(sa_train, root_dir, transform=transform)
val_set = data_utils.TCGADataset_tiles(sa_val, root_dir, transform=transform)

In [41]:
Counter(train_set.all_labels)

Counter({2: 36028, 1: 44701, 0: 129243})

In [42]:
weights = 1.0/np.array(list(Counter(train_set.all_labels).values()),dtype=float)*1e3
weights

array([0.02775619, 0.02237086, 0.00773736])

In [43]:
reciprocal_weights =[]
for index in range(len(train_set)):
    reciprocal_weights.append(weights[train_set.all_labels[index]])

In [45]:
reciprocal_weights[-5:], train_set.all_labels[-5:]

([0.022370864186483524,
  0.022370864186483524,
  0.022370864186483524,
  0.022370864186483524,
  0.022370864186483524],
 [1, 1, 1, 1, 1])