In [1]:
import os
import pyvips
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

os.environ['VIPS_DISC_THRESHOLD'] = '10gb'

# Params
drop_thr = 0.20
size = 2048
scale = 0.25

# Paths
root = '/media/latlab/MR/projects/kaggle-ubc-ocean'
data_dir = os.path.join(root, 'data')
results_dir = os.path.join(root, 'results')
train_csv = 'train.csv'
train_image_dir = os.path.join(data_dir, 'train_images')
train_thumbnail_dir = os.path.join(data_dir, 'train_thumbnails')
out_dir = os.path.join(results_dir, f'train_tiles_{size}_p{scale*100}_drop{drop_thr*100}')

def extract_image_tiles(p_img, folder, size: int, scale: float, drop_thr: float) -> list:
    name, _ = os.path.splitext(os.path.basename(p_img))
    im = pyvips.Image.new_from_file(p_img)
    w = h = size
    # https://stackoverflow.com/a/47581978/4521646
    
    # TMA tiles (center crop)
    if im.height < 5000 and im.width < 5000: # TODO exact detection of TMA
        # w = h = size = size*2
        # TMA center crop
        tile_num = min(im.height, im.width)//size
        orig_x = (im.width - tile_num*size)//2
        orig_y = (im.height - tile_num*size)//2
        idxs = [(y, y + h, x, x + w) for y in range(orig_y, im.height, h) for x in range(orig_x, im.width, w)]
    else:
        # WSI tiles
        idxs = [(y, y + h, x, x + w) for y in range(0, im.height, h) for x in range(0, im.width, w)]

    files = []
    # Tiling
    for k, (y, y_, x, x_) in enumerate(idxs):
        # https://libvips.github.io/pyvips/vimage.html#pyvips.Image.crop
        tile = im.crop(x, y, min(w, im.width - x), min(h, im.height - y)).numpy()[..., :3]
        mask_bg = np.sum(tile, axis=2) == 0
        if np.sum(mask_bg) >= (np.prod(mask_bg.shape) * drop_thr):
            # print(f"skip almost empty tile: {k:06}_{int(x_ / w)}-{int(y_ / h)}")
            continue
        if tile.shape[:2] != (h, w):
            if any((1-np.array(tile.shape[:2])/size) >= drop_thr):
                continue
            tile_ = tile
            tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
            tile = np.zeros(tile_size, dtype=tile.dtype)
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        p_img = os.path.join(folder, f"{k:06}_{int(x_ / w)}-{int(y_ / h)}.png")
        # print(tile.shape, tile.dtype, tile.min(), tile.max())
        new_size = int(size * scale), int(size * scale)
        Image.fromarray(tile).resize(new_size, Image.LANCZOS).save(p_img)
        files.append(p_img)
    return files, idxs

## Show the image tiles with segmentations for one test image

In [2]:
test_img_id = '4'
# test_img_id = '48734'
!mkdir /tmp/ubc
tiles_img, _ = extract_image_tiles(os.path.join(train_image_dir, f'{test_img_id}.png'), '/tmp/ubc', size=size, scale=scale, drop_thr=drop_thr)
print(len(tiles_img))

names = [os.path.splitext(os.path.basename(p_img))[0] for p_img in tiles_img]
pos = [name.split("_")[-1] for name in names]
idx_x, idx_y = zip(*[list(map(int, p.split("-"))) for p in pos])
nb_rows = len(set(idx_y))
nb_cols = len(set(idx_x))
print(f"{nb_rows=}\n{nb_cols=}")

fig, axes = plt.subplots(nrows=nb_rows, ncols=nb_cols,
    figsize=(nb_cols * 0.5, nb_rows * 0.5)
)
axes = np.array(axes).reshape(nb_rows, nb_cols)

for p_img, x, y in zip(tiles_img, idx_x, idx_y):
    img = plt.imread(p_img)
    ax = axes[y - 1, x - 1]
    ax.imshow(img)
print(f"image size: {img.shape}")

for i in range(nb_rows):
    for j in range(nb_cols):
        axes[i, j].set_xticklabels([])
        axes[i, j].set_yticklabels([])

plt.subplots_adjust(wspace=0, hspace=0)
# # fig.tight_layout()

mkdir: cannot create directory ‘/tmp/ubc’: File exists


## Export all image tiles¶

In [None]:
def extract_prune_tiles(idx_path_img, folder, size: int, scale: float, drop_thr: float) -> None:
    idx, p_img = idx_path_img
    print(f"processing #{idx}: {p_img}")
    name, _ = os.path.splitext(os.path.basename(p_img))
    folder = os.path.join(folder, name)
    os.makedirs(folder, exist_ok=True)
    tiles, _ = extract_image_tiles(p_img, folder, size, scale, drop_thr)

In [None]:
import glob
from tqdm.auto import tqdm
from joblib import Parallel, delayed

os.makedirs(out_dir, exist_ok=True)

ls = sorted(glob.glob(os.path.join(train_image_dir, '*.png')))
print(f"found images: {len(ls)}")
img_name = lambda p_img: os.path.splitext(os.path.basename(p_img))[0]
    
_= Parallel(n_jobs=10)(
    delayed(extract_prune_tiles)
    (id_pimg, out_dir, size=size, drop_thr=drop_thr, scale=scale)
    for id_pimg in tqdm(enumerate(ls), total=len(ls))
)

## Show some samples

In [None]:
ls = [p for p in glob.glob(out_dir + '/*') if os.path.isdir(p)]
print(f"found folders: {len(ls)}")
ls = glob.glob(out_dir + '/*/*.png')
print(f"found images: {len(ls)}")

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(16, 16))
for i, p_img in enumerate(ls[:25]):
    img = plt.imread(p_img)
    ax = axes[i // 5, i % 5]
    ax.imshow(img)