In [1]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

run = 5

In [2]:
def get_patch_label(filename):
    label_str = '[' + filename.split('[')[-1].split(']')[0] + ']'
    if ',' not in label_str:
        if ' ' in label_str:
            # [0 0 0 1]
            label_str = label_str.replace(' ', ',')
        else:
            # [0001]
            label_str = str([int(i) for i in label_str[1:-1]])
    label = eval(label_str)
    return label

def create_data(train_data):
    only_tum_list = []
    only_nec_list = []
    only_lym_list = []
    only_tas_list = []
    train_image_list = os.listdir(train_data)
    for name in train_image_list:
        big_label = get_patch_label(name)
        if np.sum(big_label) == 1:
            train_image = os.path.join(train_data, name)
            if big_label[0] == 1:
                only_tum_list.append(train_image)
            elif big_label[1] == 1:
                only_nec_list.append(train_image)
            elif big_label[2] == 1:
                only_lym_list.append(train_image)
            elif big_label[3] == 1:
                only_tas_list.append(train_image)

    return only_tum_list, only_nec_list, only_lym_list, only_tas_list

def get_onelabel_mask(category, scale):
    if category == "tum":
        return np.zeros((scale, scale), dtype=np.uint8)
    elif category == "nec":
        return np.ones((scale, scale), dtype=np.uint8)
    elif category == "lym":
        return (np.ones((scale, scale), dtype=np.uint8) * 2)
    else:
        return (np.ones((scale, scale), dtype=np.uint8) * 3)


In [3]:
train_img_dir = 'data/LUAD-HistoSeg/train'
category_list = ['tum', 'nec', 'lym', 'tas']
only_tum_list, only_nec_list, only_lym_list, only_tas_list = create_data(train_img_dir)

palette = [0]*15
palette[0:3] = [205,51,51]          # Tumor epithelial (TE)
palette[3:6] = [0,255,0]            # Necrosis (NEC)
palette[6:9] = [65,105,225]         # Lymphocyte (LYM)
palette[9:12] = [255,165,0]         # Tumor-associated stroma (TAS)
palette[12:15] = [255, 255, 255]    # White background or exclude


### One Label Dataset

In [4]:
one_label_image_dir = "data/LUAD-HistoSeg/one_label_image/img"
one_label_mask_dir = "data/LUAD-HistoSeg/one_label_image/mask"
if not os.path.exists(one_label_image_dir):
    os.makedirs(one_label_image_dir)
if not os.path.exists(one_label_mask_dir):
    os.makedirs(one_label_mask_dir)

for category, category_image_list in zip(category_list, [only_tum_list, only_nec_list, only_lym_list, only_tas_list]):
    print(category, len(category_image_list))
    for image_path in category_image_list:
        if os.path.exists(os.path.join(one_label_image_dir, os.path.basename(image_path))):
            continue
        os.system(f'cp "{image_path}" {one_label_image_dir}')
        mask = get_onelabel_mask(category, 224)
        output_mask = Image.fromarray(np.uint8(mask), mode='P')
        output_mask.putpalette(palette)
        output_mask.save(os.path.join(one_label_mask_dir, os.path.basename(image_path)))


tum 1574
nec 787
lym 42
tas 2192


### Gridded Dataset

In [5]:
from torch.utils.data import Dataset as BaseDataset
import albumentations as albu
import random

