In [3]:
## replace PatchExtractor with iafoss tile extractor
##
## extracts the tile from a single image
##

import os
from functools import reduce

import pandas as pd
import skimage.io
import numpy as np
import matplotlib.pyplot as plt

INPUT_DIR = "../input/prostate-cancer-grade-assessment"
TRAIN_DIR = f"{INPUT_DIR}/train_images"
MASK_DIR = f"{INPUT_DIR}/train_label_masks"


BLACK = (0,) * 3
GRAY = (200,) * 3
WHITE = (255,) * 3
RED = (255, 0, 0)

N = 16


class PatchExtractor:
    def __init__(self, img, patch_size):
        """
        :param img: :py:class:`skimage.Image`
        :param patch_size: integer, size of the patch

        """
        self.img = img
        self.size = patch_size
        self.n_cols = 0
        self.n_rows = 0



    def extract_patches(self):
        """
        extracts all patches from an image
        :returns: A list of :py:class:`skimage.image` objects.
        """
        # add padding

        size = self.size

        H, W = self.img.shape[:2]
        pad_h = (size - H % size) % size
        pad_w = (size - W % size) % self.size

        print("pad_h:", pad_h)
        print("pad_w", pad_w)

        padded = np.pad(
            self.img,
            [[pad_h // 2, pad_h - pad_h // 2],
             [pad_w // 2, pad_w - pad_w // 2],
             [0, 0]],
            constant_values = WHITE[0],
        )

        n_rows = padded.shape[0] // size
        n_cols = padded.shape[1] // size

        self.n_rows = n_rows
        self.n_cols = n_cols

        print("n_rows :", n_rows)
        print("n_cols :", n_cols)
        print("N_TILES:", n_rows * n_cols)

        reshaped = padded.reshape(
            padded.shape[0] // size,
            size,
            padded.shape[1] // size,
            size,
            3,
        )
        transposed = reshaped.transpose(0, 2, 1, 3, 4)
        tiles = transposed.reshape(-1, size, size, 3)

        print("reshaped.shape  :", reshaped.shape)
        print("transposed.shape:", transposed.shape)
        print("tiles.shape     :", tiles.shape)

        sums = tiles.reshape(tiles.shape[0], -1).sum(axis=-1)
        tiles_selected = np.argsort(sums)[:N]

        return padded,tiles_selected

    def shape(self):
        return self.n_rows,self.n_cols


In [None]:
# wrapper to test patch extractor (can probably delete after testing)

SIZE = 128

def merge_tiles(tiles, funcs=None):
    """
    If `funcs` specified, apply them to each tile before merging.
    """
    return np.vstack([
        np.hstack([
            reduce(lambda acc, f: f(acc), funcs, x) if funcs else x
            for x in row
        ])
        for row in tiles
    ])


def draw_borders(img):
    """
    Put borders around an image.
    """
    ret = img.copy()
    ret[0, :] = GRAY   # top
    ret[-1, :] = GRAY  # bottom
    ret[:, 0] = GRAY   # left
    ret[:, -1] = GRAY  # right
    return ret


def fill_tiles(tiles, fill_func):
    """
    Fill each tile with another array created by `fill_func`.
    """
    return np.array([[fill_func(x) for x in row] for row in tiles])


def make_patch_func(true_color, false_color):
    def ret(x):
        """
        Retunrs a color patch. The color will be `true_color` if `x` is True otherwise `false_color`.
        """
        color = true_color if x else false_color
        return np.tile(color, (SIZE, SIZE, 1)).astype(np.uint8)

    return ret


def imshow(
    img,
    title=None,
    show_shape=True,
    figsize=(8, 8)
):
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(img)
    ax.grid("off")
    ax.set_xticks([])
    ax.set_yticks([])

    if show_shape:
        ax.set_xlabel(f"Shape: {img.shape}", fontsize=16)

    if title:
        ax.set_title(title, fontsize=16)

    return ax

def imread(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"No such file or directory: '{path}'")

    return skimage.io.MultiImage(path)


img_id = "000920ad0b612851f8e01bcc880d9b3d"
img_org = imread(os.path.join(TRAIN_DIR, f"{img_id}.tiff"))[-1]

imshow(img_org, "Original image")

PATCH_SIZE = 128

extractor = PatchExtractor(img_org,PATCH_SIZE)
padded, tiles, sums = extractor.extract_patches()

n_rows, n_cols = extractor.shape()

mask = np.isin(np.arange(len(sums)), tiles).reshape(n_rows, n_cols)
mask = fill_tiles(mask, make_patch_func(RED, WHITE))
mask = merge_tiles(mask, [draw_borders])

with_mask = np.ubyte(0.7 * padded + 0.3 * mask)

imshow(with_mask, "Selected tiles")
