### Data exploration and preprocessing for the [...] dataset

In [None]:
"""
Load the source dataset and preprocess it

__ imgs
    |
    |__ fold_1
          |__ datasetname_id.png
          |__ ...
    |__ fold_2
          |__ datasetname_id.png
          |__ ...
    |__ fold_3
          |__ datasetname_id.png
          |__ ...
    |__ fold_4
          |__ datasetname_id.png
          |__ ...
          |__ datasetname_id.png
          |__ ...

__ masks
    |
    |__ fold_1
          |__ datasetname_id.png  --> segmentation mask
          |__ datasetname_id.json --> segmentation box (COCO format)
          |__ ...
    |__ fold_2
          |__ datasetname_id.png
          |__ datasetname_id.json
          |__ ...
    |__ fold_3
          |__ datasetname_id.png
          |__ datasetname_id.json
          |__ ...
    |__ fold_4
          |__ datasetname_id.png
          |__ datasetname_id.json
          |__ ...
    |__ fold_5
          |__ datasetname_id.png
          |__ datasetname_id.json
          |__ ...

For segmentation masks --> 0 and 1 (if more than 2 labels --> use SimpleITK to encode integers in images)
For segmentation boxes --> use the COCO format to encode boxes
"""

In [None]:
import os
import zipfile
import random
import shutil
from PIL import Image
from tqdm import tqdm
import json 
from PIL import ImageDraw

import matplotlib.pyplot as plt
import numpy as np

from helpers import check_dir_consistency, visualize_images, get_all_dims, create_histograms_of_dims

Please follow the suggested of download order.

## 1. Train
### 1.1 Original images from DOTA v1
- Get the 'image' folder from https://drive.google.com/drive/folders/1gmeE3D7R62UAtuIFOB9j2M5cUPTwtsxK
- 3 zips in total (1411 images):
    - `part1.zip` (469 images)
    - `part2.zip` (474 images)
    - `part3.zip` (468 images)
- Note: do not take '~1/part1.zip' in addition to '~part1.zip'

### 1.2 Masks from isaid
- Get all folders from https://drive.google.com/drive/folders/19RPVhC0dWpLF9Y_DYjxjUrwLbKUBQZ2K
- 1 zip `train-20240131T084536Z-001.zip`


## 2. Valid
### 2.1. Original images from DOTA v1
- Get the 'image' folder from https://drive.google.com/drive/folders/1RV7Z5MM4nJJJPUs6m9wsxDOJxX6HmQqZ
- 1 zip file (458 images): 
    - `part1.zip` (458 images)
- If you have followed suggested download order, this file should be renamed by your operating system (e.g., `part1 (1).zip`)
### 2.2. Masks from isaid
- Get all folders from https://drive.google.com/drive/folders/17MErPhWQrwr92Ca1Maf4mwiarPS5rcWM
- 1 zip `val-20240131T084706Z-001.zip`

## 3. Test
Download is not needed given that annotations of the test set are not publicly available.

There are a total of 1869 images for which annotations are provided.
The remaining 937 images from the test set are not used.

## Prepare data

In [None]:
seed = 42
nb_folds = 5

### Directories

In [None]:
input_dir = '/home/rob/Documents/3_projects/bench/isaid/data' # <--- modify this
output_dir = '/home/rob/Documents/3_projects/bench/isaid/processed' # <--- modify this

file_names = os.listdir(input_dir)
file_names

In [None]:
data_sources = [os.path.join(input_dir, f) for f in file_names]
print(f'Data file location: {data_sources}')

In [None]:
tmp_data_dir = os.path.join(output_dir, 'tmp_data')
os.makedirs(tmp_data_dir, exist_ok=True)
print(tmp_data_dir)

In [None]:
def unzip_dirs(data_sources: str, tmp_data_dir: str) -> list:
    subdirs: list = []

    for data_source in tqdm(data_sources, desc='Unzip dir'):
        subdir_name = os.path.splitext(os.path.basename(data_source))[0]
        subdir_path = os.path.join(tmp_data_dir, subdir_name)
        subdirs.append(subdir_path)

        # Make destination dir
        os.makedirs(subdir_path, exist_ok=True)
        
        # Unzip
        with zipfile.ZipFile(data_source, 'r') as zip_file:
            print(f'Unzipping {data_source} into {subdir_path}')
            zip_file.extractall(subdir_path)

        # Clear new dir
        images_dir = os.path.join(subdir_path, 'images')
        if os.path.exists(images_dir):
            # Move all images one directory higher in dir tree
            for file_name in os.listdir(images_dir):
                shutil.move(os.path.join(images_dir, file_name), subdir_path)
            # Remove empty 'image' subdirectory
            os.rmdir(images_dir)
    return subdirs