In [6]:
class CropAndConcatDataset(BaseDataset):    
    def __init__(self, patch_num, patch_size, size=None):
    
        self.tum, self.nec, self.lym, self.tas = only_tum_list, only_nec_list, only_lym_list, only_tas_list
        self.single_type_images = self.tum + self.nec + self.lym + self.tas

        self.patch_num = patch_num
        self.patch_size = patch_size

        self.total_len = len(self.single_type_images) if size is None else size

        self.crop_fn = albu.Compose([
            albu.PadIfNeeded(min_height=self.patch_size, min_width=self.patch_size),
            albu.RandomCrop(width=self.patch_size, height=self.patch_size)
        ])
        
        self.transform = albu.Compose([
            albu.Flip(),
            albu.RandomRotate90(),
        ])
    
    def __getitem__(self, i):
        # np.random.seed(2022 + 2022 * i)
        # random.seed(2022 + 2022 * i)
        
        image_1, mask_1 = self.create_one_image()
        return image_1, mask_1

    def create_one_image(self):      
        H = W = self.patch_num * self.patch_size
        image = np.zeros((H, W, 3), dtype=np.uint8)
        mask = np.zeros((H, W), dtype=np.uint8)

        for i in range(self.patch_num):
            for j in range(self.patch_num):

                tile_name = np.random.choice(self.single_type_images)
                label = get_patch_label(tile_name)
                assert sum(label) == 1
                tile = np.asarray(Image.open(tile_name))
                label = label.index(1)
                tile_mask = np.full((tile.shape[0], tile.shape[1]), label)

                sample = self.crop_fn(image=tile, mask=tile_mask)
                tile = sample['image']
                tile_mask = sample['mask']
                
                image[i*self.patch_size: (i+1)*self.patch_size, j*self.patch_size: (j+1)*self.patch_size] = tile
                mask[i*self.patch_size: (i+1)*self.patch_size, j*self.patch_size: (j+1)*self.patch_size] = tile_mask
        
        return image, mask

    def __len__(self):
        return self.total_len

In [7]:
gridded_dataset = CropAndConcatDataset(patch_num=2, patch_size=112, size=7200)
gridded_img_dir = f"data/LUAD-HistoSeg/gridded_image_2_112_run{run}/img"
gridded_mask_dir = f"data/LUAD-HistoSeg/gridded_image_2_112_run{run}/mask"
if not os.path.exists(gridded_img_dir):
    os.makedirs(gridded_img_dir)
if not os.path.exists(gridded_mask_dir):
    os.makedirs(gridded_mask_dir)

In [8]:
# # quick view
# from matplotlib import pyplot as plt
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 15))
# for i in range(25):
#     idx = np.random.randint(len(gridded_dataset))
#     img, _ = gridded_dataset[idx]
#     ax = axes[i // 5, i % 5]
#     ax.imshow(img)
#     ax.set_title(f" dims: {img.size} /{np.max(img)}")
# fig.tight_layout()

In [9]:
for i in tqdm(range(len(gridded_dataset))):
    img, mask = gridded_dataset[i]
    img = Image.fromarray(img)
    mask = Image.fromarray(np.uint8(mask), mode='P')
    mask.putpalette(palette)
    img.save(os.path.join(gridded_img_dir, f"{i:05d}.png"))
    mask.save(os.path.join(gridded_mask_dir, f"{i:05d}.png"))


  0%|          | 0/7200 [00:00<?, ?it/s]

100%|██████████| 7200/7200 [06:23<00:00, 18.80it/s]


### CutMix dataset


