# TMA generator from WSI
This notebook allows generating multiple TMA from WSI slides. It picks some random ellipses inside WSI according to provided tumor mask or generated OTSU mask or even randomly when no mask is available.

About magnification: UBC WSI are at x20 and TMA (only 25 instances) are at x40. We're generating TMA at x20 based on WSI with no down scaling.

TMA simulator is based on Albumentations. It crops a random ellipse in a given tile, add noise on contours and apply a background color with some random.

Tile selection has some constrains such as:
- Drop tile with high unicolor level (including black)
- Including some tissue or tumor based on a minimum ratio
- Squared tile with side among (1482, 1568, 1694). See SplitConfig.

If some black color (border) is captured then it's replaced by white.

Then the following stain/color augmentations are applied randomly on TMA:
- Vahadane (with some TMA references)
- Macenko (with generic reference)
- Reinhard (with some TMA references)

TMA generation can run in parallel across  multiple CPUs. Limit is driven by memory available and image size.

<b>Generated TMA examples displayed at the end of this notebook.<b>

### Packages
Install PyVips, StainTools/Spams and TorchStain packages

In [None]:
# Fix spams-python
# https://github.com/getspams/spams-python
!git clone https://github.com/getspams/spams-python

In [None]:
cd spams-python

In [None]:
!sed -i 's/np.bool/np.bool_/g' spams/spams.py

In [None]:
!pip install -e .

In [None]:
# Make sure staintools and torchstain are installed
!pip install staintools
!pip install torchstain

In [None]:
import numpy as np
import pandas as pd
pd.set_option('display.max_colwidth', 180)
import glob, os, time, random, gc
os.environ['VIPS_CONCURRENCY'] = '4'
os.environ['VIPS_DISC_THRESHOLD'] = '15gb'

from tqdm.notebook import tqdm
import PIL
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import cv2
import pyvips
import joblib
from concurrent.futures import ProcessPoolExecutor
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import shutil
from torchvision import transforms
import torch

import albumentations as A
import staintools
import torchstain

print('CV2:', cv2.__version__)
print('PIL:', PIL.__version__)
print('pyvips:', pyvips.__version__)
print('joblib:', joblib.__version__)
print("Pytorch", torch.__version__)
print("Albumentations", A.__version__)

