In [1]:
import numpy as np
import sys
import os
import copy
from abc import ABC, abstractmethod
import math
import copy
from copy import deepcopy
import PIL
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
import torchvision
import torchvision.models as torchmodels
import torch.nn.functional as F
import openslide
import torch.utils.data

list_pathstoadd = ["../"]
for path in list_pathstoadd:
    if(path not in sys.path):
        sys.path.append(path)
import pydmed
from pydmed.utils.data import *
import pydmed.lightdl
from pydmed.lightdl import *

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
#make dataset (section 1 of tutorial) ===================
rootdir = "../NonGit/Data/"
list_relativedirs = ["1.svs", "2.svs", "3.svs", "4.svs", "5.svs"]
list_relativedirs.sort()
#make a list of patients
list_patients = []
for fname in list_relativedirs:
    new_patient = Patient(\
                    int_uniqueid = list_relativedirs.index(fname),
                    dict_records = \
                      {"H&E":Record(rootdir, fname, {"resolution":"40x"}),\
                       "HER2-status": np.random.randint(0,4)}) #TODO:set real labels
    list_patients.append(new_patient)
#make the dataset
dataset = pydmed.utils.data.Dataset("myHER2dataset", list_patients)

In [4]:
def otsu_getpoint_from_foreground(fname_wsi):
    #settings =======
    scale_thumbnail =  0.01
    width_targetpatch = 5000 
    #extract the foreground =========================
    osimage = openslide.OpenSlide(fname_wsi)
    W, H = osimage.dimensions
    size_thumbnail = (int(scale_thumbnail*W), int(scale_thumbnail*H))
    pil_thumbnail = osimage.get_thumbnail(size_thumbnail)
    np_thumbnail = np.array(pil_thumbnail)
    np_thumbnail = np_thumbnail[:,:,0:3]
    np_thumbnail = rgb2gray(np_thumbnail)
    thresh = threshold_otsu(np_thumbnail)
    background = (np_thumbnail > thresh) + 0.0
    foreground = 1.0 - background
    #apply the padding on foreground
    w_padding_of_thumbnail = int(width_targetpatch * scale_thumbnail)
    foreground[0:w_padding_of_thumbnail, :] = 0
    foreground[-w_padding_of_thumbnail::, :] = 0
    foreground[: , 0:w_padding_of_thumbnail] = 0
    foreground[: , -w_padding_of_thumbnail::] = 0
    #select a random point =========================
    one_indices = np.where(foreground==1.0)
    i_oneindices, j_oneindices = one_indices[0].tolist(), one_indices[1].tolist()
    n = random.choice(range(len(i_oneindices)))
    i_selected, j_selected = i_oneindices[n], j_oneindices[n]
    assert(foreground[i_selected, j_selected] == 1)
    i_selected_realscale, j_selected_realscale =\
        int(i_selected/scale_thumbnail), int(j_selected/scale_thumbnail)
    x, y = j_selected_realscale, i_selected_realscale
    return x,y 
    
class WSIRandomBigchunkLoader(BigChunkLoader):
    @abstractmethod
    def extract_bigchunk(self, last_message_fromroot):
        '''
        Extract and return a bigchunk. 
        Please note that in this function you have access to
        self.patient and self.const_global_info.
        '''
        self.log("in time {} a BigChunk loaded.\n".format(time.time()))
        list_bigchunks = []
        for idx_bigpatch in range(5):
            #settings ==== 
            flag_use_otsu = True
            #===
            wsi = self.patient.dict_records["H&E"]
            fname_wsi = wsi.rootdir + wsi.relativedir
            osimage = openslide.OpenSlide(fname_wsi)
            w, h = 1000, 1000
            W, H = osimage.dimensions
            if(flag_use_otsu == True):
                rand_x, rand_y = otsu_getpoint_from_foreground(fname_wsi)
                rand_x, rand_y = int(rand_x-(w*0.5)), int(rand_y-(h*0.5))
            else:
                rand_x, rand_y = np.random.randint(0, W-w), np.random.randint(0, H-h)
            pil_bigchunk = osimage.read_region([rand_x, rand_y], 0, [w,h])
            np_bigchunk = np.array(pil_bigchunk)[:,:,0:3]
            bigchunk = BigChunk(data=np_bigchunk,\
                                 dict_info_of_bigchunk={"x":rand_x, "y":rand_y},\
                                 patient=self.patient)
            list_bigchunks.append(bigchunk)
        return list_bigchunks