In [10]:
class CutMixDataset(BaseDataset):
    def __init__(self, patch_size=224, size=7200):
        self.patch_size = patch_size
        self.total_len = size
        self.tum, self.nec, self.lym, self.tas = only_tum_list, only_nec_list, only_lym_list, only_tas_list
        self.single_type_images = self.tum + self.nec + self.lym + self.tas
        
    def __getitem__(self, i):
        # np.random.seed(2022 + 2022 * i)
        # random.seed(2022 + 2022 * i)
        
        H = W = self.patch_size
        image_path, choice = np.random.choice(self.single_type_images, size=2, replace=False)
        image = np.array(Image.open(image_path))
        label = get_patch_label(image_path)
        mask = np.full((self.patch_size, self.patch_size), label.index(1))

        mix_image = np.array(Image.open(choice))
        mix_label = get_patch_label(choice)
        mix_mask = np.full((self.patch_size, self.patch_size), mix_label.index(1))

        
        lam = np.random.beta(1, 1)
        bbx1, bbx2, bby1, bby2 = self._get_cutmix_bbox(W, H, lam)
        image[bbx1: bbx2, bby1: bby2, :] = mix_image[bbx1:bbx2, bby1:bby2, :]
        mask[bbx1: bbx2, bby1: bby2] = mix_mask[bbx1:bbx2, bby1:bby2]
        
        # lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
        # label = np.array(label) * lam + np.array(mix_label) * (1. - lam)

        # sample = self.transforms(image=image, mask=mask)
        # image, mask = sample['image'], sample['mask']
        return image, mask
    
    def _get_cutmix_bbox(self, W, H, lam):
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)

        cx = np.random.randint(H)
        cy = np.random.randint(W)

        bbx1 = np.clip(cx - cut_h // 2, 0, H)
        bby1 = np.clip(cy - cut_w // 2, 0, W)
        bbx2 = np.clip(cx + cut_h // 2, 0, H)
        bby2 = np.clip(cy + cut_w // 2, 0, W)
        
        return bbx1, bbx2, bby1, bby2
    
    def __len__(self):
        return self.total_len

In [11]:
cutmix_dataset = CutMixDataset(patch_size=224, size=7200)
cutmix_img_dir = f"data/LUAD-HistoSeg/cutmix_image_run{run}/img"
cutmix_mask_dir = f"data/LUAD-HistoSeg/cutmix_image_run{run}/mask"
if not os.path.exists(cutmix_img_dir):
    os.makedirs(cutmix_img_dir)
if not os.path.exists(cutmix_mask_dir):
    os.makedirs(cutmix_mask_dir)

In [12]:
# # quick view
# from matplotlib import pyplot as plt
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 15))
# for i in range(25):
#     idx = np.random.randint(len(cutmix_dataset))
#     img, _ = cutmix_dataset[idx]
#     ax = axes[i // 5, i % 5]
#     ax.imshow(img)
#     ax.set_title(f" dims: {img.size} /{np.max(img)}")
# fig.tight_layout()

In [13]:
for i in tqdm(range(len(cutmix_dataset))):
    img, mask = cutmix_dataset[i]
    img = Image.fromarray(img)
    mask = Image.fromarray(np.uint8(mask), mode='P')
    mask.putpalette(palette)
    img.save(os.path.join(cutmix_img_dir, f"{i:05d}.png"))
    mask.save(os.path.join(cutmix_mask_dir, f"{i:05d}.png"))


100%|██████████| 7200/7200 [04:02<00:00, 29.72it/s]


### Mosaic_1_224 dataset

In [14]:
class MosaicDataset(BaseDataset):    
    def __init__(self, patch_num, patch_size, size=None):
        self.tum, self.nec, self.lym, self.tas = only_tum_list, only_nec_list, only_lym_list, only_tas_list
        self.single_type_images = self.tum + self.nec + self.lym + self.tas

        self.patch_num = patch_num
        self.patch_size = patch_size

        self.total_len = len(self.single_type_images) if size is None else size

        self.crop_fn = albu.Compose([
            albu.PadIfNeeded(min_height=self.patch_size, min_width=self.patch_size),
            albu.RandomCrop(width=self.patch_size, height=self.patch_size)
        ])
        
        self.transform = albu.Compose([
            albu.Flip(),
            albu.RandomRotate90(),
        ])
    
    def __getitem__(self, i):
        # np.random.seed(2022 + 2022 * i)
        # random.seed(2022 + 2022 * i)
        
        H = W = self.patch_num * self.patch_size
        
        while True:
            try:
                (image_1, mask_1), (image_2, mask_2), (image_3, mask_3), (image_4, mask_4) = [self.create_one_image() for _ in range(4)] # [H, W, C]
                image, mask = self.create_mosaic(H, W, image_1, mask_1, image_2, mask_2, image_3, mask_3, image_4, mask_4)
                break            
            except AssertionError as e:
                print(e)

        
        return image, mask

    def create_one_image(self):      
        H = W = self.patch_num * self.patch_size
        image = np.zeros((H, W, 3), dtype=np.uint8)
        mask = np.zeros((H, W), dtype=np.uint8)

        for i in range(self.patch_num):
            for j in range(self.patch_num):

                tile_name = np.random.choice(self.single_type_images)
                label = get_patch_label(tile_name)
                assert sum(label) == 1
                tile = np.asarray(Image.open(tile_name))
                label = label.index(1)
                tile_mask = np.full((tile.shape[0], tile.shape[1]), label)

                sample = self.crop_fn(image=tile, mask=tile_mask)
                tile = sample['image']
                tile_mask = sample['mask']
                
                image[i*self.patch_size: (i+1)*self.patch_size, j*self.patch_size: (j+1)*self.patch_size] = tile
                mask[i*self.patch_size: (i+1)*self.patch_size, j*self.patch_size: (j+1)*self.patch_size] = tile_mask
        
        return image, mask
    
    def create_mosaic(self, H, W, image_1, mask_1, image_2, mask_2, image_3, mask_3, image_4, mask_4):
        def get_transforms(height, width, p=0.5):
            _transform = [
                albu.Flip(p=p),
                albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=p),
                albu.RandomCrop(height, width),
            ]
            return albu.Compose(_transform)
        
        image = np.zeros((H, W, 3), dtype=np.uint8)
        mask = np.zeros((H, W), dtype=np.uint8)

        h, w = int(H * (random.random()*0.6+0.2)), int(W * (random.random()*0.6+0.2))
        h += h % 2
        w += w % 2

        transform_1 = get_transforms(height=h, width=w, p=0.8)
        sample = transform_1(image=image_1, mask=mask_1)
        image_1, mask_1 = sample['image'], sample['mask']

        transform_2 = get_transforms(height=h, width=W-w, p=0.8)
        sample = transform_2(image=image_2, mask=mask_2)
        image_2, mask_2 = sample['image'], sample['mask']

        transform_3 = get_transforms(height=H-h, width=w, p=0.8)
        sample = transform_3(image=image_3, mask=mask_3)
        image_3, mask_3 = sample['image'], sample['mask']

        transform_4 = get_transforms(height=H-h, width=W-w, p=0.8)
        sample = transform_4(image=image_4, mask=mask_4)
        image_4, mask_4 = sample['image'], sample['mask']
        
        image[:h, :w, :] = image_1
        image[:h, w:W, :] = image_2
        image[h:H, :w, :] = image_3
        image[h:H, w:W, :] = image_4
    
        mask[:h, :w] = mask_1
        mask[:h, w:W] = mask_2
        mask[h:H, :w] = mask_3
        mask[h:H, w:W] = mask_4
        
        return image, mask

    def __len__(self):
        return self.total_len

In [15]:
mosaic_1_224_dataset = MosaicDataset(patch_num=1, patch_size=224, size=7200)
mosaic_1_224_img_dir = f"data/LUAD-HistoSeg/mosaic_1_224_run{run}/img"
mosaic_1_224_mask_dir = f"data/LUAD-HistoSeg/mosaic_1_224_run{run}/mask"
if not os.path.exists(mosaic_1_224_img_dir):
    os.makedirs(mosaic_1_224_img_dir)
if not os.path.exists(mosaic_1_224_mask_dir):
    os.makedirs(mosaic_1_224_mask_dir)
    

In [16]:
# # quick view
# from matplotlib import pyplot as plt
# fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 15))
# for i in range(25):
#     idx = np.random.randint(len(mosaic_1_224_dataset))
#     img, _ = mosaic_1_224_dataset[idx]
#     ax = axes[i // 5, i % 5]
#     ax.imshow(img)
#     ax.set_title(f" dims: {img.size} /{np.max(img)}")
# fig.tight_layout()

In [17]:
for i in tqdm(range(len(mosaic_1_224_dataset))):
    img, mask = mosaic_1_224_dataset[i]
    img = Image.fromarray(img)
    mask = Image.fromarray(np.uint8(mask), mode='P')
    mask.putpalette(palette)
    img.save(os.path.join(mosaic_1_224_img_dir, f"{i:05d}.png"))
    mask.save(os.path.join(mosaic_1_224_mask_dir, f"{i:05d}.png"))


100%|██████████| 7200/7200 [07:56<00:00, 15.10it/s]
