In [None]:
from typing import List
import numpy as np
from PIL import Image
import os
import argparse
from tqdm import tqdm
from collections import Counter
from pathlib import Path

import torch
import matplotlib.pyplot as plt


In [None]:
def visualize(**images):
    """PLot images in one row."""
    fontsize=14
    n = len(images)
    fig, axarr = plt.subplots(nrows=1, ncols=n, figsize=(8, 8))
    for i, (name, image) in enumerate(images.items()):
        if isinstance(image, torch.Tensor):
            if image.ndim == 3: image = image.permute(1, 2, 0)
            if image.is_cuda: image = image.detach().cpu().numpy()
        if 'mask' in name: 
            palette = [0, 64, 128, 64, 128, 0, 243, 152, 0, 255, 255, 255] + [0] * 252 * 3
            image = Image.fromarray(np.uint8(image), mode='P')
            image.putpalette(palette)
            axarr[i].imshow(image)
            axarr[i].set_title(name, fontsize=fontsize)
        else:
            axarr[i].imshow(image)
            axarr[i].set_title(name, fontsize=fontsize)
            
    for ax in axarr.ravel():
        ax.set_yticks([])
        ax.set_xticks([])
    plt.tight_layout()
    plt.show()
    plt.close()

In [None]:
def online_cut_patches(im, im_size, stride):
    """
    function for crop the image to subpatches, will include corner cases
    the return position (x,y) is the up left corner of the image
    
    Args:
        im (np.ndarray): the image for cropping
        im_size (int): the sub-image size.
        stride (int): the pixels between two sub-images.
    
    Returns:
        (list, list): list of image reference and list of its corresponding positions
    """
    im_list = []
    position_list = []

    h, w, _ = im.shape
    if h < im_size:
        h_ = np.array([0])
    else:
        h_ = np.arange(0, h - im_size + 1, stride)
        if h % stride != 0:
            h_ = np.append(h_, h-im_size)

    if w < im_size:
        w_ = np.array([0])
    else:
        w_ = np.arange(0, w - im_size + 1, stride)
        if w % stride != 0:
            w_ = np.append(w_, w - im_size)

    for i in h_:
        for j in w_:   	
            temp = np.uint8(im[i:i+im_size,j:j+im_size,:])
            im_list.append(temp)
            position_list.append((i,j))
    return im_list, position_list


In [None]:
def multiscale_online_crop(im, im_size, stride, scales):
    """
    First resize the image to different scales, then crop according to `im_size`

    Returns:
        scale_im_list: the image list
        [
            im_list_of_scale_1: [im_1, im_2, ...],
            im_list_of_scale_2: [im_1, im_2, ...],
            ...,
            im_list_of_scale_n: [im_1, im_2, ...],
        ]
        scale_position_list: the images position
        [
            pos_list_of_scale_1: [(x1, y1), (x2, y2), ...],
            pos_list_of_scale_2: [(x1, y1), (x2, y2), ...],
            ...,
            pos_list_of_scale_n: [(x1, y1), (x2, y2), ...],
        ]
    
    """
    # ----> Get the PIL.Image object
    im = Image.fromarray(im)
    w, h = im.size

    scale_im_list = []
    scale_position_list = []

    # ----> For each scale
    for scale in scales:
        scaled_im = np.asarray(im.resize((int(w * scale), int(h * scale))))
        
        im_list, position_list = online_cut_patches(scaled_im, im_size, stride) # im_size: 224, stride: 75
        scale_im_list.append(im_list)
        scale_position_list.append(position_list)

    return scale_im_list, scale_position_list

In [None]:
validation_image_dir = Path('data/WSSS4LUAD/2.validation/img')
validation_image_paths = sorted(list(validation_image_dir.glob('*.png')))
validation_mask_dir = Path('data/WSSS4LUAD/2.validation/mask')
validation_mask_paths = [validation_mask_dir / f'{p.stem}.png' for p in validation_image_paths]
len(validation_mask_paths) # 40

In [None]:
for image_path, mask_path in tqdm(zip(validation_image_paths, validation_mask_paths)):
    image_name = image_path.stem
    image_data = np.array(Image.open(image_path))

    im_size = 224
    stride = 56
    scales = [1]
    scale_img_list, scale_position_list = multiscale_online_crop(image_data, im_size, stride, scales)
    
    image_patch_dir = Path(f'data/WSSS4LUAD/2.validation/patches_{im_size}_{stride}/img')
    image_patch_dir.mkdir(parents=True, exist_ok=True)
    mask_patch_dir = Path(f'data/WSSS4LUAD/2.validation/patches_{im_size}_{stride}/mask')
    mask_patch_dir.mkdir(parents=True, exist_ok=True)

    for scale, im_list, position_list in zip(scales, scale_img_list, scale_position_list):
        w, h = Image.open(mask_path).size
        mask_data = np.array(Image.open(mask_path).resize((int(w*scale), int(h*scale)), resample=Image.BILINEAR))

        for image_patch, position in zip(im_list, position_list):
            i, j = position
            mask_patch = np.uint8(mask_data[i:i+im_size, j:j+im_size])

            patch_label = [0, 0, 0]
            for cat in range(3):
                if cat in mask_patch:
                    patch_label[cat] = 1

            image_patch = Image.fromarray(np.uint8(image_patch))
            mask_patch = Image.fromarray(np.uint8(mask_patch), mode='P')
            palette = [0, 64, 128, 64, 128, 0, 243, 152, 0, 255, 255, 255] + [0] * 252 * 3
            mask_patch = Image.fromarray(np.uint8(mask_patch), mode='P')
            mask_patch.putpalette(palette)

            image_patch.save(image_patch_dir / f'{image_name}_{scale}_{i}_{j}-{patch_label}.png')
            mask_patch.save(mask_patch_dir / f'{image_name}_{scale}_{i}_{j}-{patch_label}.png')
        
        


