<a href="https://colab.research.google.com/github/SNUH-AIeducation/SNUH-AI-Education-for-Clinicians/blob/master/Modality/WSI/patho_practice_pytorch_SNUH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Segmentation task in digital pathology**
디지털 병리 영상을 이용한 딥러닝 기반 segmentation 모델 구축을 진행합니다.  
실습에 앞서 필요한 데이터와 라이브러리를 내려받습니다.  
[런타임]-[런타임 유형 변경]에서 하드웨어 가속기를 GPU로 변경합니다.  
라이브러리를 모두 받은 후에 꼭 **[런타임]-[런타임 다시 시작]**을 눌러주세요.

In [None]:
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1MMZMfZc1MBW3jjiRlS_Kc6XNgc7VNmeX' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1MMZMfZc1MBW3jjiRlS_Kc6XNgc7VNmeX" -O wsi.zip && rm -rf ~/cookies.txt

In [None]:
!unzip wsi.zip -d ./wsi

In [None]:
!apt-get install python3-openslide

In [None]:
!pip install --force-reinstall albumentations==1.0.3

In [None]:
!pip install segmentation-models-pytorch

In [None]:
!pip install imagecodecs

## Opening

- Step 1. HuBMAP - Hacking the kidney
    - Step 1-1. Dataset specifications
    - Step 1-2. Dataset structure
        - Step 1-2-1. Whole Slide Image(.tiff)
        - Step 1-2-2. Anatomical structure(.json)
        - Step 1-2-3. Glomeruli label(.json, RLE)
- Step 2. 데이터 전처리: WSI patch generation
- Step 3. 데이터셋 구축
- Step 4. 모델 학습
- Step 5. 평가
    - Step 5-1. Patch-wise validation
    - Step 5-2. Whole image-wise validation
    - Step 5-3. IoU score

## **Step 1. HuBMAP - Hacking the kidney**
본 실습에서는 kaggle의 'HuBMAP Glomeruli FTU Segmentation Dataset'을 활용해  
디지털 병리 영상에서 `Glomeruli`를 분할하는 segmentation model 구축을 진행합니다.  
https://www.kaggle.com/c/hubmap-kidney-segmentation