subdirs = unzip_dirs(data_sources, tmp_data_dir)

###

## Unzip train & val masks

In [None]:
def unzip_masks(source, destination) -> int:
    result: int = 0
    print(f'Begin to unzip from {source} to {destination}')
    os.makedirs(destination, exist_ok=True)
    
    with zipfile.ZipFile(source, 'r') as zip_ref:
        zip_ref.extractall(destination)
    
    print(f'All Semantic masks extracted in {destination}')
    
    images_sub_dir = os.path.join(destination, 'images')
    if os.path.exists(images_sub_dir):
        for filename in os.listdir(images_sub_dir):
            shutil.move(os.path.join(images_sub_dir, filename), destination)
            result += 1
        os.rmdir(images_sub_dir)
        
    return result

In [None]:
train_masks_dir = os.path.join(tmp_data_dir, 'train-20240131T084536Z-001', 'train', 'Semantic_masks', 'images.zip')
mask_output_dir = os.path.join(tmp_data_dir, 'all_masks')

valid_masks_dir = os.path.join(tmp_data_dir, 'val-20240131T084706Z-001', 'val', 'Semantic_masks', 'images.zip')
valid_masks_new_dir = os.path.join(tmp_data_dir, 'all_masks')

unzip_masks(train_masks_dir, mask_output_dir)
unzip_masks(valid_masks_dir, mask_output_dir)


## Move all images

In [None]:
images_source_folders: list = ['part1', 'part1 (1)', 'part2', 'part3']
images_output_dir: str = os.path.join(tmp_data_dir, 'all_images')

In [None]:
for folder in images_source_folders:
    # Get the absolute path of the source folder
    source_folder_path = os.path.join(tmp_data_dir, folder)
    # Check if the source folder exists
    if os.path.exists(source_folder_path):
        # Loop through each file in the source folder
        for filename in os.listdir(source_folder_path):
            # Get the full path of the file
            source_file_path = os.path.join(source_folder_path, filename)
            # Get the destination path in the output directory
            destination_file_path = os.path.join(images_output_dir, filename)
            # Move the file to the output directory
            shutil.move(source_file_path, destination_file_path)
    else:
        print(f"Source folder {source_folder_path} does not exist")

print(f"All images have been moved to {images_output_dir}")

## Create patches

In [None]:
def check_dir_consistency_info(img_folder: str, mask_folder: str):
    """
    Checks for consistency between image files and their corresponding mask files in two folders.

    This function verifies that the number of image files in img_folder is equal to the number of mask files 
    in mask_folder. It then checks that each image file has a corresponding mask file with the same name.

    Parameters:
    img_folder (str): Path to the folder containing images.
    mask_folder (str): Path to the folder containing masks.

    Raises:
    AssertionError: If the number of files in img_folder and mask_folder is not equal.
    ValueError: If an image file does not have a corresponding mask file with the same name in mask_folder.

    Returns:
    tuple: A tuple of two lists, the first being the image files and the second being the mask files.
    
    """
    missing_files = False
    img_files = os.listdir(img_folder)

    mask_files = os.listdir(mask_folder)

    assert len(img_files) == len(mask_files)

    # Verify that all images have a corresponding mask
    for img_file in img_files:
        mask_postfix = 'instance_color_RGB'
        img_file = img_file.split('.')[0] + '_' + mask_postfix + '.png'
        if img_file not in mask_files:
            print(f"Mask corresponding to image `{img_file}` was not found.")
            missing_files = True

    print(f"Missing files ? {missing_files}")

    return img_files, mask_files

In [None]:
img_dir, mask_dir = images_output_dir, mask_output_dir
img_files, mask_files = check_dir_consistency_info(img_dir, mask_dir)

### Visualize Data

In [None]:
visualize_images(img_dir, mask_dir, seed=seed, mask_postfix='instance_color_RGB')

In [None]:
all_dims = get_all_dims(img_dir)

In [None]:
min_width = min(dim[0] for dim in all_dims)
min_height = min(dim[1] for dim in all_dims)
print(f"Minimum width: {min_width}")
print(f"Minimum height: {min_height}")

In [None]:
create_histograms_of_dims(img_dir, mult=True)

In [None]:
#create_histograms_of_dims(img_dir)