In [None]:
def seed_everything(seed):
    """
    Seeds basic parameters for reproducibility of results.
    Args:
        seed (int): Number of the seed.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    
seed_everything(42)

In [None]:
HOME = "."
DATA_HOME = os.path.join(HOME, "train_images")
DATA_WSI_THUMBNAILS_HOME = os.path.join(HOME, "train_thumbnails")
MASK_HOME = "./supplemental-masks"
TMA_FOLDER = "./TMA"
MIN_CPU, MAX_CPU = 1, 4
MAX_MEM = 5000

### TMA simulation
- Pick a random ellipse
- Add noise on contours
- add background

In [None]:
class SimulateTMA(A.DualTransform):

    def __init__(self, std, radius_ratio=(0.9, 1.0), ellipse_ratio=(0.9, 1.0), angle=(-90., 90.), background_color=(-1, -1, -1), background_color_ratio=1.0, noise_level=(0.0, 0.0), black_replacement_color=None, always_apply=False, p=1.0):
        super(SimulateTMA, self).__init__(always_apply, p)
        self.std = std
        self.radius_ratio = radius_ratio
        self.ellipse_ratio = ellipse_ratio
        self.background_color = background_color
        self.background_color_ratio = background_color_ratio
        self.angle = angle
        self.noise_level = noise_level
        self.black_replacement_color = black_replacement_color

    def apply(self, img, **params):
        height, width = img.shape[:2]
        # Replace the black regions with the replacement color
        if self.black_replacement_color is not None:
            black_mask = np.all(img == [0, 0, 0], axis=-1)
            img[black_mask] = self.black_replacement_color
        img_std = np.std(img) if self.std[0] != -1 else 0  # (20, 50)
        if (self.std[0] == -1) or ((img_std <= self.std[1]) and (img_std >= self.std[0])):
            # Draw circle
            x_center = width // 2
            y_center = height // 2
            radius_w = int((width//2)*random.uniform(self.radius_ratio[0], self.radius_ratio[1]))  # Random radius
            radius_h = int(radius_w*random.uniform(self.ellipse_ratio[0], self.ellipse_ratio[1]))  # int((height//2)*random.uniform(self.radius_ratio[0], self.radius_ratio[1]))  # Random radius
            angle = int(random.uniform(self.angle[0], self.angle[1]))
            mask = cv2.ellipse(np.zeros_like(img), (x_center, y_center), (radius_w, radius_h), angle, 0, 360, color=(255, 255, 255), thickness=-1)
            # Add noise to the contour to mimic TMA
            if self.noise_level[1] > 0:
                contour = cv2.findContours(mask[:, :, 0], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
                contour_with_noise = contour[0] + np.random.randint(-self.noise_level[0], self.noise_level[1], contour[0].shape)
                int(random.uniform(self.angle[0], self.angle[1]))
                mask = cv2.drawContours(np.zeros_like(img), [contour_with_noise], -1, (255, 255, 255), -1)
            # Apply masks
            inverse_mask = cv2.bitwise_not(mask)
            # Background color
            bg_color = self.background_color
            if self.background_color == (-1, -1, -1):
                bg_ratio = random.uniform(self.background_color_ratio[0], self.background_color_ratio[1])
                bg_color = tuple((np.max(img, axis=(0,1))*bg_ratio).astype(np.uint8)) # Auto color
            color_outside_circle = np.zeros_like(img) # Black image
            color_outside_circle[:] = bg_color
            color_outside_circle = cv2.bitwise_and(color_outside_circle, inverse_mask)
            img = cv2.bitwise_and(img, mask)
            img = cv2.add(img, color_outside_circle)
        return img

    def get_transform_init_args_names(self):
        return ("std", "radius_ratio", "ellipse_ratio", "angle", "background_color", "background_color_ratio", "noise_level", "black_replacement_color")

### Stain/Color augmentations
- vahadane (with some TMA references)
- macenko (with generic reference)
- reinhard (with some TMA references)

In [None]:
class Stainer(A.DualTransform):

    def __init__(self, ref_images, method, luminosity=True, always_apply=False, p=1.0):
        super(Stainer, self).__init__(always_apply, p)
        self.luminosity = luminosity
        self.method = method
        self.stain_normalizer = []
        self.torchstain_T = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 255)
        ])
        if method == 'macenko':
            stain_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch')
            self.stain_normalizer.append(stain_normalizer)
        else:
            for ref_image in ref_images:
                ref_image = np.array(Image.open(ref_image))
                if method == 'reinhard':
                    stain_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='torch')
                    stain_normalizer.fit(self.torchstain_T(ref_image))
                    self.stain_normalizer.append(stain_normalizer)
                else:
                    ref_image = staintools.LuminosityStandardizer.standardize(ref_image) if self.luminosity == True else ref_image
                    stain_normalizer = staintools.StainNormalizer(method=method)
                    stain_normalizer.fit(ref_image)
                    self.stain_normalizer.append(stain_normalizer)

    def apply(self, img, **params):
        # Standardize brightness (optional, can improve the tissue mask calculation)
        if self.luminosity == True:
            img = staintools.LuminosityStandardizer.standardize(img)
        if self.method == 'macenko':
            stain_normalizer = self.stain_normalizer[0]
            img, _, _ = stain_normalizer.normalize(I=self.torchstain_T(img), stains=False)
            img = img.contiguous().cpu().numpy().astype(np.uint8)
        elif self.method == 'reinhard':
            stain_normalizer = np.random.choice(self.stain_normalizer, 1)[0]
            img = stain_normalizer.normalize(I=self.torchstain_T(img))
            img = img.contiguous().cpu().numpy().astype(np.uint8)
        else:
            stain_normalizer = np.random.choice(self.stain_normalizer, 1)[0]
            img = stain_normalizer.transform(img)
        return img

    def get_transform_init_args_names(self):
        return ("ref_images", "method", "luminosity")

### OTSU mask to capture tissue

In [None]:
def get_otsu_mask(img, wsi_scale, mthresh=7, sthresh=20, sthresh_up = 255, use_otsu = True):
    img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)  # Convert to HSV space
    img_med = cv2.medianBlur(img_hsv[:,:,1], mthresh)  # Apply median blurring
    # Thresholding
    if use_otsu:
        _, img_otsu = cv2.threshold(img_med, sthresh, sthresh_up, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
    else:
        _, img_otsu = cv2.threshold(img_med, sthresh, sthresh_up, cv2.THRESH_BINARY)
    # Morphological closing
    close = int(32*wsi_scale)
    if close > 0:
        kernel = np.ones((close, close), np.uint8)
        img_otsu = cv2.morphologyEx(img_otsu, cv2.MORPH_CLOSE, kernel)
    return img_otsu

### Generate thumbnail to visualize where are the TMAs

In [None]:
def generate_thumbnail(image, tiles, conf):
    scale = conf.tma_thumbnail_scale
    if scale is not None:
        image = image.resize(scale, kernel=conf.kernel)
        image = image.numpy()[..., :3]    
        for t in tiles:
            x1, y1, x2, y2 = t[4], t[5], t[6], t[7]
            uid = str(t[-1])
            image_id = t[1]
            x1 = int(x1*scale)
            y1 = int(y1*scale)
            x2 = int(x2*scale)
            y2 = int(y2*scale)
            x_center = int((x1 + x2)/2)
            y_center = int((y1 + y2)/2)    
            font = cv2.FONT_HERSHEY_DUPLEX
            font_scale = 16.0*scale
            font_thickness = int(30*scale)        
            uid_textsize = cv2.getTextSize(uid, font, font_scale, font_thickness)[0]
            
            # image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,255), thickness=int(2))
            image = cv2.circle(image, (x_center, y_center), (y2-y1)//2, color=(0,255,255), thickness=int(2))
            
            image = cv2.putText(image, text = uid, org = (x1+2, y1+2 + int(uid_textsize[1]*1.2)), fontFace = font, fontScale = font_scale, color = (0,0,0), thickness = font_thickness)
            image = cv2.putText(image, text = uid, org = (x1, y1 + int(uid_textsize[1]*1.2)), fontFace = font, fontScale = font_scale, color = (0,255,255), thickness = font_thickness)
        if conf.tma_folder is not None:
            img_dir = os.path.join(conf.tma_folder, "%s"%image_id)
            os.makedirs(img_dir, exist_ok=True)
            map_file = os.path.join(img_dir, "thumbnail_map.png")
            th = Image.fromarray(image)
            th.save(map_file)

### Generate TMAs from WSI
- Use mask (if available) with tumor ratio
- Use OTSU mask (if enough memory to manage it) with ratio
- Use standard deviation criteria if none of mask nor OTSU mask available

In [None]:
def tile_single_image(file, conf):
    seed_everything(42)
    
    max_tiles = conf.max_tiles
    drop_tile_color_ratio = conf.drop_tile_color_ratio
        
    tiles = []
    filepath, is_tma, has_mask = file
    image_id = int(filepath.split("/")[-1].replace(".png", ""))
    
    # TMA image x40
    if is_tma:
        image = pyvips.Image.new_from_file(filepath)
        if conf.tma_generation:
            if isinstance(conf.tma_scale, tuple):
                image = image.thumbnail_image(conf.tma_scale[1], height=conf.tma_scale[0])
            else:
                image = image.resize(conf.tma_scale, kernel=conf.kernel) if conf.tma_scale != 1. else image # x40 => x20, x10 ...         
        else:
            if isinstance(conf.tma_scale, tuple):
                image = image.thumbnail_image(conf.tma_scale[1], height=conf.tma_scale[0])
            else:
                image = image.resize(conf.tma_scale, kernel=conf.kernel) if conf.tma_scale != 1. else image # x40 => x20, x10 ...

    else:
        # Open image
        image = pyvips.Image.new_from_file(filepath) # default access is random and needs more memory/time # access='sequential'            
        image = image.resize(conf.wsi_scale, kernel=conf.kernel) if conf.wsi_scale != 1. else image # x20 => x10

    image_width = image.width
    image_height = image.height

    # Resume
    if conf.tma_folder is not None:
        img_dir = os.path.join(conf.tma_folder, "%s"%image_id)
        if os.path.exists(img_dir):
            return tiles

    if is_tma == False:
        # Pick box
        crop_side_w = conf.tma_crop[-1] # Biggest crop
        crop_side_h = crop_side_w
        if has_mask:
            # Open mask
            mask = pyvips.Image.new_from_file(os.path.join(MASK_HOME, "%s.png" % image_id))
            mask = mask.resize(conf.wsi_scale, kernel=pyvips.enums.Kernel.NEAREST) if conf.wsi_scale != 1. else mask # x20 => x10
            mask_width = mask.width
            mask_height = mask.height
            assert(mask_width == image_width)
            assert(mask_height == image_height)

            # Find ROI with tumor based on mask
            idxs = [(y, y + crop_side_h, x, x + crop_side_w, 0, 0, 0) for y in range(0, mask_height, crop_side_h) for x in range(0, mask_width, crop_side_w)]
            random.shuffle(idxs)
            for uid, (y, y_, x, x_, _, _, _) in enumerate(idxs):                                                
                # Update crop size randomly
                crop_side_w_ = np.random.choice(conf.tma_crop, 1)[0]
                crop_side_h_ = crop_side_w_                
                x1, y1, w1, h1 = x, y, min(crop_side_w_, mask_width - x), min(crop_side_h_, mask_height - y)
                if isinstance(mask, np.ndarray):
                    x1, y1, x2, y2 = x1, y1, x1 + w1, y1 + h1
                    mask_tumor_tile = mask[y1:y2, x1:x2]
                else:
                    mask_tumor_tile = mask.crop(x1, y1, w1, h1).numpy()[..., :3]
                    mask_tumor_tile = mask_tumor_tile[:, :, 0]
                tumoralr = (np.sum((mask_tumor_tile != 0).astype(int)))/(crop_side_w_*crop_side_h_)
                if tumoralr > conf.tma_tumoral_ratio:
                    tile_image = image.crop(x1, y1, w1, h1).numpy()[..., :3]
                    mask_black = np.all(tile_image == [0,0,0], axis=2) # Black to white
                    tile_image[mask_black] = [255, 255, 255]                                                    
                     # Make sure the tile is square
                    if (tile_image.shape[0] != crop_side_h_) or (tile_image.shape[1] != crop_side_w_):
                        continue
                    # Random augmentation
                    if np.random.random() > conf.tma_simulation_prob:
                        tile_image = conf.tma_simulation(image=tile_image)["image"] if conf.tma_simulation is not None else tile_image
                    tumor_tile = Image.fromarray(tile_image)
                    if conf.tma_folder is not None:
                        img_dir = os.path.join(conf.tma_folder, "%s"%image_id)
                        os.makedirs(img_dir, exist_ok=True)
                        tile_file = os.path.join(img_dir, os.path.basename(filepath).replace(".svs", ".png").replace(".png", "_%.2f_%d.png" % (tumoralr, uid)))
                        tumor_tile.save(tile_file)
                    tile_height, tile_width = tumor_tile.width, tumor_tile.height                            
                    # Track TMA selected
                    x1, y1, x2, y2 = x1, y1, x1 + w1, y1 + h1
                    tiles.append((filepath, image_id, image_width, image_height, x1, y1, x2, y2, tile_width, tile_height, 1, is_tma, tile_file, 
                                  has_mask, tumoralr, 0, uid))            
                    if len(tiles) >= conf.tma_max_tiles:
                        break
            generate_thumbnail(image, tiles, conf)
            del mask
        else:
            # No mask, pick tile randomly with otsu mask threshold
            if (conf.otsu_mask_zero_ratio is not None) and (image_width*image_height <= conf.otsu_mask_size_limit):
                tmp_img = image.numpy()[..., :3]
                mask = get_otsu_mask(tmp_img, conf.wsi_scale)                        
                del tmp_img                        
                gc.collect()
                mask_width = mask.shape[1]
                mask_height = mask.shape[0]
                assert(mask_width == image_width)
                assert(mask_height == image_height)

                img_dir = os.path.join(conf.tma_folder, str(image_id))
                os.makedirs(img_dir, exist_ok=True)             
                (Image.fromarray((mask).astype(np.uint8)).resize((mask_width//4, mask_height//4))).save(os.path.join(img_dir, "otsu_mask.png"))

                # Find ROI with tumor based on mask
                idxs = [(y, y + crop_side_h, x, x + crop_side_w, 0, 0, 0) for y in range(0, mask_height, crop_side_h) for x in range(0, mask_width, crop_side_w)]
                random.shuffle(idxs)
                for uid, (y, y_, x, x_, _, _, _) in enumerate(idxs):                                                
                    # Update crop size randomly
                    crop_side_w_ = np.random.choice(conf.tma_crop, 1)[0]
                    crop_side_h_ = crop_side_w_                
                    x1, y1, w1, h1 = x, y, min(crop_side_w_, mask_width - x), min(crop_side_h_, mask_height - y)
                    x1, y1, x2, y2 = x1, y1, x1 + w1, y1 + h1                            
                    tile_mask = mask[y1:y2, x1:x2]
                    tumoralr = (tile_mask == 0).sum()/(tile_mask.shape[1]*tile_mask.shape[0])
                    if tumoralr >= conf.otsu_mask_zero_ratio:
                        continue
                    tile_image = image.crop(x1, y1, w1, h1).numpy()[..., :3]
                    mask_black = np.all(tile_image == [0,0,0], axis=2) # Black to white
                    tile_image[mask_black] = [255, 255, 255]                                
                     # Make sure the tile is square
                    if (tile_image.shape[0] != crop_side_h_) or (tile_image.shape[1] != crop_side_w_):
                        continue
                    # Ignore uniform title
                    pix, pixcnt = np.unique(np.array(Image.fromarray(tile_image).convert("L")), return_counts=True)
                    pixcnt = (pixcnt/(tile_image.shape[0]*tile_image.shape[1]))
                    if (pixcnt >= conf.drop_tile_color_ratio).any():
                        continue                 
                    # Random augmentation
                    if np.random.random() > conf.tma_simulation_prob:
                        tile_image = conf.tma_simulation(image=tile_image)["image"] if conf.tma_simulation is not None else tile_image
                    tumor_tile = Image.fromarray(tile_image)
                    if conf.tma_folder is not None:
                        img_dir = os.path.join(conf.tma_folder, "%s"%image_id)
                        os.makedirs(img_dir, exist_ok=True)
                        tile_file = os.path.join(img_dir, os.path.basename(filepath).replace(".png", "_%d.png" % uid))
                        tumor_tile.save(tile_file)
                    tile_height, tile_width = tumor_tile.width, tumor_tile.height                            
                    tiles.append((filepath, image_id, image_width, image_height, x1, y1, x2, y2, tile_width, tile_height, 1, is_tma, tile_file, 
                                  has_mask, tumoralr, 0, uid))            
                    if len(tiles) >= conf.tma_max_tiles:
                        break
                generate_thumbnail(image, tiles, conf) 
                del mask
            # No mask, pick tile randomly with std in a range and not single color
            else:
                idxs = [(y, y + crop_side_h, x, x + crop_side_w, 0, 0, 0) for y in range(0, image_height, crop_side_h) for x in range(0, image_width, crop_side_w)]
                random.shuffle(idxs)
                for uid, (y, y_, x, x_, _, _, _) in enumerate(idxs):                        
                    # Update crop size randomly
                    crop_side_w_ = np.random.choice(conf.tma_crop, 1)[0]
                    crop_side_h_ = crop_side_w_                        
                    x1, y1, w1, h1 = x, y, min(crop_side_w_, image_width - x), min(crop_side_h_, image_height - y)
                    image_tile = image.crop(x1, y1, w1, h1).numpy()[..., :3]
                    mask_black = np.all(image_tile == [0,0,0], axis=2) # Black to white
                    image_tile[mask_black] = [255, 255, 255]                           
                    # Make sure the tile is square
                    if (image_tile.shape[0] != crop_side_h_) or (image_tile.shape[1] != crop_side_w_):
                        continue                        
                    # Ignore uniform title
                    pix, pixcnt = np.unique(np.array(Image.fromarray(image_tile).convert("L")), return_counts=True)
                    pixcnt = (pixcnt/(image_tile.shape[0]*image_tile.shape[1]))
                    if (pixcnt >= conf.drop_tile_color_ratio).any():
                        continue
                    image_std = np.std(image_tile) if conf.std is not None else None
                    if (image_std == None) or ((image_std > conf.std[0]) and (image_std <= conf.std[1])):
                        # Random augmentation
                        if np.random.random() > conf.tma_simulation_prob:
                            image_tile = conf.tma_simulation(image=image_tile)["image"] if conf.tma_simulation is not None else image_tile                      
                        if conf.tma_folder is not None:
                            img_dir = os.path.join(conf.tma_folder, "%s"%image_id)
                            os.makedirs(img_dir, exist_ok=True)
                            tile_file = os.path.join(img_dir, os.path.basename(filepath).replace(".png", "_%d.png" % uid))
                            tile = Image.fromarray(image_tile)
                            tile.save(tile_file)
                        tile_height, tile_width = tile.width, tile.height                            
                        # Track TMA selected
                        x1, y1, x2, y2 = x1, y1, x1 + w1, y1 + h1
                        tiles.append((filepath, image_id, image_width, image_height, x1, y1, x2, y2, tile_width, tile_height, 1, is_tma, tile_file, 
                                      has_mask, 0, 0, uid))            
                        if len(tiles) >= conf.tma_max_tiles:
                            break
                generate_thumbnail(image, tiles, conf)      
    else:
        # TMA
        if conf.tma_folder is not None:
            img_dir = os.path.join(conf.tma_folder, "%s"%image_id)
            os.makedirs(img_dir, exist_ok=True)
            tile_file = os.path.join(img_dir, os.path.basename(filepath).replace(".png", "_%d.png" % 0))
            (Image.fromarray(image.numpy()[..., :3])).save(tile_file)       
        x1, y1, x2, y2 = 0, 0, image_width-1, image_height-1
        tile_width, tile_height = image_width, image_height
        tumoralr = 1.0
        tiles.append((filepath, image_id, image_width, image_height, x1, y1, x2, y2, tile_width, tile_height, 1, is_tma, tile_file, 
                      has_mask, tumoralr, 0, 0))                
    del image
    gc.collect()                        
    return tiles

### Files and configuration
The more TMA references the longer initialization.

In [None]:
# List WSI files, thumbnails and masks
files = glob.glob(os.path.join(DATA_HOME, "*.png"))
thumbnails_files = glob.glob(os.path.join(DATA_WSI_THUMBNAILS_HOME, "*.png"))
masks_files = glob.glob(os.path.join(MASK_HOME, "*.png"))
files_pd = pd.DataFrame(files, columns=["file"])
files_pd["image_id"] = files_pd["file"].apply(lambda x: int(x.split("/")[-1].replace(".png", "")))
# Thumbnails are only for WSI so we know which images are TMA
thumbnails_files_pd = pd.DataFrame(thumbnails_files, columns=["file"])
thumbnails_files_pd["image_id"] = thumbnails_files_pd["file"].apply(lambda x: int(x.split("/")[-1].replace("_thumbnail.png", "")))
tma_list = thumbnails_files_pd["image_id"].unique()
masks_files_pd = pd.DataFrame(masks_files, columns=["file"])
masks_files_pd["image_id"] = masks_files_pd["file"].apply(lambda x: int(x.split("/")[-1].replace(".png", "")))
masks_list = masks_files_pd["image_id"].unique()
files_pd["is_tma"] = False
files_pd.loc[(~files_pd["image_id"].isin(tma_list)), "is_tma"] = True
files_pd["has_mask"] = False
files_pd.loc[(files_pd["image_id"].isin(masks_list)), "has_mask"] = True
print("WSI:", files_pd[files_pd["is_tma"] == False].shape)
print("TMA:", files_pd[files_pd["is_tma"] == True].shape)
print("MASK:", files_pd[files_pd["has_mask"] == True].shape)
tma_pd = files_pd[(files_pd["is_tma"] == True)].reset_index(drop=True)
files = files_pd[["file", "is_tma", "has_mask"]].values
print(files.shape, files[0:10])

In [None]:
# TMA references for stain augmentation
tma_images = tma_pd["file"].unique()
tma_pd.head()

In [None]:
def tma_augmentation(p=1.0):
    return A.Compose([
        SimulateTMA((-1, -1), radius_ratio=(0.6, 1.0), ellipse_ratio=(0.85, 1.15), angle=(-90., 90.), background_color=(-1, -1, -1), background_color_ratio=(0.80, 1.0), noise_level=(20./5, 100./5), black_replacement_color=None, p=1.0, always_apply=True),
        A.OneOf([
            Stainer(ref_images=tma_images, method='vahadane', luminosity=True, p=0.34),
            Stainer(ref_images=None, method='macenko', luminosity=False, p=0.33),
            Stainer(ref_images=tma_images, method='reinhard', luminosity=False, p=0.33),
        ], p=0.60),        
    ], p=p)

In [None]:
class SplitConfig:

    # Tiles
    max_tiles = 500
    std = (20, 80)

    # Resize interpolation
    resize_interpolation=Image.LANCZOS
    kernel = pyvips.enums.Kernel.LANCZOS3
    
    wsi_scale = 1.0 # Keep x20
    tma_scale = 0.5 # x40 => x20
    tma_crop = [1482, 1568, 1694]
    tma_folder = TMA_FOLDER
    tma_tumoral_ratio = 0.75
    tma_max_tiles = 500
    tma_simulation = tma_augmentation(p=1.0)
    tma_simulation_prob = 0.0 # 0.25
    drop_tile_color_ratio = 0.30 # 0.50
    otsu_mask_size_limit = 40000*40000 # To prevent OOM for large WSI
    otsu_mask_zero_ratio = 0.80
    tma_thumbnail_scale = 0.10

split_config = SplitConfig

### Sanity check
- Image size vs mask size
- Maximum TMAs per WSI

In [None]:
def probe_single_image(file, conf):
    tiles = []
    h, w = conf.tma_crop[0], conf.tma_crop[0]
    Image.MAX_IMAGE_PIXELS = None        
    # Open image
    filepath, is_tma, has_mask = file
    image_id = int(filepath.split("/")[-1].replace(".png", ""))    
    image = pyvips.Image.new_from_file(filepath)    
    image_width = image.width
    image_height = image.height
    if has_mask:
        mask_filepath = os.path.join(MASK_HOME, "%s.png"%image_id)
        if os.path.exists(mask_filepath):
            mask = Image.open(mask_filepath)            
            mask_width = mask.width
            mask_height = mask.height
            assert(image_width == mask_width)
            assert(image_height == mask_height)
        else:
            raise Exception("Expected mask not found %s" % mask_filepath)
    # Tile image
    idxs = [(y, y + h, x, x + w) for y in range(0, image_height, h) for x in range(0, image_width, w)]    
    tiles.append((filepath, image_id, image_width, image_height, len(idxs), is_tma, has_mask))
    del image
    return tiles

images_info = joblib.Parallel(n_jobs=1)(joblib.delayed(probe_single_image)(file, split_config) for file in files)
images_probe = []
for c in images_info:
    images_probe.extend(c)
images_probe_pd = pd.DataFrame(images_probe, columns=["file", "image_id", "width", "height", "tiles", "is_tma", "has_mask"])
images_probe_pd["surface"] = images_probe_pd["width"] * images_probe_pd["height"] / 1000000
# Create groups for adaptive jobs
images_probe_pd["jobs"] = images_probe_pd["surface"].apply(lambda x: max(MIN_CPU, min(MAX_CPU, np.floor(MAX_MEM/x)))).astype(np.int16)
# Keep non-TMA images
images_probe_pd = images_probe_pd.sort_values(["surface"], ascending=[True])
images_probe_pd = images_probe_pd[images_probe_pd["is_tma"] == False].reset_index(drop=True)
images_probe_pd

In [None]:
# Limit to a few WSI
MAX_WSI = 16
images_groups = pd.concat([images_probe_pd[images_probe_pd["has_mask"] == True].head(MAX_WSI//2), images_probe_pd[images_probe_pd["has_mask"] == False].head(MAX_WSI//2)], ignore_index=True)
images_groups

In [None]:
images_groups = images_groups.groupby("jobs")[["file", "is_tma", "has_mask"]].apply(lambda x: list(map(tuple,x.values))).reset_index().sort_values(["jobs"], ascending=[False])
images_groups.rename(columns={0: "file"}, inplace=True)

tiles = []
# Run by batch, each in parallel with CPUs depending on acceptable memory limit
for idx, row in images_groups.iterrows():
    jobs = int(row["jobs"])
    images_list = row["file"]
    images_tiled = joblib.Parallel(n_jobs=jobs)(joblib.delayed(tile_single_image)(file, split_config) for file in tqdm(images_list, total=len(images_list)))
    for c in images_tiled:
        tiles.extend(c)

### TMA visualization

In [None]:
for filepath, is_tma, has_mask in images_list:
    try:
        image_id = int(filepath.split("/")[-1].replace(".png", "")) 
        files = glob.glob(os.path.join(TMA_FOLDER, str(image_id), "*.png"))
        fig, ax = plt.subplots(1, 1, figsize=(32, 20))
        if has_mask:
            d = ax.imshow(Image.open(os.path.join(MASK_HOME, "%d.png" % image_id)), interpolation='none')
            d = ax.set_title("Tumor mask - We keep red only")
        elif os.path.join(TMA_FOLDER, str(image_id),'otsu_mask.png') in files:
            d = ax.imshow(Image.open(os.path.join(TMA_FOLDER, str(image_id),"otsu_mask.png")), interpolation='none')
            d = ax.set_title("OTSU mask")
        else:
            d = ax.imshow(Image.open(os.path.join(DATA_WSI_THUMBNAILS_HOME,"%d_thumbnail.png"%image_id)))
        d = plt.show()

        fig, ax = plt.subplots(1, 1, figsize=(32, 32))
        if has_mask:
            d = ax.imshow(Image.open(os.path.join(TMA_FOLDER, str(image_id),"thumbnail_map.png")))
            d = ax.set_title("Tiles selected for TMA - %d"%image_id)
        elif os.path.join(TMA_FOLDER, str(image_id),'otsu_mask.png') in files:
            d = ax.imshow(Image.open(os.path.join(TMA_FOLDER, str(image_id),"thumbnail_map.png")))
            d = ax.set_title("Tiles selected for TMA - %d"%image_id)
        else:
            d = ax.imshow(Image.open(os.path.join(TMA_FOLDER, str(image_id),"thumbnail_map.png")))
            d = ax.set_title("Tiles selected for TMA - %d"%image_id)
        d = plt.show()    

        files = [f for f in files  if "otsu_mask.png" not in f]
        files = [f for f in files  if "thumbnail_map.png" not in f]

        T = 4 # 8
        chunks = len(files)//T
        chunks = 1 if chunks ==0 else chunks
        files = files[0:T*chunks]
        for j, tmas in enumerate(np.array_split(files, chunks)):
            fig, ax = plt.subplots(1, T, figsize=(32, 20))
            for i, tma in enumerate(tmas):
                d = ax[i].imshow(Image.open(tma))
                d = ax[i].set_title(tma.split("/")[-1])
            plt.show()
            if j > 6:
                break
    except Exception as ex:
        print(ex)