In [None]:
import sys

sys.path.append("../")
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import re
import cv2
import json
from pprint import pprint
import random
import albumentations as alb

In [211]:
def extract_id(file_name: str) -> str | None:
    """
    Give a file name such as 'A_P000001_PAS_CPG.tif',
    Extract the file ID: 'A_P000001' using regular expression
    """
    match = re.match(r"([A-Z]_P\d+)_", file_name, re.IGNORECASE)

    if match:
        return match.group(1)
    else:
        return None


def cell_mask_to_rgb(cell_mask: np.ndarray) -> np.ndarray:
    """For visualization purposes
    Convert 2D cell mask to an RGB image
    2D cell mask:
    1 - Lympchoyte centroids
    2 - Monocytes centroids
    RGB image:
    Green - Lympchoyte centroids
    Blue - Monocytes centroids
    """
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
    cell_mask = cv2.dilate(cell_mask, kernel, iterations=1)
    rgb_mask = np.zeros(
        (cell_mask.shape[0], cell_mask.shape[1], 3), dtype=np.uint8
    )
    rgb_mask[cell_mask == 1] = (0, 255, 0)
    rgb_mask[cell_mask == 2] = (0, 0, 255)
    return rgb_mask


def get_augmentation():
    aug = alb.Compose(
        [
            alb.OneOf(
                [
                    alb.HueSaturationValue(
                        hue_shift_limit=10,
                        sat_shift_limit=(-40, 40),
                        val_shift_limit=5,
                        always_apply=False,
                        p=0.5,
                    ),  # .8
                    alb.RGBShift(
                        r_shift_limit=30,
                        g_shift_limit=30,
                        b_shift_limit=30,
                        p=0.5,
                    ),  # .7
                ],
                p=1,
            ),
            alb.OneOf(
                [
                    alb.GaussianBlur(blur_limit=(1, 3), p=0.5),
                    alb.Sharpen(
                        alpha=(0.1, 0.3), lightness=(1.0, 1.0), p=0.5
                    ),
                    alb.ImageCompression(
                        quality_lower=30, quality_upper=80, p=0.5
                    ),
                ],
                p=0.5,
            ),
            alb.RandomBrightnessContrast(
                brightness_limit=0.1, contrast_limit=0.3, p=0.5
            ),
            alb.ShiftScaleRotate(
                shift_limit=0.01,
                scale_limit=0.01,
                rotate_limit=180,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                p=0.8,
            ),
            alb.Flip(p=0.5),
        ],
        p=1,
    )
    return aug


aug = get_augmentation()

In [None]:
# Folder containing image patches
patch_image_dir = "/home/u1910100/Documents/Monkey/patches_256/images"
# Folder containing cell masks
cell_mask_dir = (
    "/home/u1910100/Documents/Monkey/patches_256/annotations/masks"
)
# Folder containing json files
json_dir = (
    "/home/u1910100/Documents/Monkey/patches_256/annotations/json"
)

# Visualize random patches
file_names = os.listdir(patch_image_dir)
files_sample = random.sample(file_names, 1)

for name in files_sample:
    name_without_ext = os.path.splitext(name)[0]

    # Get path to files
    patch_image_path = os.path.join(
        patch_image_dir, f"{name_without_ext}.npy"
    )
    cell_mask_path = os.path.join(
        cell_mask_dir, f"{name_without_ext}.npy"
    )
    json_path = os.path.join(json_dir, f"{name_without_ext}.json")

    # load RGB patch and cell mask
    image_patch = np.load(patch_image_path)
    cell_mask = np.load(cell_mask_path)
    # Convert cell mask to RGB image for visualization
    rgb_cell_mask = cell_mask_to_rgb(cell_mask)

    augmented_data = aug(image=image_patch, mask=rgb_cell_mask)
    aug_image, aug_mask = (
        augmented_data["image"],
        augmented_data["mask"],
    )
    # for augmentation, params in replay_info.items():
    #     print(f"Augmentation: {augmentation}, Parameters: {params}")

    # Output
    pprint(name)
    # pprint(replay_info)
    # pprint(annotations)
    fig, axes = plt.subplots(2, 3, figsize=(10, 10))
    axes[0][0].imshow(image_patch)
    axes[0][0].title.set_text("RGB Patch")
    axes[0][1].imshow(rgb_cell_mask)
    axes[0][1].title.set_text("Cell Mask")
    axes[0][2].imshow(image_patch, alpha=0.5)
    axes[0][2].imshow(rgb_cell_mask, alpha=0.5)
    axes[0][2].title.set_text("Cell Mask overlay on RGB Patch")

    axes[1][0].imshow(aug_image)
    axes[1][0].title.set_text("Aug RGB Patch")
    axes[1][1].imshow(aug_mask)
    axes[1][1].title.set_text("Aug Cell Mask")
    axes[1][2].imshow(aug_image, alpha=0.5)
    axes[1][2].imshow(aug_mask, alpha=0.5)
    axes[1][2].title.set_text("Cell Mask overlay on RGB Patch")
    plt.show()

Dataloader

In [None]:
%reset -f
from monkey.data.dataset import get_dataloaders
from monkey.data.data_utils import imagenet_denormalise
from monkey.config import TrainingIOConfig
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir="./",
)
batch_size = 4
train_loader, val_loader = get_dataloaders(
    IOconfig,
    val_fold=4,
    task=1,
    batch_size=batch_size,
    disk_radius=13,
    do_augmentation=True,
    module="detection",
)

In [None]:
data = next(iter(train_loader))
fig, axes = plt.subplots(batch_size, 3, figsize=(19, 19))

for i in range(batch_size):
    print(data["id"][i])
    image = data["image"][i].numpy()
    mask = data["mask"][i][0].numpy()

    image = np.moveaxis(image, 0, 2)
    image = imagenet_denormalise(image)

    axes[i][0].imshow(image)
    axes[i][1].imshow(mask, cmap="gray")
    axes[i][2].imshow(image, alpha=0.5)
    axes[i][2].imshow(mask, alpha=0.5)
plt.show()