In [1]:
import numpy as np
import sys
import os
import copy
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
from abc import ABC, abstractmethod
import math
import copy
from copy import deepcopy
import PIL
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
import torchvision
import torchvision.models as torchmodels
import torch.nn.functional as F
import openslide
import torch.utils.data

list_pathstoadd = ["../"]
for path in list_pathstoadd:
    if(path not in sys.path):
        sys.path.append(path)
import pydmed
from pydmed.utils.data import *
import pydmed.lightdl
from pydmed.lightdl import *

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
#make dataset (section 1 of tutorial) ===================
rootdir = "../NonGit/Data/"
list_relativedirs = ["1.svs", "2.svs", "3.svs", "4.svs", "5.svs"]
list_relativedirs.sort()
#make a list of patients
list_patients = []
for fname in list_relativedirs:
    new_patient = Patient(\
                    int_uniqueid = list_relativedirs.index(fname),
                    dict_records = \
                      {"H&E":Record(rootdir, fname, {"resolution":"40x"}),\
                       "HER2-status": np.random.randint(0,4)}) #TODO:set real labels
    list_patients.append(new_patient)
#make the dataset
dataset = pydmed.utils.data.Dataset("myHER2dataset", list_patients)

In [4]:
def otsu_getpoint_from_foreground(fname_wsi):
    #settings =======
    scale_thumbnail =  0.01
    width_targetpatch = 5000 
    #extract the foreground =========================
    osimage = openslide.OpenSlide(fname_wsi)
    W, H = osimage.dimensions
    size_thumbnail = (int(scale_thumbnail*W), int(scale_thumbnail*H))
    pil_thumbnail = osimage.get_thumbnail(size_thumbnail)
    np_thumbnail = np.array(pil_thumbnail)
    np_thumbnail = np_thumbnail[:,:,0:3]
    np_thumbnail = rgb2gray(np_thumbnail)
    thresh = threshold_otsu(np_thumbnail)
    background = (np_thumbnail > thresh) + 0.0
    foreground = 1.0 - background
    #apply the padding on foreground
    w_padding_of_thumbnail = int(width_targetpatch * scale_thumbnail)
    foreground[0:w_padding_of_thumbnail, :] = 0
    foreground[-w_padding_of_thumbnail::, :] = 0
    foreground[: , 0:w_padding_of_thumbnail] = 0
    foreground[: , -w_padding_of_thumbnail::] = 0
    #select a random point =========================
    one_indices = np.where(foreground==1.0)
    i_oneindices, j_oneindices = one_indices[0].tolist(), one_indices[1].tolist()
    n = random.choice(range(len(i_oneindices)))
    i_selected, j_selected = i_oneindices[n], j_oneindices[n]
    assert(foreground[i_selected, j_selected] == 1)
    #print("i,j selected = [{},{}]".format(i_selected, j_selected))
    i_selected_realscale, j_selected_realscale =\
        int(i_selected/scale_thumbnail), int(j_selected/scale_thumbnail)
    x, y = j_selected_realscale, i_selected_realscale
    return x,y 
    
class WSIRandomBigchunkLoader(BigChunkLoader):
    @abstractmethod
    def extract_bigchunk(self, last_message_fromroot):
        '''
        Extract and return a bigchunk. 
        Please note that in this function you have access to
        self.patient and self.const_global_info.
        '''
        #read `idx_bigchunk` from checkpoint =======
        checkpoint = self.get_checkpoint()
        if(checkpoint == None):
            idx_bigchunk = 0
        else:
            idx_bigchunk = checkpoint["checkpoint_for_bigchunk"]
        #extract bigchunk =======
        wsi = self.patient.dict_records["H&E"]
        fname_wsi = wsi.rootdir + wsi.relativedir
        osimage = openslide.OpenSlide(fname_wsi)
        w, h = 2000, 2000
        W, H = osimage.dimensions
        x, y = self.get_bigchunk_position(idx_bigchunk, W, H)
               #this function is implemented below
        pil_bigchunk = osimage.read_region([x, y], 0, [w,h])
        np_bigchunk = np.array(pil_bigchunk)[:,:,0:3]
        bigchunk = BigChunk(data=np_bigchunk,\
                            dict_info_of_bigchunk={"x":x, "y":y},\
                            patient=self.patient)
        return bigchunk
    
    def get_bigchunk_position(self, idx_bigchunk, W, H):
        O = np.array([W/2.0, H/2.0])
        theta = idx_bigchunk*10.0*math.pi/180.0
        r = 10000.0
        x = int(W*0.5 + r*math.cos(theta))
        y = int(H*0.5 + r*math.sin(theta))
        return x, y
        
    
