# README
Run this notebook to create the dataset for mosaic training.

In [1]:
import utils
import random
import time
import os
from functools import wraps

import torch
from pathlib import Path
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from einops import rearrange

from tqdm import tqdm
import logging

from skimage import morphology

# np.random.seed(42)
# random.seed(42)

In [2]:
def show_mask(image):
    palette = [
        0, 64, 128,  # r, g, b for Tumor
        64, 128, 0,  # r, g, b for Stroma
        243, 152, 0,  # r, g, b for Normal
        255, 255, 255,  # r, g, b for background
    ] + [0] * 252 * 3
    image = Image.fromarray(np.uint8(image), mode='P')
    image.putpalette(palette)
    plt.imshow(image)
    plt.show()
    plt.close()
    
def show_background(image):
    plt.imshow(image, vmin=0, vmax=127, cmap='gray')
    plt.show()
    plt.close()

def show_image(image):
    if isinstance(image, torch.Tensor):
        image = rearrange(image, 'c h w -> h w c')
    plt.imshow(image)
    plt.show()
    plt.close()

def create_data(train_data):
    print(train_data)
    tumor_set, stroma_set, normal_set = set(), set(), set()
    for path in Path(train_data).glob('*.png'):
        if utils.is_tumor(path): tumor_set.add(str(path))
        if utils.is_stroma(path): stroma_set.add(str(path))
        if utils.is_normal(path): normal_set.add(str(path))

    tumor_images = list(tumor_set - stroma_set - normal_set)
    stroma_images = list(stroma_set - tumor_set - normal_set)
    normal_images = list(normal_set - tumor_set - stroma_set)

    return tumor_images, stroma_images, normal_images

In [3]:
def get_background(region):
    gray = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
    ret, binary = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)
    binary = np.uint8(binary)    
    dst = morphology.remove_small_objects(binary==255,min_size=50,connectivity=1)
    mask = np.array(dst, dtype=np.uint8)
    mask = mask * 255
    return mask


In [4]:
def timeit(fn):
    """decorator for timing functions"""
    @wraps(fn)
    def measure_time(*args, **kwargs):
        t1 = time.time()
        result = fn(*args, **kwargs)
        t2 = time.time()
        logging.info(f"@timefn: {fn.__name__} took {t2 - t1: .5f} s")
        return result
    return measure_time

def visualize(save=None, **images):
    """PLot images in one row."""
    fontsize=14
    def axarr_show(axarr, image, name):
        if isinstance(image, torch.Tensor):
            if image.ndim == 3: image = image.permute(1, 2, 0)
            if image.is_cuda: image = image.detach().cpu().numpy()
        if name == 'mask': 
            palette = [0, 64, 128, 64, 128, 0, 243, 152, 0, 255, 255, 255] + [0] * 252 * 3
            image = Image.fromarray(np.uint8(image), mode='P')
            image.putpalette(palette)
            axarr.imshow(image)
            axarr.set_title(name, fontsize=fontsize)
        elif 'background' in name:
            palette = [255, 255, 255, 0, 0, 0]
            image = Image.fromarray(np.uint8(image), mode='P')
            image.putpalette(palette)
            axarr.imshow(image)
            axarr.set_title(name, fontsize=fontsize)
        else:
            axarr.imshow(image)
            axarr.set_title(name, fontsize=fontsize)
    n = len(images)
    fig, axarr = plt.subplots(nrows=1, ncols=n, figsize=(8, 8))
    if n == 1:
        name, image = list(images.items())[0]
        axarr_show(axarr, image, name)
        axarr.set_yticks([])
        axarr.set_xticks([])
    else:
        for i, (name, image) in enumerate(images.items()):
            axarr_show(axarr[i], image, name)
            
        for ax in axarr.ravel():
            ax.set_yticks([])
            ax.set_xticks([])
    plt.tight_layout()
    if save is not None:
        plt.savefig(save)
    else:
        plt.show()
    plt.close()

In [5]:
import numpy as np
from scipy.special import binom
import matplotlib.pyplot as plt


bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)

def bezier(points, num=200):
    N = len(points)
    t = np.linspace(0, 1, num=num)
    curve = np.zeros((num, 2))
    for i in range(N):
        curve += np.outer(bernstein(N - 1, i, t), points[i])
    return curve