class WSIRandomSmallchunkCollector(SmallChunkCollector):
    def __init__(self, *args, **kwargs):
        #grab privates
        self.tfms_onsmallchunkcollection =\
            torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),\
            torchvision.transforms.Resize((224,224)),\
            torchvision.transforms.ColorJitter(brightness=0,\
                                     contrast=0,\
                                     saturation=0.5,\
                                     hue=[-0.1, 0.1]),\
            torchvision.transforms.ToTensor(),\
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],\
                                   std=[0.229, 0.224, 0.225])
        ])
        super(WSIRandomSmallchunkCollector, self).__init__(*args, **kwargs)
    
    
    @abstractmethod     
    def extract_smallchunk(self, call_count, list_bigchunks, last_message_fromroot):
        '''
        Extract and return a smallchunk. Please note that in this function you have access to 
        self.bigchunk, self.patient, self.const_global_info.
        Inputs:
            - list_bigchunks: the list of extracted bigchunks.
            - Other arguemtns are not needed in this sample notebook.
        '''
        bigchunk = random.choice(list_bigchunks)
        W, H = bigchunk.data.shape[1], bigchunk.data.shape[0]
        w, h = 224, 224
        rand_x, rand_y = np.random.randint(0, W-w), np.random.randint(0, H-h)
        np_smallchunk = bigchunk.data[rand_y:rand_y+h, rand_x:rand_x+w, :]
        #apply the transformation ===========
        if(self.tfms_onsmallchunkcollection != None):
            toret = self.tfms_onsmallchunkcollection(np_smallchunk)
            toret = toret.cpu().detach().numpy() #[3 x 224 x 224]
            toret = np.transpose(toret, [1,2,0]) #[224 x 224 x 3]
        else:
            toret = np_smallchunk
        #wrap in SmallChunk
        smallchunk = SmallChunk(data=toret,\
                                dict_info_of_smallchunk={"x":rand_x, "y":rand_y},\
                                dict_info_of_bigchunk = bigchunk.dict_info_of_bigchunk,\
                                patient=bigchunk.patient)
        return smallchunk        

In [5]:
#make dataloader ================== 
tfms = torchvision.transforms.ToTensor()
const_global_info = {
    "num_bigchunkloaders":5,
    "maxlength_queue_smallchunk":100,
    "maxlength_queue_lightdl":10000,
    "interval_resched": 10,
    "core-assignment":{"lightdl":None,
                       "smallchunkloaders":None,
                       "bigchunkloaders":None}
}
dataloader = LightDL(dataset=dataset,\
                     type_bigchunkloader=WSIRandomBigchunkLoader,\
                     type_smallchunkcollector=WSIRandomSmallchunkCollector,\
                     const_global_info=const_global_info,\
                     batch_size=10, tfms=tfms)

In [6]:
#build the model and optimizer====================
model = torchmodels.resnet18(pretrained=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
model.train()
print("")




In [7]:
#train the model ============================
dataloader.start()
time.sleep(20)
tstart = time.time()
batchcount = 0
while True:
    x, list_patients, list_smallchunks = dataloader.get()
    y = torch.from_numpy(np.array([patient.dict_records['HER2-status']
                                   for patient in list_patients])).to(device)
    batchcount += 1
    optimizer.zero_grad()
    netout = model(x.to(device))
    loss = criterion(netout, y)
    loss.backward()
    if((batchcount%10)==0):
        print("*************  batchcount = {} ************".format(batchcount))
    if(batchcount>200): 
        dataloader.pause_loading()
        break

 loading initial bigchunks (please wait)
 bigchunk 0 from 5

 bigchunk 1 from 5

 bigchunk 2 from 5

 bigchunk 3 from 5

 bigchunk 4 from 5

The initial loading of bigchunks took 22.76441478729248 seconds.
*************  batchcount = 10 ************
*************  batchcount = 20 ************
*************  batchcount = 30 ************
*************  batchcount = 40 ************
*************  batchcount = 50 ************
*************  batchcount = 60 ************
*************  batchcount = 70 ************
*************  batchcount = 80 ************
*************  batchcount = 90 ************
*************  batchcount = 100 ************
*************  batchcount = 110 ************
*************  batchcount = 120 ************
*************  batchcount = 130 ************
*************  batchcount = 140 ************
*************  batchcount = 150 ************
*************  batchcount = 160 ************
*************  batchcount = 170 ************
*************  batchcount = 180 ******