class WSIRandomSmallchunkCollector(SmallChunkCollector):
    def __init__(self, *args, **kwargs):
        #grab privates
        self.tfms_onsmallchunkcollection =\
            torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),\
            torchvision.transforms.Resize((224,224)),\
            torchvision.transforms.ColorJitter(brightness=0,\
                                     contrast=0,\
                                     saturation=0.5,\
                                     hue=[-0.1, 0.1]),\
            torchvision.transforms.ToTensor(),\
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],\
                                   std=[0.229, 0.224, 0.225])
        ])
        super(WSIRandomSmallchunkCollector, self).__init__(*args, **kwargs)
    
    
    @abstractmethod     
    def extract_smallchunk(self, call_count, bigchunk, last_message_fromroot):
        '''
        Extract and return a smallchunk. Please note that in this function you have access to 
        self.bigchunk, self.patient, self.const_global_info.
        Inputs:
            - bigchunk: the returned bigchunk.
        '''
        if(call_count == 0): #if the `SmallChunkCollector` has just started
            old_checkpoint = self.get_checkpoint()
            if(old_checkpoint == None):
                #It is the first time that a `SmallChunkCollector`
                #is loaded for the `self.patient`
                new_checkpoint = {"checkpoint_for_bigchunk":1}
                self.set_checkpoint(new_checkpoint)
            else:
                new_checkpoint = {"checkpoint_for_bigchunk":\
                                     old_checkpoint["checkpoint_for_bigchunk"]+1}
                self.set_checkpoint(new_checkpoint)
        W, H = bigchunk.data.shape[1], bigchunk.data.shape[0]
        w, h = 224, 224
        rand_x, rand_y = np.random.randint(0, W-w), np.random.randint(0, H-h)
        np_smallchunk = bigchunk.data[rand_y:rand_y+h, rand_x:rand_x+w, :]
        #apply the transformation ===========
        if(self.tfms_onsmallchunkcollection != None):
            toret = self.tfms_onsmallchunkcollection(np_smallchunk)
            toret = toret.cpu().detach().numpy() #[3 x 224 x 224]
            toret = np.transpose(toret, [1,2,0]) #[224 x 224 x 3]
        else:
            toret = np_smallchunk
        #wrap in SmallChunk
        smallchunk = SmallChunk(data=toret,\
                                dict_info_of_smallchunk={"x":rand_x, "y":rand_y},\
                                dict_info_of_bigchunk = bigchunk.dict_info_of_bigchunk,\
                                patient=bigchunk.patient)
        #time.sleep(0.1) #TODO:REMOVE
#         print("===== in extract_smallchunk, shape = {} ================"
#               .format(smallchunk.data.shape))
        return smallchunk
    