In [None]:
def count_images_below_threshold(img_dir, threshold: int = 448):
    count = 0
    images_too_small: list = []

    for img_file in os.listdir(img_dir):
        file_path = os.path.join(img_dir, img_file)

        # Open image and get dim
        try:
            with Image.open(file_path) as img:
                width, height = img.size
                if width < threshold or height < threshold:
                    count += 1
                    images_too_small.append(file_path)
        except IOError:
            print(f"Cannot open the file: {file_path}")

    return count, images_too_small

patch_size = 448
num_images_too_small, img_too_small = count_images_below_threshold(img_dir, threshold=patch_size)
num_mask_too_small, mask_too_small = count_images_below_threshold(mask_dir, threshold=patch_size)
print(f'Found {num_images_too_small} images and {num_mask_too_small} masks below threhsold.')

# Remove images and masks
if num_images_too_small == num_images_too_small:
    [os.remove(img) for img in img_too_small]
    [os.remove(img) for img in mask_too_small]
else:
    raise Exception(f"There is a mismatch between your images and masks!")

In [None]:
n_patches = 30
output_imgs = os.path.join(tmp_data_dir, 'imgs_patches')
output_masks = os.path.join(tmp_data_dir, 'masks_patches')

[os.makedirs(dir, exist_ok=True) for dir in [output_imgs, output_masks]]

## Create Splits

In [None]:
def create_folds_and_split_data(img_folder, mask_folder, 
                                output_dir, seed=42, n_folds=5,
                                mask_postfix='instance_color_RGB') -> tuple:
    # Set random seed
    np.random.seed(seed)
    
    # Create folder structure
    img_output_dir = os.path.join(output_dir, 'tmp_imgs')
    mask_output_dir = os.path.join(output_dir, 'tmp_masks')

    for i in range(n_folds):
        os.makedirs(os.path.join(img_output_dir, f'fold_{i}'), exist_ok=True)
        os.makedirs(os.path.join(mask_output_dir, f'fold_{i}'), exist_ok=True)

    # Shuffle and split images
    img_files = sorted(os.listdir(img_folder))
    np.random.shuffle(img_files)

    fold_size = len(img_files) // n_folds
    print(f'Size of fold: {fold_size}')

    for i, img_file in tqdm(enumerate(img_files), desc='files'):
        fold_num = min(i // fold_size, n_folds - 1)  # Avoid exceeding the number of folds

        # Copy image file
        img_src = os.path.join(img_folder, img_file)
        img_dst = os.path.join(img_output_dir, f'fold_{fold_num}', img_file)
        shutil.copy(img_src, img_dst)

        # Copy corresponding masks
        mask_file = img_file.split('.')[0] + '_' + mask_postfix + '.png'
        mask_src = os.path.join(mask_folder, mask_file)

        mask_dst = os.path.join(mask_output_dir, f'fold_{fold_num}', mask_file)
        shutil.copy(mask_src, mask_dst)
        

    return img_output_dir, mask_output_dir

img_output_dir, mask_output_dir = create_folds_and_split_data(img_dir, mask_dir, output_dir, seed, n_folds=nb_folds)

In [None]:
def verify_image_mask_correspondence(output_dir, img_folder, mask_folder, mask_postfix='instance_color_RGB', n_folds=5):
    """
    Verify that every image in each fold has corresponding masks for each category and vice versa,
    and that the total count of images and masks matches the original dataset.

    Args:
    output_dir (str): The base directory where the 'imgs' and 'masks' subdirectories are located.
    img_folder (str): The directory containing the original images.
    mask_folder (str): The directory containing the original masks.
    n_categories (int): Number of categories for each image.
    n_folds (int): Number of folds to check for correspondence.

    Raises:
    Exception: If an image doesn't have corresponding masks, a mask doesn't have a corresponding image,
               or the total count of images and masks doesn't match the original dataset.

    Returns:
    bool: True if all tests pass, False otherwise.
    """
    img_output_dir = os.path.join(output_dir, 'tmp_imgs')
    mask_output_dir = os.path.join(output_dir, 'tmp_masks')
    original_img_count = len(os.listdir(img_folder))
    original_mask_count = len(os.listdir(mask_folder)) 
    
    total_img_count = 0
    total_mask_count = 0

    for i in range(n_folds):
        img_fold_dir = os.path.join(img_output_dir, f'fold_{i}')
        mask_fold_dir = os.path.join(mask_output_dir, f'fold_{i}')

        img_files = set(os.listdir(img_fold_dir))
        mask_files = set(os.listdir(mask_fold_dir))
        total_img_count += len(img_files)
        total_mask_count += len(mask_files)

        # Test 1: Each image has corresponding masks
        for img_file in img_files:
            base_name = os.path.splitext(img_file)[0]
            mask_file = f"{base_name}_{mask_postfix}.png"
            if mask_file not in mask_files:
                raise Exception(f"Image {img_file} in fold_{i} does not have a corresponding mask.")

        # Test 2: Each mask has a corresponding image
        for mask_file in mask_files:
            base_name = mask_file.split('_')[0]
            img_file = f"{base_name}.png"
            if img_file not in img_files:
                raise Exception(f"Mask {mask_file} in fold_{i} does not have a corresponding image.")

    # Test 3: Total count of images and masks matches the original dataset
    if total_img_count != original_img_count or total_mask_count != original_mask_count:
        raise Exception("The total count of images or masks does not match the original dataset.")

    return True

In [None]:
try:
    tests_passed = verify_image_mask_correspondence(output_dir, img_dir, mask_dir, n_folds=nb_folds)
    if tests_passed:
        print("All tests passed successfully.")
except Exception as e:
    print(f"Error: {e}")

In [None]:
images_patches_dir = os.path.join(output_dir, 'imgs')
masks_patches_dir = os.path.join(output_dir, 'masks')
os.makedirs(images_patches_dir, exist_ok=True)
os.makedirs(masks_patches_dir, exist_ok=True)

In [None]:
def extract_and_save_patches(img_dir, mask_dir, output_imgs_dir, 
                             output_masks_dir, num_patches=30, 
                             patch_size=(256, 256), seed=42, mask_postfix='instance_color_RGB'):
    random.seed(seed)

    # Get all image names
    img_files = [f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))]

    for img_file in tqdm(img_files, desc='Extracting patches'):
        img_path = os.path.join(img_dir, img_file)
        base_name, _ = os.path.splitext(img_file)
        mask_file = img_file.split('.')[0] + '_' + mask_postfix + '.png'
        print(f'mask_file: {mask_file}')
        mask_path = os.path.join(mask_dir, mask_file)
        print(f'mask_path: {mask_path}')
        
        with Image.open(img_path) as img:
            for i in range(num_patches):
                # Generate random position for cropping
                x = random.randint(0, img.width - patch_size[0])
                y = random.randint(0, img.height - patch_size[1])

                # Extract and save image patch
                img_patch = img.crop((x, y, x + patch_size[0], y + patch_size[1]))
                img_patch_file_name = f"{base_name}_patch_{i}.png"
                img_patch.save(os.path.join(output_imgs_dir, img_patch_file_name))

                # Extract and save mask patches 

                if os.path.exists(mask_path):
                    with Image.open(mask_path) as mask:
                        mask_patch = mask.crop((x, y, x + patch_size[0], y + patch_size[1]))
                        mask_patch_file_name = f"{base_name}_patch_{i}.png"
                        mask_patch.save(os.path.join(output_masks_dir, mask_patch_file_name))