![hubmap_results](https://www.kaggleusercontent.com/kf/48229620/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..0EfFqtR2KRpJ4P7fgHoVIg.Z78v2tme4qFKDCFJlrIjw4IUra9v5OoE0AyRZUQ5Sl6DGoCqpMSWDLrU3bxR1zaaS7a_0fRB58gQ7XE75Vk9vlTUO8qaVk8Exy3JW8th1wQC0hlgp1PyRhVfsbWobUv2-9qyTK1Ww5o287PAPh5HROW6xMpvBSyE6WxJUiX9lxWdoRjO9ttJBoMeMio7k_DH0N3Vwr1dCdTt1qpKIns0EACWj7Kx6qWoS9F4SuQlKoj4lQMFnI8p-uGE3S5pDCApbGpcUrRdLvnrL9N1XQbB2gXfnmndMtrFwt6-LDGWdl42bJlzq9Hur5RvDUXvMmWhq8SAb492xX-y-wIm51BBKv8v0lUvwfspt6K64BWggr-c60a8RIBE7ztQVnyDqgJEoPMd8doMf0U8Apjj_ifWSA8-LQvQklrYCLJHGrRxCN4DoEKhcqZ2RZ3tA9r86hQQr0wTLpb9XIyb8w6CJsmd11QF4XUdCl6N_rpdfZcO9xAmUwemB40iTNG_jbMH8GxDTwO0IGlPIfbo3Nv5ay4tfKtIAhdGqt9VkHoG-2vmnX23aL3Of_xLFgHU48HHFnWF7PpbM-qA201oab0UbOPn5_n0tfB1AAV0EoDwLY9w9wqyByTOpomTp_bjkddyKN8_7-CXSpFNs6TwWaG6Aow5p2TJ9kYWx2ZuInscM1Q2NOU.PJl_tL1VktTAUjTWPQUY-Q/__results___files/__results___24_0.png)

reference: https://www.kaggle.com/ihelon/hubmap-exploratory-data-analysis

## Step 1-1. Dataset specifications
Human BioMolecular Atlas Program(HuBMAP)이 NIH의 후원을 통해 개최한 위 competition은  
병리 영상에서 `functional tissue units (FTUs)`을 분할하는 것을 목표로 합니다.  
본 challenge의 설명에 따르면 조직 내의 FTU 크기, 모양, 수, 위치 정보를 찾는 것이 의학적으로 유효하다고 말합니다.  
위 competition에서는 그 중 normal glomeruli을 분할하고자 합니다.


데이터셋 구성은 다음과 같습니다.
- 20 kidney tissue sections
    - 11 fresh frozen (FF) carboxymethylcellulose(CMC) embedded
    - 9 formalin fixed paraffin embedded (FFPE)
- Each sample has the following data:
    - PAS stain microscopy image(RGB-channel TIFF)
        - The histology stained image is saved as a 24 bit RGB .tif file.
        - The spatial resolution is .5 micron per pixel
    - Anatomical structure segmentation mask(JSON)
    - Glomeruli segmentation mask(JSON)
- Known clinical metadata

## Step 1-2. Dataset structure
HuBMAP 데이터셋 구조는 다음과 같습니다.

- Anatomical structure(json)
- Glomeruli label(json)
- Whole Slide Image(tiff)

그 외에도 race, sex, age와 같은 clinical metadata도 포함되어 있습니다.  
-> HuBMAP-20-dataset_information.csv

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import PIL.Image as Image
import tifffile as tiff

In [None]:
HuBMAP_PATH = "./wsi/HuBMAP/"

In [None]:
hubmap_data_info = pd.read_csv(os.path.join(HuBMAP_PATH, 'HuBMAP-20-dataset_information.csv'))
hubmap_data_info.tail()

### Step 1-2-1. Whole Slide Image(.tiff)
본 challenge의 Kidney WSI는 tiff 파일로 구성되어 있습니다.  

In [None]:
DOWNSAMPLE_FACTOR = 16 # level4 resolution

In [None]:
wsi_sample_lv4 = tiff.imread(os.path.join(HuBMAP_PATH, "0486052bb.tiff")) #openslide open level4
wsi_sample_lv4 = cv2.resize(wsi_sample_lv4,
                            (wsi_sample_lv4.shape[1]//DOWNSAMPLE_FACTOR,wsi_sample_lv4.shape[0]//DOWNSAMPLE_FACTOR),
                            interpolation = cv2.INTER_AREA)
plt.figure(figsize=(15,15))
plt.imshow(wsi_sample_lv4)
plt.title("0486052bb.tiff")

### Step 1-2-2. Anatomical structure(.json)
Kidney 조직 WSI 영상의 anatomical structure를 json 파일로 제공합니다.

![Anatomical_structure](https://www.kaggleusercontent.com/kf/47173731/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..q-BBmw8iCLU5eus2kmodNQ.0L4ZTx8opqJ45okmbMud-ycfHAyPeW6HCa6ZtukTlNPjyZpfXHIo8nLq8kE_3xrxCVsTxrd9ap0KOjmKq4CeNroqY9kk4_gRPSoWzFKJ_rehW_2IDHBTpXRgxvEUlfQyqjgTxwECcUX8X4eatGlI7B92ybYNf0rTVM45nN73zPwTyNvZKFJVmOTYKTlxRmTvOOaETigy8f29MZLtPtvqEIIz8xOPz1F-A4wmyKKht4qUOszVC6EdO-K0x7t7RxCOmWQmADFpYn-0k4FTxlRNQCDgw3o8mfmjm0MIO81oMaIYmL1QmixNbft1h613haBCL0Ee8B5EP167-O_vDySa6ZGTZONOT5vaiq6t4IyZS3ojP6AEi30l006DZ46vEKmgiIIavgLrEbGnX0iIyNdxOti2vCOZprh6Bkh4_cAQC4dh8eMyP6rqx7UvCoVWgLMNfdIApgb2xOHnWcnw5jkptSFmWhVF_0_i78W8pSNrt2iVpdCO-weFBuIB5obKv-KX2Cr-Qi-LIrbM_soi2C4aUsDpSOfR20AaxPj_LXLNhncDMoRrJgnY8ydzBh_g-FdyM51bd43KxF7ePgdeB3kIjyMnjzwDswbqqr8_DNXXvYOlgPVFTFJHKPoEUyQACQEP-QHSOyyZt4RU86xfrKrBhEW-GsRd-scuxzNXw1ySP64.v-jC0l2WA_ttSh6_KTe8XA/__results___files/__results___2_1.png)

reference: https://www.kaggle.com/leahscherschel/dataset-details

In [None]:
import json

In [None]:
structure_data_sample = json.load(open(os.path.join(HuBMAP_PATH, "0486052bb-anatomical-structure.json"), 'r'))
structure_data_sample #list (x, y) 이 때, x좌표의 0은 왼쪽 y좌표의 0은 상단입니다.

In [None]:
# json 내의 anatomical structure 개수와 구성
len(structure_data_sample)

In [None]:
for i in range(len(structure_data_sample)):
    print(structure_data_sample[i]["properties"]["classification"]["name"])

In [None]:
structure_data_sample[1] #dict in list 

json 파일은 좌표로 구성되어 있어 각 좌표를 이어주면 해당 polygon을 얻을 수 있습니다.

![coordinate_to_poly](https://qph.fs.quoracdn.net/main-qimg-2a0bdb4c01da036091c068fcc4c4ff0a.webp)

In [None]:
# Cortex
poly_coordi = structure_data_sample[1]['geometry']['coordinates'] #level0 coordinates 
poly_coordi = np.int32(np.array(poly_coordi) / DOWNSAMPLE_FACTOR) #level4 coordinates
cortex_lv4 = wsi_sample_lv4.copy()
cortex_lv4 = cv2.polylines(cortex_lv4, np.int32([poly_coordi]), True, (0, 255, 0), 5)

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(cortex_lv4)
plt.title("0486052bb.tiff - Cortex")

### Step 1-2-3. Glomeruli label(.json, .csv)
target인 glomeruli 또한 json 파일로 구성되어 있습니다.   
또한 Run Length Encoding(RLE) 기법을 이용해 glomeruli mask 영상을 압축하였습니다.

In [None]:
mask_annotation_sample = json.load(open(os.path.join(HuBMAP_PATH, "0486052bb.json"), 'r'))
print(f"Number of glomerulus(mask): {len(mask_annotation_sample)}")

In [None]:
mask_lv4 = np.zeros(wsi_sample_lv4.shape[0:2])
cnt = 0
for anno in mask_annotation_sample:
    label_name = anno['properties']['classification']['name']
    if label_name != 'glomerulus':
        continue
    anno_coordi = anno['geometry']['coordinates']
    if len(anno_coordi) <= 1:
        anno_coordi = np.int32(np.array(anno_coordi) / DOWNSAMPLE_FACTOR)
        mask_lv4 = cv2.fillPoly(mask_lv4, anno_coordi, True)
        cnt += 1
    else:
        anno_coordi = np.concatenate([np.array(anno_coordi[i]).squeeze() for i in range(len(anno_coordi))])
        anno_coordi = np.int32(np.array(anno_coordi) / DOWNSAMPLE_FACTOR)
        mask_lv4 = cv2.fillPoly(mask, anno_coordi, True)
        cnt += len(anno_coordi)
        
print(f"Added {cnt} glomerulus")

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(wsi_sample_lv4)
plt.imshow(mask_lv4, alpha=0.3)
plt.title("0486052bb.tiff - glomerulus")

## Step 2. 데이터 전처리: WSI patch generation

In [None]:
# reference code: https://www.kaggle.com/mariazorkaltseva/hubmap-seresnext50-unet-dice-loss
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)

    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape[::-1]).T

# New version
def rle_encode_less_memory(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    This simplified method requires first and last pixel to be zero
    '''
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

def is_tile_contains_info(img, pixel_limits, content_threshold, expected_shape):
    """
    img: np.array
    pixel_limits: tuple
    content_threshold: float percents
    expected_shape: tuple
    """
    
    left_limit = np.prod(img > pixel_limits[0], axis=-1)
    right_limit =  np.prod(img < pixel_limits[1], axis=-1)

    if img.shape != expected_shape:
        return False, 0.

    percent_of_pixels = np.sum(left_limit*right_limit) / (img.shape[0] * img.shape[1])
    return  percent_of_pixels > content_threshold, percent_of_pixels

def extract_train_tiles(sample_img_path, rle_mask_sample, fname):
    """downsampling image and extract tiles with downsampled image
    """
    print(fname)
    sample_image = tiff.imread(sample_img_path)
        
    sample_mask = rle_decode(rle_mask_sample, sample_image.shape[0:2])
    print(f"Original Tiff image shape: {sample_image.shape}")
    
    pad0 = (REDUCE_RATE*TILE_SIZE - sample_image.shape[0]%(REDUCE_RATE*TILE_SIZE))%(REDUCE_RATE*TILE_SIZE)
    pad1 = (REDUCE_RATE*TILE_SIZE - sample_image.shape[1]%(REDUCE_RATE*TILE_SIZE))%(REDUCE_RATE*TILE_SIZE)
    
    sample_image = np.pad(sample_image,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                   constant_values=0)
    sample_mask = np.pad(sample_mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2]],
                  constant_values=0)
        
    sample_image = cv2.resize(sample_image,(sample_image.shape[1]//REDUCE_RATE,sample_image.shape[0]//REDUCE_RATE),
                             interpolation = cv2.INTER_AREA)
    
    sample_mask = cv2.resize(sample_mask,(sample_mask.shape[1]//REDUCE_RATE,sample_mask.shape[0]//REDUCE_RATE),
                             interpolation = cv2.INTER_AREA)
    
    print(f"Reduced Tiff image shape: {sample_image.shape}")
    
    tiles, masks = [], []
    for x in range(0,sample_image.shape[0],TILE_SIZE):
        for y in range(0,sample_image.shape[1],TILE_SIZE):
            sub_image = np.float32(sample_image[x:x+TILE_SIZE,y:y+TILE_SIZE])
            sub_mask = sample_mask[x:x+TILE_SIZE,y:y+TILE_SIZE]
            if is_tile_contains_info(sub_image, (50, 220), 0.7, (TILE_SIZE,TILE_SIZE, 3))[0]:
                tiles.append(sub_image)
                masks.append(sub_mask)
            else:
                continue
    if not os.path.exists(TRAIN_SAVE_DIR):
        os.mkdir(TRAIN_SAVE_DIR)
    if not os.path.exists(os.path.join(TRAIN_SAVE_DIR, "wsi")):
        os.mkdir(os.path.join(TRAIN_SAVE_DIR, "wsi"))
    if not os.path.exists(os.path.join(TRAIN_SAVE_DIR, "mask")):
        os.mkdir(os.path.join(TRAIN_SAVE_DIR, "mask"))

    count = 0
    for tile,mask in zip(tiles,masks):
        cv2.imwrite(os.path.join(TRAIN_SAVE_DIR, "wsi", f"{fname}_{count:03}.png"), tile)
        cv2.imwrite(os.path.join(TRAIN_SAVE_DIR, "mask", f"{fname}_{count:03}.png"), mask)
        count += 1
            
    print(f"Length tiles", len(tiles))

def extract_test_tiles(sample_img_path, rle_mask_sample, fname):
    """padding + downsampling image and extract tiles with downsampled image
    """
    print(fname)
    sample_image = tiff.imread(sample_img_path)
    
    sample_mask = rle_decode(rle_mask_sample, sample_image.shape[0:2])
    print(f"Original Tiff image shape: {sample_image.shape}")
    
    pad0 = (REDUCE_RATE*TILE_SIZE - sample_image.shape[0]%(REDUCE_RATE*TILE_SIZE))%(REDUCE_RATE*TILE_SIZE)
    pad1 = (REDUCE_RATE*TILE_SIZE - sample_image.shape[1]%(REDUCE_RATE*TILE_SIZE))%(REDUCE_RATE*TILE_SIZE)
    
    sample_image = np.pad(sample_image,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                   constant_values=0)
    sample_mask = np.pad(sample_mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2]],
                  constant_values=0)
        
    sample_image = cv2.resize(sample_image,(sample_image.shape[1]//REDUCE_RATE,sample_image.shape[0]//REDUCE_RATE),
                             interpolation = cv2.INTER_AREA)
    
    sample_mask = cv2.resize(sample_mask,(sample_mask.shape[1]//REDUCE_RATE,sample_mask.shape[0]//REDUCE_RATE),
                             interpolation = cv2.INTER_AREA)
    
    print(f"Reduced Tiff image shape: {sample_image.shape}")
    
    tiles, masks = [], []
    for x in range(0,sample_image.shape[0],TILE_SIZE):
        for y in range(0,sample_image.shape[1],TILE_SIZE):
            sub_image = np.float32(sample_image[x:x+TILE_SIZE,y:y+TILE_SIZE])
            sub_mask = sample_mask[x:x+TILE_SIZE,y:y+TILE_SIZE]
            tiles.append(sub_image)
            masks.append(sub_mask)
    if not os.path.exists(TEST_SAVE_DIR):
        os.mkdir(TEST_SAVE_DIR)
    if not os.path.exists(os.path.join(TEST_SAVE_DIR, "wsi")):
        os.mkdir(os.path.join(TEST_SAVE_DIR, "wsi"))
    if not os.path.exists(os.path.join(TEST_SAVE_DIR, "mask")):
        os.mkdir(os.path.join(TEST_SAVE_DIR, "mask"))

    count = 0
    for tile,mask in zip(tiles,masks):
        cv2.imwrite(os.path.join(TEST_SAVE_DIR, "wsi", f"{fname}_{count:03}.png"), tile)
        cv2.imwrite(os.path.join(TEST_SAVE_DIR, "mask", f"{fname}_{count:03}.png"), mask)
        count += 1
            
    print(f"Length tiles", len(tiles))

In [None]:
TRAIN_SAVE_DIR = "./wsi/train/"
TEST_SAVE_DIR = "./wsi/test/"
TILE_SIZE = 256
REDUCE_RATE = 4
hubmap_rle_info = pd.read_csv(os.path.join(HuBMAP_PATH, 'train.csv'))

In [None]:
train_img_path = os.path.join(HuBMAP_PATH, "0486052bb.tiff")
train_rle_str = hubmap_rle_info[hubmap_rle_info["id"]=="0486052bb"].encoding.values[0]
extract_train_tiles(train_img_path, train_rle_str, "0486052bb")

In [None]:
train_img_path = os.path.join(HuBMAP_PATH, "8242609fa.tiff")
train_rle_str = hubmap_rle_info[hubmap_rle_info["id"]=="8242609fa"].encoding.values[0]
extract_train_tiles(train_img_path, train_rle_str, "8242609fa")

In [None]:
test_img_path = os.path.join(HuBMAP_PATH, "aaa6a05cc.tiff")
test_rle_str = hubmap_rle_info[hubmap_rle_info["id"]=="aaa6a05cc"].encoding.values[0]
extract_test_tiles(test_img_path, test_rle_str, "aaa6a05cc")

## Step 3. 데이터셋 구축

In [None]:
import torch
import albumentations as A

In [None]:
class HuBMAPDataset(torch.utils.data.Dataset):
    
    def __init__(
            self, 
            paths, 
            mode,
            augmentation=None,
            preprocessing=None,
    ):

        self.paths = paths
        self.mode = mode
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def __getitem__(self, i):       
        if self.mode in ['train', 'val']:
            image = np.array(Image.open(self.paths[i][0]))
            mask = np.array(Image.open(self.paths[i][1]).convert('L'))
            mask = np.expand_dims(mask, axis=2)
        else:
            image = np.array(Image.open(self.paths[i]))

        if self.augmentation:
            if self.mode in ['train', 'val']:
                sample = self.augmentation(image=image, mask=mask)
                image, mask = sample['image'], sample['mask']
            else:
                sample = self.augmentation(image=image)
                image = sample['image']

        if self.preprocessing:
            if self.mode in ['train', 'val']:
                sample = self.preprocessing(image=image, mask=mask)
                image, mask = sample['image'], sample['mask']
            else:
                sample = self.preprocessing(image=image)
                image = sample['image']

        if self.mode in ['train', 'val']:
            return image, mask
        
        return image
        
    def __len__(self):
        return len(self.paths)

In [None]:
def get_training_augmentation():
    transform_list = []
    transform_list.append(A.RandomRotate90(p=.5))
    transform_list.append(A.HorizontalFlip(p=.5))
    transform_list.append(A.VerticalFlip(p=.5))
    transform_list.append(A.Transpose(p=.5))
    transform_list.append(A.ShiftScaleRotate(scale_limit=0.2, rotate_limit=0, shift_limit=0.2, border_mode=0, p=.5))
    transform_list.append(
        A.OneOf([
            A.RandomBrightness(limit=.2, p=1), 
            A.RandomContrast(limit=.2, p=1), 
            A.RandomGamma(p=1)
        ], p=.5)
    )
    transform_list.append(
        A.OneOf([
            A.Blur(blur_limit=3, p=1),
            A.MedianBlur(blur_limit=3, p=1)
        ], p=.1)
    )
    transform_list.append(
        A.OneOf([
            A.RandomContrast(p=1),
            A.HueSaturationValue(p=1)
        ], p=.9)
    )
    
    
    return A.Compose(transform_list)


def get_validation_augmentation():
    test_transform = [
    ]
    return A.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing():
    _transform = [
        A.Normalize(mean=(0.65459856,0.48386562,0.69428385), 
                       std=(0.15167958,0.23584107,0.13146145), 
                       max_pixel_value=255.0, always_apply=True, p=1.0),
        A.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return A.Compose(_transform)

In [None]:
train_img_path_ls = glob.glob('./wsi/train/wsi/*')
train_img_path_ls = [file for file in train_img_path_ls if file.endswith(".png")]
train_img_path_ls.sort()
train_mask_path_ls = glob.glob('./wsi/train/mask/*')
train_mask_path_ls = [file for file in train_mask_path_ls if file.endswith(".png")]
train_mask_path_ls.sort()
train_path_ls = list(zip(train_img_path_ls, train_mask_path_ls))

In [None]:
test_img_path_ls = glob.glob('./wsi/test/wsi/*')
test_img_path_ls = [file for file in test_img_path_ls if file.endswith(".png")]
test_img_path_ls.sort()
test_mask_path_ls = glob.glob('./wsi/test/mask/*')
test_mask_path_ls = [file for file in test_mask_path_ls if file.endswith(".png")]
test_mask_path_ls.sort()
test_paths = list(zip(test_img_path_ls, test_mask_path_ls))

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_paths, val_paths = train_test_split(train_path_ls,
                                        test_size=0.2,  
                                        random_state=0,
                                        shuffle=True)

In [None]:
aug_dataset = HuBMAPDataset(train_paths,
                              'train',
                              augmentation=get_training_augmentation(),
                             )

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.axis("off")
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
# check augmentation 
for i in range(10):
    image, mask = aug_dataset[3]
    visualize(image=image, mask=mask.squeeze(-1))

## Step 4. 모델 학습
segmentation 모델을 모아놓은 library를 활용합니다.  
segmentation models by. qubvel  
https://github.com/qubvel/segmentation_models.pytorch

### Unet
실습에서 사용할 모델은 Unet입니다.  
Unet은 말 그대로 U자 모양으로 생긴 모델로   
입력 영상의 context 포착을 위한 contracting path와    
추출한 feature map을 upsampling하는 구간인 expanding path로 구성되어 있습니다.  
최종 output은 각 pixel에 target 유무에 대한 probability로 나오게 됩니다.

![Unet](https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FvLAIi%2FbtqBCmVoUFu%2FsaODCYPMvM5Siesq8s3gP1%2Fimg.png)

### Dice loss
$
DL = 1 - {{2p_{i1}y_{i1} +\gamma}\over{p_{i1}^{2}y_{i1}^{2} +\gamma}}
$
실습에서 사용하는 loss function은 dice loss로 dice coefficient를 loss로 활용합니다.  
dice coefficient는 prediction과 ground truth의 overlap area에 2를 곱하고  
prediction과 ground truth 영역을 합한 값으로 나누어 구합니다.

![Dice loss](https://i.stack.imgur.com/OsH4y.png)

In [None]:
train_dataset = HuBMAPDataset(train_paths, 
                              'train',
                              augmentation=get_training_augmentation(), 
                              preprocessing=get_preprocessing()
                            )
valid_dataset = HuBMAPDataset(val_paths,
                              'val',
                              augmentation=get_validation_augmentation(), 
                              preprocessing=get_preprocessing()
                            )
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=16, 
                                            shuffle=True, 
                                            num_workers=4, 
                                            pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                                            batch_size=8, 
                                            shuffle=False, 
                                            num_workers=4, 
                                            pin_memory=True)

In [None]:
import segmentation_models_pytorch as smp

In [None]:
ENCODER = 'se_resnext50_32x4d'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=None, 
    in_channels=3,
    classes=1, 
    activation=ACTIVATION,
)

In [None]:
loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
save_path = "./ckpt/Unet_Resnext50"

try: 
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
        print(f"New directory!: {save_path}")
except OSError: 
    print("Error: Failed to create the directory.")

In [None]:
# train model for 10 epochs
epoch = 10
max_score = 0

for i in range(0, epoch-1):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, os.path.join(save_path, 'best_model01.pth'))
        print('New Record!')
    
    torch.save(model, os.path.join(save_path, 'final_model01.pth'))
        
    if i == 100:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

## Step 5. 평가

### Step 5-1. Patch-wise validation
평가는 patch-wise와 whole slide-wise 두가지로 진행합니다.  
먼저 patch-wise로 predict mask의 dice coefficient 평균을 구하고  
실제로 어떤 형태로 prediction되는지 시각화해봅니다.

In [None]:
test_img_path_ls = glob.glob('./wsi/test/wsi/*')
test_img_path_ls = [file for file in test_img_path_ls if file.endswith(".png")]
test_img_path_ls.sort()
test_mask_path_ls = glob.glob('./wsi/test/mask/*')
test_mask_path_ls = [file for file in test_mask_path_ls if file.endswith(".png")]
test_mask_path_ls.sort()
test_paths = list(zip(test_img_path_ls, test_mask_path_ls))

In [None]:
best_model = torch.load(os.path.join("./ckpt/Unet_Resnext50", 'best_model01.pth'))

In [None]:
# code reference: https://gist.github.com/gergf/acd8e3fd23347cb9e6dc572f00c63d79
def dice(true_mask, pred_mask, non_seg_score=1.0):
    """
        Computes the Dice coefficient.
        Args:
            true_mask : Array of arbitrary shape.
            pred_mask : Array with the same shape than true_mask.  
        
        Returns:
            A scalar representing the Dice coefficient between the two segmentations. 
        
    """
    assert true_mask.shape == pred_mask.shape

    true_mask = np.asarray(true_mask).astype(np.bool_)
    pred_mask = np.asarray(pred_mask).astype(np.bool_)

    # If both segmentations are all zero, the dice will be 1. (Developer decision)
    im_sum = true_mask.sum() + pred_mask.sum()
    if im_sum == 0:
        return non_seg_score

    # Compute Dice coefficient
    intersection = np.logical_and(true_mask, pred_mask)
    return 2. * intersection.sum() / im_sum

In [None]:
test_dataset = HuBMAPDataset(test_img_path_ls, 
                             'test',
                             preprocessing=get_preprocessing()
                        )

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                                              batch_size=1, 
                                              shuffle=False, 
                                              num_workers=4)

test_dataset_vis = HuBMAPDataset(test_paths, 'val')

tot_patch_dice_coef = 0
best_model.eval()

for i in range(len(test_dataset)):
    
    image_vis, gt = test_dataset_vis[i] #.astype('uint8')
    image = test_dataset[i]
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    with torch.set_grad_enabled(False):
        pr_mask = best_model.predict(x_tensor)
        pr_mask = (pr_mask.squeeze().cpu().numpy().round().astype('uint8'))
    tot_patch_dice_coef += dice(gt.squeeze(), pr_mask)
print(f"Average of patch-wise dice coefficient: {tot_patch_dice_coef/len(test_dataset)}")

In [None]:
for i in range(30,50):
    
    image_vis, gt = test_dataset_vis[i] #.astype('uint8')
    image = test_dataset[i]
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    with torch.set_grad_enabled(False):
        pr_mask = best_model.predict(x_tensor)
        pr_mask = (pr_mask.squeeze().cpu().numpy().round().astype('uint8'))
        
    visualize(
        image=image_vis, 
        ground_truth = gt,
        predicted_mask=pr_mask
    )

실습의 편의를 위해 미리 학습한 weight를 갖고 평가를 진행합니다.

In [None]:
best_model = torch.load(os.path.join("./wsi/best_ckpt", 'best_model.pth'))

In [None]:
test_dataset = HuBMAPDataset(test_img_path_ls, 
                             'test',
                             preprocessing=get_preprocessing()
                        )

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                                              batch_size=1, 
                                              shuffle=False, 
                                              num_workers=4)

test_dataset_vis = HuBMAPDataset(test_paths, 'val')

tot_patch_dice_coef = 0
best_model.eval()

for i in range(len(test_dataset)):
    
    image_vis, gt = test_dataset_vis[i] #.astype('uint8')
    image = test_dataset[i]
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    with torch.set_grad_enabled(False):
        pr_mask = best_model.predict(x_tensor)
        pr_mask = (pr_mask.squeeze().cpu().numpy().round().astype('uint8'))
    tot_patch_dice_coef += dice(gt.squeeze(), pr_mask)
print(f"Average of patch-wise dice coefficient: {tot_patch_dice_coef/len(test_dataset)}")

In [None]:
for i in range(30,50):
    
    image_vis, gt = test_dataset_vis[i] #.astype('uint8')
    image = test_dataset[i]
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    with torch.set_grad_enabled(False):
        pr_mask = best_model.predict(x_tensor)
        pr_mask = (pr_mask.squeeze().cpu().numpy().round().astype('uint8'))
        
    visualize(
        image=image_vis, 
        ground_truth = gt,
        predicted_mask=pr_mask
    )

### Step 5-2. Whole image-wise validation
patch로 나누어 결과를 출력하고, 다시 Whole Slide Image로 합치는 과정이 필요합니다.  
합친 후, whole image에서의 평가를 진행합니다.

In [None]:
def extract_slide_tiles(file_paths, ID):
    def check_id(file_path):
        return os.path.splitext(os.path.basename(file_path))[0].split('_')[0]
    tile_ls = [path for path in file_paths if check_id(path)==ID]
    return sorted(tile_ls)

IDX = 'aaa6a05cc'
slide_paths = extract_slide_tiles(test_img_path_ls, IDX)

# padded and reduced shape of corresponding image
height, width = (4864, 3328)

In [None]:
from tqdm import tqdm

In [None]:
# one slide predictions
test_dataset = HuBMAPDataset(slide_paths, 
                             'test',
                             preprocessing=get_preprocessing()
                        )

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

test_dataset_vis = HuBMAPDataset(slide_paths, 'test')

best_model.eval()

mask_preds = []
for i, image in enumerate(tqdm(test_dataloader)):
    image_vis = test_dataset_vis[i].astype('uint8')
    
    image = image.to(DEVICE)
    with torch.set_grad_enabled(False):
        mask_pred = best_model(image)
        mask_pred = mask_pred.squeeze().cpu().numpy().round().astype('uint8')
    mask_preds.append(np.expand_dims(mask_pred, axis=0))
    

mask_preds = np.concatenate(mask_preds)

In [None]:
# patch-wise image를 합치는 과정
merge_image = np.zeros((height, width, 3))

k = 0
for i in range(0, height // TILE_SIZE):
    for j in range(0, width // TILE_SIZE):
        image = np.array(Image.open(slide_paths[k]))
        merge_image[i*TILE_SIZE:i*TILE_SIZE + TILE_SIZE, j*TILE_SIZE:j*TILE_SIZE + TILE_SIZE, :] = image
        k += 1

In [None]:
# prediction mask를 합치는 과정
merge_mask = np.zeros((height, width))

k = 0
for i in range(0, height // TILE_SIZE):
    for j in range(0, width // TILE_SIZE):
        merge_mask[i*TILE_SIZE:i*TILE_SIZE + TILE_SIZE, j*TILE_SIZE:j*TILE_SIZE + TILE_SIZE] = mask_preds[k]
        k += 1

In [None]:
# Ground Truth mask 불러오기
sample_img_path = os.path.join(HuBMAP_PATH, "aaa6a05cc.tiff") 
rle_mask_sample = hubmap_rle_info[hubmap_rle_info["id"]==IDX].encoding.values[0]
sample_image = tiff.imread(sample_img_path)
gt_mask = rle_decode(rle_mask_sample, sample_image.shape[0:2])
pad0 = (REDUCE_RATE*TILE_SIZE - sample_image.shape[0]%(REDUCE_RATE*TILE_SIZE))%(REDUCE_RATE*TILE_SIZE)
pad1 = (REDUCE_RATE*TILE_SIZE - sample_image.shape[1]%(REDUCE_RATE*TILE_SIZE))%(REDUCE_RATE*TILE_SIZE)
gt_mask = np.pad(gt_mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2]],
              constant_values=0)
gt_mask = cv2.resize(gt_mask,(gt_mask.shape[1]//REDUCE_RATE,gt_mask.shape[0]//REDUCE_RATE),
                         interpolation = cv2.INTER_AREA)

In [None]:
print(f"Whole image-wise dice coefficient: {dice(gt_mask, merge_mask)}")

In [None]:
fig,ax = plt.subplots(1,2,figsize=(15,15))
ax[0].imshow(merge_image.astype('uint8'))
ax[0].imshow(gt_mask, cmap='coolwarm', alpha=0.5)
ax[0].set_title("ground truth")
ax[1].imshow(merge_image.astype('uint8'))
ax[1].imshow(merge_mask.astype('uint8'), cmap='coolwarm', alpha=0.5)
ax[1].set_title("prediction")

### Step 5-3. IoU score
IoU score는 segmentation model을 평가하는 지표 중 하나로,  
예측 결과와 Ground Truth 간 교집합 영역 넓이/ 합집합 영역 넓이로 점수를 매깁니다.  
일치할 경우 1, 어느 하나 매치하지 않다면 score는 0이 됩니다.

![iou](https://www.pyimagesearch.com/wp-content/uploads/2016/09/iou_equation.png)

https://www.pyimagesearch.com/

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
def compute_iou(y_pred, y_true):
     # ytrue, ypred is a flatten vector
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    current = confusion_matrix(y_true, y_pred, labels=[0, 1])
    # compute mean iou
    intersection = np.diag(current)
    ground_truth_set = current.sum(axis=1)
    predicted_set = current.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection
    IoU = intersection / union.astype(np.float32)
    return np.mean(IoU)

In [None]:
compute_iou(merge_mask, gt_mask)