In [None]:
for image_path, mask_path in tqdm(zip(validation_image_paths, validation_mask_paths)):
    image_name = image_path.stem
    image_data = np.array(Image.open(image_path))

    im_size = 224
    stride = 112
    scales = [1, 1.25, 1.5, 1.75, 2]
    scale_img_list, scale_position_list = multiscale_online_crop(image_data, im_size, stride, scales)
    
    image_patch_dir = Path(f'data/WSSS4LUAD/2.validation/patches_{im_size}_{stride}/img')
    image_patch_dir.mkdir(parents=True, exist_ok=True)
    mask_patch_dir = Path(f'data/WSSS4LUAD/2.validation/patches_{im_size}_{stride}/mask')
    mask_patch_dir.mkdir(parents=True, exist_ok=True)

    for scale, im_list, position_list in zip(scales, scale_img_list, scale_position_list):
        w, h = Image.open(mask_path).size
        mask_data = np.array(Image.open(mask_path).resize((int(w*scale), int(h*scale)), resample=Image.BILINEAR))

        for image_patch, position in zip(im_list, position_list):
            i, j = position
            mask_patch = np.uint8(mask_data[i:i+im_size, j:j+im_size])

            patch_label = [0, 0, 0]
            for cat in range(3):
                if cat in mask_patch:
                    patch_label[cat] = 1

            image_patch = Image.fromarray(np.uint8(image_patch))
            mask_patch = Image.fromarray(np.uint8(mask_patch), mode='P')
            palette = [0, 64, 128, 64, 128, 0, 243, 152, 0, 255, 255, 255] + [0] * 252 * 3
            mask_patch = Image.fromarray(np.uint8(mask_patch), mode='P')
            mask_patch.putpalette(palette)

            image_patch.save(image_patch_dir / f'{image_name}_{scale}_{i}_{j}-{patch_label}.png')
            mask_patch.save(mask_patch_dir / f'{image_name}_{scale}_{i}_{j}-{patch_label}.png')
        
        


# Split Test Dataset

In [None]:
test_image_dir = Path('data/WSSS4LUAD/3.testing/img')
test_image_paths = sorted(list(test_image_dir.glob('*.png')))
test_mask_dir = Path('data/WSSS4LUAD/3.testing/mask')
test_mask_paths = [test_mask_dir / f'{p.stem}.png' for p in test_image_paths]
len(test_mask_paths) # 80

In [None]:
for image_path, mask_path in tqdm(zip(test_image_paths, test_mask_paths)):
    image_name = image_path.stem
    image_data = np.array(Image.open(image_path))
    # print(f'H: {image_data.shape[0]}, W: {image_data.shape[1]}')

    im_size = 224
    stride = 112
    scales = [1, 1.25, 1.5, 1.75, 2]
    scale_img_list, scale_position_list = multiscale_online_crop(image_data, im_size, stride, scales)
    
    image_patch_dir = Path(f'data/WSSS4LUAD/3.testing/patches_{im_size}_{stride}/img')
    image_patch_dir.mkdir(parents=True, exist_ok=True)
    mask_patch_dir = Path(f'data/WSSS4LUAD/3.testing/patches_{im_size}_{stride}/mask')
    mask_patch_dir.mkdir(parents=True, exist_ok=True)

    for scale, im_list, position_list in zip(scales, scale_img_list, scale_position_list):
        w, h = Image.open(mask_path).size
        mask_data = np.array(Image.open(mask_path).resize((int(w*scale), int(h*scale)), resample=Image.BILINEAR))

        for image_patch, position in zip(im_list, position_list):
            i, j = position
            mask_patch = np.uint8(mask_data[i:i+im_size, j:j+im_size])

            patch_label = [0, 0, 0]
            for cat in range(3):
                if cat in mask_patch:
                    patch_label[cat] = 1

            image_patch = Image.fromarray(np.uint8(image_patch))
            mask_patch = Image.fromarray(np.uint8(mask_patch), mode='P')
            palette = [0, 64, 128, 64, 128, 0, 243, 152, 0, 255, 255, 255] + [0] * 252 * 3
            mask_patch = Image.fromarray(np.uint8(mask_patch), mode='P')
            mask_patch.putpalette(palette)

            image_patch.save(image_patch_dir / f'{image_name}_{scale}_{i}_{j}-{patch_label}.png')
            mask_patch.save(mask_patch_dir / f'{image_name}_{scale}_{i}_{j}-{patch_label}.png')
        
        
