In [1]:
import os
import sys
import torch
import accimage
import numpy as np
import pandas as pd
from PIL import Image
from imageio import imread
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms, set_image_backend, get_image_backend
import torch.nn.functional as F
import pickle
import train_utils
import data_utils
import os
import gc
import sys
import torch
import psutil
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
from sklearn import metrics
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import models, set_image_backend
set_image_backend('accimage')
import data_utils
import train_utils
root_dir_coad = '/n/mounted-data-drive/COAD/'


In [2]:
sa_train, sa_val = data_utils.load_COAD_train_val_sa_pickle()
root_dir = '/n/mounted-data-drive/COAD/'
magnification = '10.0'
batch_type = 'slide'

train_transform = train_utils.transform_train


In [3]:
class TCGADataset_tiled_slides(Dataset):
    """
    TCGA slide dataset. Each slide is linked to its tiles via a label.
    """
    def __init__(self, sample_annotations, root_dir, transform=None, loader=data_utils.default_loader, magnification='5.0'):
        """
        Args:
            sample_annot (dict): dictionary of sample names and their respective labels.
            root_dir (string): directory containing all of the samples and their respective images.
            transform (callable, optional): optional transform to be applied on the images of a sample.
            loader specifies image backend: use accimage
            magnification: tile magnification
        """
        self.sample_names = list(sample_annotations.keys())
        self.sample_labels = list(sample_annotations.values())
        
        self.root_dir = root_dir
        self.transform = transform
        self.loader = loader
        self.magnification = magnification
        self.batch_type = batch_type
        self.img_dirs = [self.root_dir + sample_name + '.svs/' \
                         + sample_name + '_files/' + self.magnification for sample_name in self.sample_names]
        self.jpegs = [os.listdir(img_dir) for img_dir in self.img_dirs]
        self.all_jpegs = []
        self.all_labels = []
        self.jpg_to_sample = []
        self.coords = []
        
        for idx,(im_dir,label,l) in enumerate(zip(self.img_dirs,self.sample_labels,self.jpegs)):
            sample_coords = []
            for jpeg in l:
                # build tile dataset
                self.all_jpegs.append(im_dir+'/'+jpeg)
                # label for each tile
                self.all_labels.append(label)
                # tracks slide membership at a tile level
                self.jpg_to_sample.append(idx)
                # store tile coordinates
                x,y = jpeg[:-5].split('_') # 'X_Y.jpeg'
                x,y = int(x), int(y)
                self.coords.append(torch.tensor([x,y]))
            
    def __len__(self):
        return len(self.all_jpegs)
        
    def __getitem__(self, idx):
        image = self.loader(self.all_jpegs[idx])
        if self.transform is not None:
            image = self.transform(image)
        if image.shape[1] < 256 or image.shape[2] < 256:
            image = data_utils.pad_tensor_up_to(image,256,256,channels_last=False)
        return image, self.all_labels[idx], self.coords[idx],self.jpg_to_sample[idx]

In [4]:
train_set = TCGADataset_tiled_slides(sa_train, root_dir, transform=train_transform, magnification='5.0')
train_loader = DataLoader(train_set, batch_size=512, pin_memory=True, num_workers=20)

In [5]:
state_dict_file = '/n/tcga_models/resnet18_WGD_10x.pt'
device = torch.device('cuda', 0)
output_shape = 1

In [6]:
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(2048, output_shape, bias=True)
saved_state = torch.load(state_dict_file, map_location=lambda storage, loc: storage)
resnet.load_state_dict(saved_state)
resnet.fc = nn.Linear(2048, 2048, bias=False)
resnet.fc.weight.data = torch.eye(2048)
resnet.cuda(device=device)
for param in resnet.parameters():
    param.requires_grad = False

Linear(in_features=2048, out_features=1, bias=True)

In [10]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

class pool_and_classify(nn.Module):
    def __init__(self, pool_fn, n_inputs, n_outputs):
        super(pool_and_classify, self).__init__()
        self.fc = nn.Linear(n_inputs,n_outputs)
        self.pool = pool_fn
    def forward(self,h):
        pooled = self.pool(h)
        logits = self.fc(pooled)
        return logits
    
