## Load the necessary imports

In [None]:
from __future__ import annotations

import sys

sys.path.append("../scripts")

In [None]:
from tqdm import tqdm
from testy import count_black_pixels, convert_slice_stats_to_csv, convert_volume_stats_to_csv, comparison_plot
from wup_analysis import initialize_dataset, DATA_PATH_TEMPLATE, update_slice_stats, update_volume_stats
from preprocessing import otsu_thresholding
import cv2
from matplotlib import pyplot as plt
import numpy as np

## Implementation of helper functions for the image preprocessing

In [None]:
def remove_black_area(img: np.ndarray, plot: bool=False) -> np.ndarray:
    """Remove the black area from the input image.

    :param img: The input image.
    :param plot: Switch to enable/disable plotting.
    :return: The image with the black area removed.
    """
    # Find contours in the image
    contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Find the contour with the largest area
    largest_contour = max(contours, key=cv2.contourArea)

    # Create a mask for the largest contour
    largest_contour_mask = np.zeros_like(img, dtype=np.uint8)
    cv2.drawContours(largest_contour_mask, [largest_contour], -1, 255, cv2.FILLED)
    
    # Subtract the largest contour mask from the input image
    img_excluding_largest_contour = cv2.subtract(largest_contour_mask, img)
    
    if plot:
        comparison_plot(img, img_excluding_largest_contour, "Grayscale Image", "Image Excluding Largest Contour")
        
    return img_excluding_largest_contour

## Main function for the 2D image analysis

In [None]:
def run_analysis(data: list[np.ndarray], dataset: str, save: bool = False) -> None:
    """
    Perform analysis on the given dataset.

    :param data: List of numpy arrays representing the dataset.
    :param dataset: Name of the dataset.
    :param save: Whether to save the analysis results or not. Defaults to False.
    :return: None
    """

    slice_stats = {}
    min_init, max_init = float('inf'), 0
    volume_stats = {
        'min_air_pockets': min_init,
        'max_air_pockets': max_init,
        'min_air_pocket_size': min_init,
        'max_air_pocket_size': max_init,
        'min_black_pixel_count': min_init,
        'max_black_pixel_count': max_init,
        'min_air_pocket_percentage': min_init,
        'max_air_pocket_percentage': max_init,
    }

    for img in tqdm(data, total=len(data)):
        center_x = img.shape[1] // 2
        center_y = img.shape[0] // 2
        thresh = otsu_thresholding(img)
        volume_black_pixel_count = count_black_pixels(thresh)
        inverted_img = remove_black_area(thresh)

        num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(inverted_img, connectivity=8)

        slice_stats, areas, max_area, min_area = update_slice_stats(slice_stats, num_labels, stats, center_x, center_y, centroids)
        volume_stats = update_volume_stats(volume_stats, num_labels, min_area, max_area, volume_black_pixel_count, areas)

    volume_stats['min_air_pocket_depth'] = min(dct.get('depth') for dct in slice_stats.values())
    volume_stats['max_air_pocket_depth'] = max(dct.get('depth') for dct in slice_stats.values())


    if save:
        convert_slice_stats_to_csv(slice_stats,f'../data/aske/{dataset}/csv_files/slice_stats')
        convert_volume_stats_to_csv(volume_stats, f'../data/aske/{dataset}/csv_files/volume_stats')

    return slice_stats, volume_stats

## Start the analysis

In [None]:
dataset = 'WUP2'
path = DATA_PATH_TEMPLATE.format(dataset.lower())
DATA = initialize_dataset(path, dataset, width=1028, save=True)

In [None]:
run_analysis(DATA, dataset=dataset.lower(), save=True)

## Test to randomly apply the postprocessing to a series of images -> Used for visual inspection

In [None]:
# import random

# # Assuming you want to select, for example, 5 random images from the list
# subset_size = 30

# # Randomly select indices for the subset
# subset_indices = random.sample(range(len(DATA)), subset_size)

# # Create a list of images from the selected indices
# random_subset = [(DATA[i], i) for i in subset_indices]

# for img, i in random_subset:
#     print(i)
#     remove_black_area(otsu_thresholding(img), plot=True)