def visualize_one_patient(patient, list_smallchunks):
    '''
    Given all smallchunks collected for a specific patient, this function
    should visualize the patient. 
    Inputs:
        - patient: the patient under considerations, an instance of `utils.data.Patient`.
        - list_smallchunks: the list of all collected small chunks for the patient,
            a list whose elements are an instance of `lightdl.SmallChunk`.
    '''
    #settings =======
    vis_scale = 0.01 #=====
    fname_wsi = patient.dict_records["H&E"].rootdir + patient.dict_records["H&E"].relativedir
    opsimage = openslide.OpenSlide(fname_wsi)
    W, H = opsimage.dimensions
    opsimageW, opsimageH = opsimage.dimensions
    W, H = int(W*vis_scale), int(H*vis_scale)
    pil_thumbnail = opsimage.get_thumbnail((W,H))
    plt.ioff()
    fig, ax = plt.subplots(1,2, figsize=(2*10,10))
    ax[0].imshow(pil_thumbnail)
    ax[0].axis('off')
    ax[0].set_title("patient {}, H&E [{} x {}]."\
                    .format(patient.int_uniqueid, opsimageW, opsimageH))
    ax = ax[1]
    ax.imshow(pil_thumbnail)
    ax.axis('off')
    print("Visualizing patient {} with {} smallchunks"\
          .format(patient, len(list_smallchunks)))
    list_colors = ['lawngreen', 'cyan', 'gold', 'greenyellow']
    list_shownbigchunks = []
    for smallchunk in list_smallchunks:
        #show the bigchunk ================
        x = smallchunk.dict_info_of_bigchunk["x"]
        y = smallchunk.dict_info_of_bigchunk["y"]
        x, y = int(x*vis_scale), int(y*vis_scale)
        if(not([x,y] in list_shownbigchunks)):
            w, h = int(2000*vis_scale), int(2000*vis_scale)
            rect = patches.Rectangle((x,y), w, h, linewidth=1,\
                                      linestyle="--",\
                                      edgecolor=random.choice(list_colors),\
                                      facecolor='none', fill=False)
            ax.add_patch(rect)
            list_shownbigchunks.append([x,y])
        
        #get x,y,w,h ======
        x = smallchunk.dict_info_of_smallchunk["x"]*vis_scale +\
            smallchunk.dict_info_of_bigchunk["x"]*vis_scale
        y = smallchunk.dict_info_of_smallchunk["y"]*vis_scale +\
            smallchunk.dict_info_of_bigchunk["y"]*vis_scale
        x, y = int(x), int(y)
        w, h = int(224*vis_scale), int(224*vis_scale)
        x_centre, y_centre = int(x+0.5*w), int(y+0.5*h)
        #make-show the rect =====
        circle = patches.Circle((x_centre, y_centre), radius=w*0.05,\
                                 facecolor=random.choice(list_colors),\
                                 fill=True)
        ax.add_patch(circle)
    plt.title("patient {} (extracted big/small chunks)".format(patient.int_uniqueid), fontsize=20)
    plt.savefig("Sample_2_Output/patient_{}.eps"\
                .format(patient.int_uniqueid), bbox_inches='tight',  format='eps')
    plt.close(fig)
    

In [5]:
#make dataloader ================== 
tfms = torchvision.transforms.ToTensor()
const_global_info = {
    "num_bigchunkloaders":3,
    "maxlength_queue_smallchunk":200,
    "maxlength_queue_lightdl":10000,
    "interval_resched": 5,
    "core-assignment":{"lightdl":None,
                       "smallchunkloaders":None,
                       "bigchunkloaders":None}
}
dataloader = LightDL(dataset=dataset,\
                     type_bigchunkloader=WSIRandomBigchunkLoader,\
                     type_smallchunkcollector=WSIRandomSmallchunkCollector,\
                     const_global_info=const_global_info,\
                     batch_size=10, tfms=tfms)