class Segment():
    def __init__(self, p1, p2, angle1, angle2, **kw):
        self.p1 = p1; self.p2 = p2
        self.angle1 = angle1; self.angle2 = angle2
        self.numpoints = kw.get("numpoints", 100)
        r = kw.get("r", 0.3)
        d = np.sqrt(np.sum((self.p2-self.p1)**2))
        self.r = r*d
        self.p = np.zeros((4,2))
        self.p[0,:] = self.p1[:]
        self.p[3,:] = self.p2[:]
        self.calc_intermediate_points(self.r)

    def calc_intermediate_points(self,r):
        self.p[1,:] = self.p1 + np.array([self.r*np.cos(self.angle1),
                                    self.r*np.sin(self.angle1)])
        self.p[2,:] = self.p2 + np.array([self.r*np.cos(self.angle2+np.pi),
                                    self.r*np.sin(self.angle2+np.pi)])
        self.curve = bezier(self.p,self.numpoints)


def get_curve(points, **kw):
    segments = []
    for i in range(len(points)-1):
        seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
        segments.append(seg)
    curve = np.concatenate([s.curve for s in segments])
    return segments, curve

def ccw_sort(p):
    d = p-np.mean(p,axis=0)
    s = np.arctan2(d[:,0], d[:,1])
    return p[np.argsort(s),:]

def get_bezier_curve(a, rad=0.2, edgy=0):
    """ given an array of points *a*, create a curve through
    those points. 
    *rad* is a number between 0 and 1 to steer the distance of
          control points.
    *edgy* is a parameter which controls how "edgy" the curve is,
           edgy=0 is smoothest."""
    p = np.arctan(edgy)/np.pi+.5
    a = ccw_sort(a)
    a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
    d = np.diff(a, axis=0)
    ang = np.arctan2(d[:,1],d[:,0])
    f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
    ang = f(ang)
    ang1 = ang
    ang2 = np.roll(ang,1)
    ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
    ang = np.append(ang, [ang[0]])
    a = np.append(a, np.atleast_2d(ang).T, axis=1)
    s, c = get_curve(a, r=rad, method="var")
    x,y = c.T
    return x,y, a


def get_random_points(n=5, scale=0.8, mindst=None, rec=0):
    """ create n random points in the unit square, which are *mindst*
    apart, then scale them."""
    mindst = mindst or .7/n
    a = np.random.rand(n,2)
    d = np.sqrt(np.sum(np.diff(ccw_sort(a), axis=0), axis=1)**2)
    if np.all(d >= mindst) or rec>=200:
        return a*scale
    else:
        return get_random_points(n=n, scale=scale, mindst=mindst, rec=rec+1)

In [6]:
def get_bezier_mask(n, scale, rad=0.2, edgy=0.05):
    # a = get_random_points(n=n, scale=scale // 3 * 2)
    a = get_random_points(n=n, scale=scale)
    x, y, _ = get_bezier_curve(a,rad=rad, edgy=edgy)
    # delta_x = random.randint(0, scale // 3)
    # delta_y = random.randint(0, scale // 3)
    # x = np.round(x) + delta_x
    # y = np.round(y) + delta_y
    x = np.round(x)
    y = np.round(y)
    mask = np.zeros((scale, scale), dtype=np.uint8)
    mask = cv2.fillPoly(mask, np.int32([np.stack([x, y], axis=1)]), 1)

    return mask

def get_onelabel_mask(category, scale):
    if category == "tumor":
        return np.zeros((scale, scale), dtype=np.uint8)
    elif category == "stroma":
        return np.ones((scale, scale), dtype=np.uint8)
    else:
        return (np.ones((scale, scale), dtype=np.uint8) * 2)


In [7]:
train_dir = Path("./data/WSSS4LUAD/1.training")
tumor_images, stroma_images, normal_images = create_data(train_dir)
dataset_dict = {
    "tumor": tumor_images,
    "stroma": stroma_images,
    "normal": normal_images,
}

data/WSSS4LUAD/1.training


In [8]:
def synthesize_one(n=12, rad=0.2, edgy=0.05, background_class="tumor", foreground_class="stroma"):
    background_image_path = random.choice(dataset_dict[background_class])
    foreground_image_path = random.choice(dataset_dict[foreground_class])

    background_image = np.array(Image.open(background_image_path).resize((224, 224)))
    foreground_image = np.array(Image.open(foreground_image_path).resize((224, 224)))

    background_mask = get_onelabel_mask(background_class, scale=224)
    foreground_mask = get_onelabel_mask(foreground_class, scale=224)
    
    bezier_mask = get_bezier_mask(n=n, scale=224, rad=rad, edgy=edgy)

    synthesized_image = bezier_mask[:,:,np.newaxis] * foreground_image + (1 - bezier_mask)[:,:,np.newaxis] * background_image
    # synthesized_image = bezier_mask[:,:,np.newaxis] * foreground_image + (1 - bezier_mask)[:,:,np.newaxis] * background_image
    synthesized_mask = bezier_mask * foreground_mask + (1 - bezier_mask) * background_mask

    return synthesized_image, synthesized_mask


