# README
Run this notebook to create the dataset for mosaic training.

In [1]:
from torch.utils.data import Dataset as BaseDataset
import albumentations as albu

import random
import os
import torch
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from tqdm import tqdm

# np.random.seed(42)
# random.seed(42)

In [2]:
def get_patch_label(file):
    # [a: Tumor epithelial (TE), b: Necrosis (NEC), c: Lymphocyte (LYM), d: Tumor-associated stroma (TAS)]
    if isinstance(file, Path):
        file = str(file)
    fname = file[:-4]
    label_str = '[' + fname.split(']')[0].split('[')[-1] + ']'
    label_str = label_str.replace(' ', ',')
    label = eval(label_str)
    return label

In [3]:
def get_one_label(train_image_dir):
    train_image_list = sorted(list(train_image_dir.glob('*.png')))

    single_type_num = 0
    only_tum_num = 0
    only_nec_num = 0
    only_lym_num = 0
    only_tas_num = 0
    
    only_tum_list = []
    only_nec_list = []
    only_lym_list = []
    only_tas_list = []

    for train_image in train_image_list:
        big_label = get_patch_label(train_image)
        if np.sum(big_label) == 1:
            single_type_num += 1
            if big_label[0] == 1:
                only_tum_num += 1
                only_tum_list.append(train_image)
            elif big_label[1] == 1:
                only_nec_num += 1
                only_nec_list.append(train_image)
            elif big_label[2] == 1:
                only_lym_num += 1
                only_lym_list.append(train_image)
            elif big_label[3] == 1:
                only_tas_num += 1
                only_tas_list.append(train_image)
    print('only_tum_num:', len(only_tum_list))
    print('only_nec_num:', len(only_nec_list))
    print('only_lym_num:', len(only_lym_list))
    print('only_tas_num:', len(only_tas_list))
    
    return only_tum_list, only_nec_list, only_lym_list, only_tas_list


----

In [4]:
class CropAndConcatDataset(BaseDataset):    
    def __init__(self, args, patch_num, patch_size, size=None):
       
        self.args = args
        self.train_dir = Path(args.train_dir)
        self.train_images = list(self.train_dir.glob('*.png'))
        self.tum, self.nec, self.lym, self.tas = get_one_label(self.train_dir)
        self.single_type_images = self.tum + self.nec + self.lym + self.tas
        print(f'num of single type images: {len(self.single_type_images)}')

        self.patch_num = patch_num
        self.patch_size = patch_size

        self.total_len = len(self.single_type_images) if size is None else size

        self.crop_fn = albu.Compose([
            albu.PadIfNeeded(min_height=self.patch_size, min_width=self.patch_size),
            albu.RandomCrop(width=self.patch_size, height=self.patch_size)
        ])
        
        self.transform = albu.Compose([
            albu.Flip(),
            albu.RandomRotate90(),
        ])
    
    def __getitem__(self, i):
        # np.random.seed(2022 + 2022 * i)
        # random.seed(2022 + 2022 * i)
        
        H = W = self.patch_num * self.patch_size
        
        while True:
            try:
                (image_1, mask_1), (image_2, mask_2), (image_3, mask_3), (image_4, mask_4) = [self.create_one_image() for _ in range(4)] # [H, W, C]
                image, mask = self.create_mosaic(H, W, image_1, mask_1, image_2, mask_2, image_3, mask_3, image_4, mask_4)
                break            
            except AssertionError as e:
                print(e)

        
        return image, mask

    def create_one_image(self):      
        H = W = self.patch_num * self.patch_size
        image = np.zeros((H, W, 3), dtype=np.uint8)
        mask = np.zeros((H, W), dtype=np.uint8)

        for i in range(self.patch_num):
            for j in range(self.patch_num):

                tile_name = np.random.choice(self.single_type_images)
                label = get_patch_label(tile_name)
                assert sum(label) == 1
                tile = np.asarray(Image.open(tile_name))
                label = label.index(1)
                tile_mask = np.full((tile.shape[0], tile.shape[1]), label)

                sample = self.crop_fn(image=tile, mask=tile_mask)
                tile = sample['image']
                tile_mask = sample['mask']
                
                image[i*self.patch_size: (i+1)*self.patch_size, j*self.patch_size: (j+1)*self.patch_size] = tile
                mask[i*self.patch_size: (i+1)*self.patch_size, j*self.patch_size: (j+1)*self.patch_size] = tile_mask
        
        return image, mask
    
    def create_mosaic(self, H, W, image_1, mask_1, image_2, mask_2, image_3, mask_3, image_4, mask_4):
        def get_transforms(height, width, p=0.5):
            _transform = [
                albu.Flip(p=p),
                albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=p),
                albu.RandomCrop(height, width),
            ]
            return albu.Compose(_transform)
        
        image = np.zeros((H, W, 3), dtype=np.uint8)
        mask = np.zeros((H, W), dtype=np.uint8)

        h, w = int(H * (random.random()*0.6+0.2)), int(W * (random.random()*0.6+0.2))
        h += h % 2
        w += w % 2

        transform_1 = get_transforms(height=h, width=w, p=0.8)
        sample = transform_1(image=image_1, mask=mask_1)
        image_1, mask_1 = sample['image'], sample['mask']

        transform_2 = get_transforms(height=h, width=W-w, p=0.8)
        sample = transform_2(image=image_2, mask=mask_2)
        image_2, mask_2 = sample['image'], sample['mask']

        transform_3 = get_transforms(height=H-h, width=w, p=0.8)
        sample = transform_3(image=image_3, mask=mask_3)
        image_3, mask_3 = sample['image'], sample['mask']

        transform_4 = get_transforms(height=H-h, width=W-w, p=0.8)
        sample = transform_4(image=image_4, mask=mask_4)
        image_4, mask_4 = sample['image'], sample['mask']
        
        image[:h, :w, :] = image_1
        image[:h, w:W, :] = image_2
        image[h:H, :w, :] = image_3
        image[h:H, w:W, :] = image_4
    
        mask[:h, :w] = mask_1
        mask[:h, w:W] = mask_2
        mask[h:H, :w] = mask_3
        mask[h:H, w:W] = mask_4
        
        return image, mask

    def __len__(self):
        return self.total_len