In [6]:
#build the model and optimizer====================
model = torchmodels.resnet18(pretrained=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
model.train()
print("")




In [7]:
#train the model ============================
dataloader.start()
time.sleep(20)
tstart = time.time()
batchcount = 0
while True:
    x, list_patients, list_smallchunks = dataloader.get()
    y = torch.from_numpy(np.array([patient.dict_records['HER2-status']
                                   for patient in list_patients])).to(device)
    batchcount += 1
    optimizer.zero_grad()
    netout = model(x.to(device))
    loss = criterion(netout, y)
    loss.backward()
    if((batchcount%10)==0):
        print("*************  batchcount = {} ************".format(batchcount))
    if(batchcount>10000): 
        dataloader.pause_loading()
        break

 loading initial bigchunks, please wait ....
  bigchunk 0 from 3. please wait ...

  bigchunk 1 from 3. please wait ...

  bigchunk 2 from 3. please wait ...

The initial loading of bigchunks took 3.431326150894165 seconds.
*************  batchcount = 10 ************
*************  batchcount = 20 ************
*************  batchcount = 30 ************
*************  batchcount = 40 ************
*************  batchcount = 50 ************
*************  batchcount = 60 ************
*************  batchcount = 70 ************
*************  batchcount = 80 ************
*************  batchcount = 90 ************
*************  batchcount = 100 ************
*************  batchcount = 110 ************
*************  batchcount = 120 ************
*************  batchcount = 130 ************
*************  batchcount = 140 ************
*************  batchcount = 150 ************
*************  batchcount = 160 ************
*************  batchcount = 170 ************
*************  batch

*************  batchcount = 1770 ************
*************  batchcount = 1780 ************
*************  batchcount = 1790 ************
*************  batchcount = 1800 ************
*************  batchcount = 1810 ************
*************  batchcount = 1820 ************
*************  batchcount = 1830 ************
*************  batchcount = 1840 ************
*************  batchcount = 1850 ************
*************  batchcount = 1860 ************
*************  batchcount = 1870 ************
*************  batchcount = 1880 ************
*************  batchcount = 1890 ************
*************  batchcount = 1900 ************
*************  batchcount = 1910 ************
*************  batchcount = 1920 ************
*************  batchcount = 1930 ************
*************  batchcount = 1940 ************
*************  batchcount = 1950 ************
*************  batchcount = 1960 ************
*************  batchcount = 1970 ************
*************  batchcount = 1980 *

*************  batchcount = 3560 ************
*************  batchcount = 3570 ************
*************  batchcount = 3580 ************
*************  batchcount = 3590 ************
*************  batchcount = 3600 ************
*************  batchcount = 3610 ************
*************  batchcount = 3620 ************
*************  batchcount = 3630 ************
*************  batchcount = 3640 ************
*************  batchcount = 3650 ************
*************  batchcount = 3660 ************
*************  batchcount = 3670 ************
*************  batchcount = 3680 ************
*************  batchcount = 3690 ************
*************  batchcount = 3700 ************
*************  batchcount = 3710 ************
*************  batchcount = 3720 ************
*************  batchcount = 3730 ************
*************  batchcount = 3740 ************
*************  batchcount = 3750 ************
*************  batchcount = 3760 ************
*************  batchcount = 3770 *

*************  batchcount = 5350 ************
*************  batchcount = 5360 ************
*************  batchcount = 5370 ************
*************  batchcount = 5380 ************
*************  batchcount = 5390 ************
*************  batchcount = 5400 ************
*************  batchcount = 5410 ************
*************  batchcount = 5420 ************
*************  batchcount = 5430 ************
*************  batchcount = 5440 ************
*************  batchcount = 5450 ************
*************  batchcount = 5460 ************
*************  batchcount = 5470 ************
*************  batchcount = 5480 ************
*************  batchcount = 5490 ************
*************  batchcount = 5500 ************
*************  batchcount = 5510 ************
*************  batchcount = 5520 ************
*************  batchcount = 5530 ************
*************  batchcount = 5540 ************
*************  batchcount = 5550 ************
*************  batchcount = 5560 *

*************  batchcount = 7140 ************
*************  batchcount = 7150 ************
*************  batchcount = 7160 ************
*************  batchcount = 7170 ************
*************  batchcount = 7180 ************
*************  batchcount = 7190 ************
*************  batchcount = 7200 ************
*************  batchcount = 7210 ************
*************  batchcount = 7220 ************
*************  batchcount = 7230 ************
*************  batchcount = 7240 ************
*************  batchcount = 7250 ************
*************  batchcount = 7260 ************
*************  batchcount = 7270 ************
*************  batchcount = 7280 ************
*************  batchcount = 7290 ************
*************  batchcount = 7300 ************
*************  batchcount = 7310 ************
*************  batchcount = 7320 ************
*************  batchcount = 7330 ************
*************  batchcount = 7340 ************
*************  batchcount = 7350 *

*************  batchcount = 8930 ************
*************  batchcount = 8940 ************
*************  batchcount = 8950 ************
*************  batchcount = 8960 ************
*************  batchcount = 8970 ************
*************  batchcount = 8980 ************
*************  batchcount = 8990 ************
*************  batchcount = 9000 ************
*************  batchcount = 9010 ************
*************  batchcount = 9020 ************
*************  batchcount = 9030 ************
*************  batchcount = 9040 ************
*************  batchcount = 9050 ************
*************  batchcount = 9060 ************
*************  batchcount = 9070 ************
*************  batchcount = 9080 ************
*************  batchcount = 9090 ************
*************  batchcount = 9100 ************
*************  batchcount = 9110 ************
*************  batchcount = 9120 ************
*************  batchcount = 9130 ************
*************  batchcount = 9140 *

In [8]:
dataloader.visualize(visualize_one_patient)

Visualizing patient utils.data.Patient with unique id: 0 with 22816 smallchunks
Visualizing patient utils.data.Patient with unique id: 1 with 17504 smallchunks
Visualizing patient utils.data.Patient with unique id: 2 with 16600 smallchunks
Visualizing patient utils.data.Patient with unique id: 3 with 22213 smallchunks
Visualizing patient utils.data.Patient with unique id: 4 with 20877 smallchunks