slide_level_classification_layer = pool_and_classify(pool_fn,2048,1)
slide_level_classification_layer.cuda()

pool_and_classify(
  (fc): Linear(in_features=2048, out_features=1, bias=True)
)

In [11]:
e = 0
learning_rate = 1e-4
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(slide_level_classification_layer.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, min_lr=1e-8)

In [12]:
# track number of slides seen
p_update = torch.tensor(0.,device=device)
# store embeddings, labels, and memberships
embeddings = []
slide_labels = []
slide_membership = []
step = 0
batches = 0

for batch,labels,coords,idxs in train_loader:
    # get embeddings
    batch,labels,coords,idxs = batch.cuda(),labels.cuda(),coords.cuda(),idxs.cuda()
    if len(embeddings) == 0:
        current_slide = torch.min(idxs)
    # append each batched results
    embeddings.extend(resnet(batch))
    slide_membership.extend(idxs)
    slide_labels.extend(labels)
    
    p_update += 0.0001
    batches+=1
    
    if torch.rand(1,device=device) < p_update:
        slide_membership = torch.stack(slide_membership)
        slides = list(set(torch.squeeze(slide_membership).detach().cpu().numpy()))
        loss = torch.tensor(0.,device=device)
        embeddings = torch.stack(embeddings)
        slide_labels = torch.stack(slide_labels)
        for slide in slides:
            logits = slide_level_classification_layer(embeddings[slide_membership == slide,:])
            loss += criterion(logits,slide_labels[slide_membership == slide][0].float().view(-1))
        break    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()    
        embeddings = []
        slide_labels = []
        print(step,batches,torch.max(slide_membership).detach().cpu().numpy())
        slide_membership = []
        step+=1
        batches = 0
        p_update = torch.tensor(0.,device=device)
        
slide_membership = torch.stack(slide_membership)
slides = list(set(torch.squeeze(slide_membership).detach().cpu().numpy()))
loss = torch.tensor(0.,device=device)
embeddings = torch.stack(embeddings)
slide_labels = torch.stack(slide_labels)
for slide in slides:
    pooled = pool_fn(embeddings[slide_membership == slide,:])
    logits = slide_level_classification_layer(pooled)
    loss += criterion(logits,slide_labels[slide_membership == slide][0].float().view(-1))

loss.backward()
optimizer.step()
optimizer.zero_grad()        

RuntimeError: both arguments to matmul need to be at least 1D, but they are 0D and 2D

In [39]:
list_em = [torch.masked_select(embeddings.transpose(1,0),slide_membership == slide).view(-1,2048) for slide in slides]

In [40]:
list_em[0]

tensor([[0.2498, 0.0647, 1.1720,  ..., 0.4731, 1.7007, 0.9100],
        [1.3068, 0.0257, 0.3159,  ..., 0.6631, 0.3141, 0.8234],
        [0.9543, 0.3433, 1.4579,  ..., 1.2796, 1.2526, 0.2178],
        ...,
        [1.1694, 0.3554, 1.0797,  ..., 1.0848, 1.6654, 0.1915],
        [0.0672, 0.7309, 2.1181,  ..., 1.2691, 1.3189, 0.8123],
        [1.8138, 0.7799, 0.5544,  ..., 0.1573, 0.4847, 1.3014]],
       device='cuda:0')

In [41]:
embeddings[slide_membership == slide,:]

tensor([[0.2498, 0.2957, 0.2364,  ..., 1.2197, 1.5712, 1.4878],
        [0.0647, 0.0634, 0.0708,  ..., 0.5148, 0.9572, 0.6616],
        [1.1720, 1.5456, 1.1372,  ..., 0.9125, 1.2380, 1.0903],
        ...,
        [0.2094, 0.2096, 0.2409,  ..., 0.1367, 0.1567, 0.1573],
        [0.0420, 0.0420, 0.0478,  ..., 0.4847, 0.6013, 0.4847],
        [0.5484, 0.5249, 0.3362,  ..., 1.4218, 1.4458, 1.3014]],
       device='cuda:0')

In [21]:
torch.sum(slide_membership == slide)

tensor(571, device='cuda:0')