In [5]:
from argparse import Namespace

args = Namespace()
# set to 2 and 112 for better performance
patch_num = 2
patch_size = 112
N=10_000


In [6]:
from joblib import Parallel, delayed

for N_sample in [10, 20, 30, 40]:
    for run in range(1, 5):
        args.train_dir = f"./data/LUAD-HistoSeg/limit_N/one_label_N{N_sample}_run{run}"
        mosaic_data = Path(f"./data/LUAD-HistoSeg/limit_N/mosaic_{patch_num}_{patch_size}_N_{N_sample}_run{run}")
        (mosaic_data / 'img').mkdir(parents=True, exist_ok=True)
        (mosaic_data / 'mask').mkdir(parents=True, exist_ok=True)

        dataset = CropAndConcatDataset(args, patch_num=patch_num, patch_size=patch_size, size=N)

        def func(i):
            image, mask = dataset[i]
            image = Image.fromarray(image)
            palette = [0]*15
            palette[0:3] = [205,51,51]          # Tumor epithelial (TE)
            palette[3:6] = [0,255,0]            # Necrosis (NEC)
            palette[6:9] = [65,105,225]         # Lymphocyte (LYM)
            palette[9:12] = [255,165,0]         # Tumor-associated stroma (TAS)
            palette[12:15] = [255, 255, 255]    # White background or exclude
            mask = Image.fromarray(np.uint8(mask), mode='P')
            mask.putpalette(palette)
            image.save(mosaic_data / 'img' / f'{patch_num}_{patch_size}_{i}.png')
            mask.save(mosaic_data / 'mask' / f'{patch_num}_{patch_size}_{i}.png')
        def print_error(value):
            print("error: ", value)

        Parallel(n_jobs=4)(delayed(func)(i) for i in tqdm(range(N)));

only_tum_num: 10
only_nec_num: 10
only_lym_num: 10
only_tas_num: 10
num of single type images: 40


100%|██████████| 10000/10000 [02:40<00:00, 62.34it/s]


only_tum_num: 10
only_nec_num: 10
only_lym_num: 10
only_tas_num: 10
num of single type images: 40


100%|██████████| 10000/10000 [02:38<00:00, 63.13it/s]


only_tum_num: 10
only_nec_num: 10
only_lym_num: 10
only_tas_num: 10
num of single type images: 40


100%|██████████| 10000/10000 [02:37<00:00, 63.47it/s]


only_tum_num: 10
only_nec_num: 10
only_lym_num: 10
only_tas_num: 10
num of single type images: 40


100%|██████████| 10000/10000 [02:38<00:00, 63.00it/s]


only_tum_num: 20
only_nec_num: 20
only_lym_num: 20
only_tas_num: 20
num of single type images: 80


100%|██████████| 10000/10000 [02:41<00:00, 62.11it/s]


only_tum_num: 20
only_nec_num: 20
only_lym_num: 20
only_tas_num: 20
num of single type images: 80


100%|██████████| 10000/10000 [02:40<00:00, 62.26it/s]


only_tum_num: 20
only_nec_num: 20
only_lym_num: 20
only_tas_num: 20
num of single type images: 80


100%|██████████| 10000/10000 [02:42<00:00, 61.68it/s]


only_tum_num: 20
only_nec_num: 20
only_lym_num: 20
only_tas_num: 20
num of single type images: 80


100%|██████████| 10000/10000 [02:42<00:00, 61.52it/s]


only_tum_num: 30
only_nec_num: 30
only_lym_num: 30
only_tas_num: 30
num of single type images: 120


100%|██████████| 10000/10000 [02:41<00:00, 61.84it/s]


only_tum_num: 30
only_nec_num: 30
only_lym_num: 30
only_tas_num: 30
num of single type images: 120


100%|██████████| 10000/10000 [02:40<00:00, 62.42it/s]


only_tum_num: 30
only_nec_num: 30
only_lym_num: 30
only_tas_num: 30
num of single type images: 120


100%|██████████| 10000/10000 [02:38<00:00, 63.06it/s]


only_tum_num: 30
only_nec_num: 30
only_lym_num: 30
only_tas_num: 30
num of single type images: 120


100%|██████████| 10000/10000 [02:40<00:00, 62.49it/s]


only_tum_num: 40
only_nec_num: 40
only_lym_num: 40
only_tas_num: 40
num of single type images: 160


100%|██████████| 10000/10000 [02:47<00:00, 59.66it/s]


only_tum_num: 40
only_nec_num: 40
only_lym_num: 40
only_tas_num: 40
num of single type images: 160


100%|██████████| 10000/10000 [02:44<00:00, 60.80it/s]


only_tum_num: 40
only_nec_num: 40
only_lym_num: 40
only_tas_num: 40
num of single type images: 160


100%|██████████| 10000/10000 [02:45<00:00, 60.32it/s]


only_tum_num: 40
only_nec_num: 40
only_lym_num: 40
only_tas_num: 40
num of single type images: 160


100%|██████████| 10000/10000 [02:42<00:00, 61.68it/s]