In [None]:
def create_patches(src_img_root_dir, src_masks_root_dir, output_imgs_root_dir, output_masks_root_dir, patch_size, nb_folds: int = 5):
    for i in range(nb_folds):
        fold = f'fold_{i}'
        img_dir = os.path.join(src_img_root_dir, fold)
        mask_dir = os.path.join(src_masks_root_dir, fold)

        # Prepare dir to store final data
        output_imgs_dir = os.path.join(output_imgs_root_dir, fold)
        output_masks_dir = os.path.join(output_masks_root_dir, fold)
        os.makedirs(output_imgs_dir, exist_ok=True)
        os.makedirs(output_masks_dir, exist_ok=True)

        extract_and_save_patches(img_dir, mask_dir, output_imgs_dir, output_masks_dir, num_patches=30, patch_size=(patch_size, patch_size), seed=42)

In [None]:
create_patches(img_output_dir, mask_output_dir, images_patches_dir, masks_patches_dir, patch_size, nb_folds=nb_folds)

In [None]:
iSAID_palette = \
    {
        0: (0, 0, 0),
        1: (0, 0, 63),
        2: (0, 63, 63),
        3: (0, 63, 0),
        4: (0, 63, 127),
        5: (0, 63, 191),
        6: (0, 63, 255),
        7: (0, 127, 63),
        8: (0, 127, 127),
        9: (0, 0, 127),
        10: (0, 0, 191),
        11: (0, 0, 255),
        12: (0, 191, 127),
        13: (0, 127, 191),
        14: (0, 127, 255),
        15: (0, 100, 155)
    }

file = os.path.join(output_dir, 'isaid_mask_palette.json')
with open(file, 'w') as json_file:
    json.dump(iSAID_palette, json_file, indent=4)

In [None]:
print(output_dir)