In [9]:
def synthesize_and_save(save_dir, i, n=12, rad=0.2, edgy=0.05, background_class="tumor", foreground_class="stroma"):
    synthesized_image, synthesized_mask = synthesize_one(n, rad, edgy, background_class, foreground_class)
    synthesized_image = Image.fromarray(synthesized_image)
    palette = [
        0, 64, 128,  # r, g, b for Tumor
        64, 128, 0,  # r, g, b for Stroma
        243, 152, 0,  # r, g, b for Normal
        255, 255, 255,  # r, g, b for background
    ] + [0] * 252 * 3
    synthesized_mask = Image.fromarray(np.uint8(synthesized_mask), mode='P')
    synthesized_mask.putpalette(palette)

    label = [0, 0, 0]
    if 'tumor' in [background_class, foreground_class]:
        label[0] = 1
    if 'stroma' in [background_class, foreground_class]:
        label[1] = 1
    if 'normal' in [background_class, foreground_class]:
        label[2] = 1

    synthesized_image.save(os.path.join(save_dir, 'img', f"{i:05d}-{label}.png"))
    synthesized_mask.save(os.path.join(save_dir, 'mask', f"{i:05d}-{label}.png"))

In [10]:
from joblib import Parallel, delayed

for run in range(5, 10):
    save_dir = f"data/WSSS4LUAD/bezier224_5_0.2_0.05_1d1_run{run}"

    if not os.path.exists(os.path.join(save_dir, 'img')):
        os.makedirs(os.path.join(save_dir, 'img'))
    if not os.path.exists(os.path.join(save_dir, 'mask')):
        os.makedirs(os.path.join(save_dir, 'mask'))

    N_train = 2_500
    for i in tqdm(range(N_train), total=N_train):
        synthesize_and_save(save_dir, i, background_class="tumor", foreground_class="stroma")

    for i in tqdm(range(N_train, N_train*2), total=N_train):
        synthesize_and_save(save_dir, i, background_class="stroma", foreground_class="tumor") 

    for i in tqdm(range(N_train*2, N_train*3), total=N_train):
        synthesize_and_save(save_dir, i, background_class="normal", foreground_class="stroma") 

    for i in tqdm(range(N_train*3, N_train*4), total=N_train):
        synthesize_and_save(save_dir, i, background_class="stroma", foreground_class="normal") 

  1%|          | 15/2500 [00:00<01:14, 33.58it/s]

100%|██████████| 2500/2500 [01:07<00:00, 37.25it/s]
100%|██████████| 2500/2500 [01:03<00:00, 39.39it/s]
100%|██████████| 2500/2500 [01:04<00:00, 39.02it/s]
100%|██████████| 2500/2500 [01:01<00:00, 40.49it/s]
100%|██████████| 2500/2500 [01:00<00:00, 41.62it/s]
100%|██████████| 2500/2500 [01:00<00:00, 41.17it/s]
100%|██████████| 2500/2500 [01:03<00:00, 39.50it/s]
100%|██████████| 2500/2500 [01:08<00:00, 36.75it/s]
100%|██████████| 2500/2500 [01:19<00:00, 31.57it/s]
100%|██████████| 2500/2500 [01:19<00:00, 31.55it/s]
100%|██████████| 2500/2500 [01:23<00:00, 29.85it/s]
100%|██████████| 2500/2500 [01:22<00:00, 30.47it/s]
100%|██████████| 2500/2500 [01:19<00:00, 31.30it/s]
100%|██████████| 2500/2500 [01:20<00:00, 30.89it/s]
100%|██████████| 2500/2500 [01:21<00:00, 30.57it/s]
100%|██████████| 2500/2500 [01:22<00:00, 30.45it/s]
100%|██████████| 2500/2500 [01:20<00:00, 31.13it/s]
100%|██████████| 2500/2500 [01:22<00:00, 30.40it/s]
100%|██████████| 2500/2500 [01:21<00:00, 30.58it/s]
100%|███████