In [1]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
from typing import Dict, List, Tuple

In [2]:
def map_array(array: np.ndarray, mapping: Dict[int, int]) -> np.ndarray:
    unique, inv = np.unique(array, return_inverse=True)
    return np.array([mapping[x] for x in unique])[inv].reshape(array.shape)


def pad_image(array: np.ndarray, tile_size: int) -> np.ndarray:
    if array.shape[0] == tile_size and array.shape[1] == tile_size:
        return array
    padded = 255 * np.ones((tile_size, tile_size, 3))
    padded = padded.astype(array.dtype)
    padded[:array.shape[0], :array.shape[1], :] = array
    return padded


def pad_mask(array: np.ndarray, tile_size: int) -> np.ndarray:
    if array.shape[0] == tile_size and array.shape[1] == tile_size:
        return array
    padded = np.zeros((tile_size, tile_size), dtype=array.dtype)
    padded[:array.shape[0], :array.shape[1]] = array
    return padded


def correct_mask_ids(array: np.ndarray) -> np.ndarray:
    mapping = {elem: idx + 1 for idx, elem in enumerate(set(array[np.nonzero(array > 0)]))}
    mapping[0] = 0
    return map_array(array, mapping)


def parse_lizard_instance(image_path: str, label_path: str, tile_size: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    tiles, tile_masks = [], []
    image = np.array(Image.open(image_path))
    label = sio.loadmat(label_path)
    counter = 1
    id_dict = {0: 0}
    for _id, _class in zip(np.squeeze(label["id"]), np.squeeze(label["class"])):
        if _class == 1:
            id_dict[_id] = counter
            counter += 1
        else:
            id_dict[_id] = 0
    mapped_instance_mask = map_array(label["inst_map"], id_dict)
    if not np.any(mapped_instance_mask):
        return tiles, tile_masks
    r_idx, c_idx = np.nonzero(mapped_instance_mask)
    r_min, r_max, c_min, c_max = int(np.min(r_idx)), int(np.max(r_idx)), int(np.min(c_idx)), int(np.max(c_idx))
    num_tiles_r = (r_max - r_min) // tile_size + ((r_max - r_min) % tile_size != 0)
    num_tiles_c = (c_max - c_min) // tile_size + ((c_max - c_min) % tile_size != 0)
    for _r in range(num_tiles_r):
        for _c in range(num_tiles_c):
            tile_mask = mapped_instance_mask[_r * tile_size: (_r + 1) * tile_size, _c * tile_size: (_c + 1) * tile_size]
            if not np.any(tile_mask):
                continue
            tile_masks.append(pad_mask(correct_mask_ids(tile_mask), tile_size))
            tile = image[_r * tile_size: (_r + 1) * tile_size, _c * tile_size: (_c + 1) * tile_size, :]
            tiles.append(pad_image(tile, tile_size))
    return tiles, tile_masks

In [3]:
image_dirs = [
    r"C:\Users\abdul\Desktop\TUM\thesis\lizard\lizard_images\train",
    r"C:\Users\abdul\Desktop\TUM\thesis\lizard\lizard_images\val",
]
label_dir = r"C:\Users\abdul\Desktop\TUM\thesis\lizard\lizard_labels\Labels"
data = dict()
for image_dir in image_dirs:
    for f in os.listdir(image_dir):
        tiles, masks = parse_lizard_instance(
            os.path.join(image_dir, f),
            os.path.join(label_dir, f.replace(".png", ".mat")),
            256,
        )
        if tiles and masks:
            data[f] = {
                "tiles": tiles,
                "masks": masks,
            }

In [4]:
# apply 80% - 20% split

num_neutrophils = [(key, sum(np.max(mask) for mask in value["masks"])) for key, value in data.items()]
num_neutrophils.sort(key=lambda x: x[1], reverse=True)
test_file_names = list()
for i in range(len(num_neutrophils) // 5 + (len(num_neutrophils) % 5 != 0)):
    tmp = num_neutrophils[i * 5 : (i + 1) * 5]
    if len(tmp) >= 3:
        test_file_names.append(tmp[2][0])

train_images, train_masks, test_images, test_masks = [], [], [], []
for key in data:
    tiles, masks = data[key]["tiles"], data[key]["masks"]
    if key in test_file_names:
        test_images.extend(tiles)
        test_masks.extend(masks)
    else:
        train_images.extend(tiles)
        train_masks.extend(masks)

print(len(train_images), len(train_masks))
print(len(test_images), len(test_masks))

710 710
174 174


In [5]:
# save images

for idx, (image, mask) in enumerate(zip(train_images, train_masks)):
    Image.fromarray(image).save(os.path.join("data", "lizard", "train", "images", f"image_{idx}.tif"))
    Image.fromarray(mask).save(os.path.join("data", "lizard", "train", "masks", f"image_{idx}.tif"))

for idx, (image, mask) in enumerate(zip(test_images, test_masks)):
    Image.fromarray(image).save(os.path.join("data", "lizard", "test", "images", f"image_{idx}.tif"))
    Image.fromarray(mask).save(os.path.join("data", "lizard", "test", "masks", f"image_{idx}.tif"))
