# Decompose image to tiles/grid 🖽

In [None]:
import numpy as np
from PIL import Image

def tile_image(p_img, folder, size: int = 1024) -> list:
    w = h = size
    im = np.array(Image.open(p_img))
    # https://stackoverflow.com/a/47581978/4521646
    tiles = [im[i:(i + h), j:(j + w), ...] for i in range(0, im.shape[0], h) for j in range(0, im.shape[1], w)]
    idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0], h) for j in range(0, im.shape[1], w)]
    name, _ = os.path.splitext(os.path.basename(p_img))
    files = []
    for k, tile in enumerate(tiles):
        if tile.shape[:2] != (h, w):
            tile_ = tile
            tile = np.zeros_like(tiles[0])
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        p_img = os.path.join(folder, f"{name}_{k:03}.png")
        Image.fromarray(tile).save(p_img)
        files.append(p_img)
    return files, idxs

In [None]:
import os

DATASET_IMAGES = "../input/hubmap-organ-segmentation/train_images/"
DATASET_MASKS = "../input/hacking-the-human-body-annotation-masks/train_masks"

!mkdir -p /kaggle/temp/images
!mkdir -p /kaggle/temp/masks

tiles_img, _ = tile_image(os.path.join(DATASET_IMAGES, "12233.tiff"), "/kaggle/temp/images", size=512)
tiles_seg, idxs = tile_image(os.path.join(DATASET_MASKS, "12233.png"), "/kaggle/temp/masks", size=512)

!ls -lh /kaggle/temp/images
!ls -lh /kaggle/temp/masks

## Show the image tiles with segmentations

In [None]:
import matplotlib.pyplot as plt
from skimage import color

nb_tiles_sqrt = int(np.sqrt(len(tiles_img)))

fig, axes = plt.subplots(nrows=nb_tiles_sqrt, ncols=nb_tiles_sqrt, figsize=(9, 9))
for i, (p_img, p_seg) in enumerate(zip(tiles_img, tiles_seg)):
    img = plt.imread(p_img)
    mask = np.array(Image.open(p_seg))
    axes[i // nb_tiles_sqrt, i % nb_tiles_sqrt].imshow(color.label2rgb(mask, img, bg_label=0, bg_color=(1.,1.,1.), alpha=0.25))
    axes[i // nb_tiles_sqrt, i % nb_tiles_sqrt].set_axis_off()
fig.tight_layout()

# Back recosntruction 🖼️

In [None]:
tiles = [np.array(Image.open(p_seg)) for p_seg in tiles_seg]
im = plt.imread(os.path.join(DATASET_IMAGES, "12233.tiff"))
seg = np.zeros(im.shape[:2], dtype=np.uint8)
for tile, (i1, i2, j1, j2) in zip(tiles, idxs):
    i2 = min(i2, im.shape[0])
    j2 = min(j2, im.shape[1])
    seg[i1:i2, j1:j2] = tile[:(i2 - i1), :(j2 - j1)]
plt.imshow(seg)

print(im.shape, seg.shape)