In [1]:
from import_src import import_src
import_src()

from src.image_process.image_io import LoadImage, WriteImage
import src.image_process.pre_process as prep
import src.image_process.mask.mask_getters as maskget
from src.image_process.edp_center.centroid import get_centroid
from src.image_process.edp_center.center_optimization.opt_funcs import Distance
from src.image_process.edp_center.center_optimization.optimization import optimize_center
from src.image_process.edp_center.autocorrelation import AutoCorrelation



import os
import numpy as np
from typing import Iterable, List, Union, Callable, TypeVar, Tuple
from collections.abc import Iterable as IterableABC


In [2]:
T = TypeVar("T")
U = TypeVar("U")

def batch_run(batch_input: Iterable[T], operate_single: Callable[[T], U]) -> List[U]:
    """
    Applies the given single-item function to each item in the batch.

    Args:
        batch_input: An iterable of inputs to process.
        operate_single: A function that processes a single input of type `T`.

    Returns:
        A list of processed inputs.

    Raises:
        TypeError: If any element in the iterable is not of the expected type.
    """
    # Check that all elements in the iterable are of the expected type
    first_element_type = type(next(iter(batch_input)))
    if not all(isinstance(x, first_element_type) for x in batch_input):
        raise TypeError(f"All elements in the iterable must be of type {first_element_type.__name__}.")
    return [operate_single(x) for x in batch_input]

def make_batch_version(operate_single: Callable[[T], U]) -> Callable[[Union[Iterable[T], T]], Union[List[U], U]]:
    """
    Creates a batch-compatible version of a single-item processing function.

    Args:
        operate_single: A function that processes a single input of type `T`.

    Returns:
        A function that can process either a single input or an iterable of inputs.
    """
    def batch_version(input: Union[Iterable[T], T]) -> Union[List[U], U]:
        """
        Processes either a single input or an iterable of inputs.

        Args:
            input: A single input or an iterable of inputs.

        Returns:
            The processed input or a list of processed inputs.

        Raises:
            TypeError: If the input is not of the expected type.
            ValueError: If the iterable is empty.
        """
        if not isinstance(input, IterableABC) or isinstance(input, (str, bytes)):
            # Single item
            return operate_single(input)
        elif isinstance(input, IterableABC):
            # Batch processing
            if not input:
                raise ValueError("The input iterable must not be empty.")
            return batch_run(input, operate_single)
        else:
            # Unsupported type
            raise TypeError(
                f"Input must be of type {T} or an iterable of {T}, "
                f"but got {type(input).__name__}."
            )

    return batch_version

In [3]:
def load_single(path: str) -> np.ndarray:
    return LoadImage(path).data

load = make_batch_version(load_single)

In [4]:
def preprocess_single(data):
    median_filter = prep.MedianFilter(kernel_size=5)
    all_positive = prep.AllPositive()

    pre_processors = [median_filter, all_positive]

    pre_pipe = prep.PreProcessPipe(pre_processors=pre_processors)
    data = pre_pipe.pre_process_pipe(data)

    return data

preprocess = make_batch_version(preprocess_single)

In [5]:
def get_mask_single(data):
    mean_mask = maskget.MeanTreshMask(.1)

    mask = maskget.superpose_masks(data, [mean_mask])

    return mask

get_mask = make_batch_version(get_mask_single)

In [6]:
from typing import Tuple
from functools import partial

def autocorr(data, mask):
    ac = AutoCorrelation(data, mask)
    center, _ = ac.compute()
    return center

def first_guess(data, mask, method: str = "centroid"):
    """
    Computes the first guess for the center using the specified method.

    Args:
        data: Input data (e.g., a numpy array).
        mask: Mask for the input data.
        method: Method to use for computing the first guess. Supported values are "autocorrelation" and "centroid".

    Returns:
        The computed first guess for the center.

    Raises:
        ValueError: If the method is not supported.
    """
    methods = {
        "autocorrelation": lambda: autocorr(data, mask),
        "centroid": lambda: get_centroid(data),
    }

    if method not in methods:
        raise ValueError(f"Unknown first guess method: {method}")

    return methods[method]()

def find_center_single(data_mask: Tuple, distance_metric: str = "manhattan", first_guess_method: str = "centroid"):
    """
    Finds the center of the input data using the specified distance metric and first guess method.

    Args:
        data_mask: A tuple containing the input data and mask.
        distance_metric: The distance metric to use for optimization. Defaults to "manhattan".
        first_guess_method: The method to use for computing the first guess. Defaults to "centroid".

    Returns:
        The computed center.
    """
    data, mask = data_mask

    # Compute the first guess
    initial_guess = first_guess(data, mask, method=first_guess_method)

    # Create the penalty function
    penalty_func = Distance(data, mask, distance_metric=distance_metric).get_penalty_func()

    # Optimize the center
    center = optimize_center(penalty_func, data.shape, initial_guess=initial_guess)

    return center

# Configuration
first_guess_method = "centroid"
distance_metric = "manhattan"

# Create the batch version of find_center_single
find_center = make_batch_version(
    partial(find_center_single, distance_metric=distance_metric, first_guess_method=first_guess_method)
)

In [7]:
samples_path = "../raw_data/samples/"
samples_file_names = os.listdir(samples_path)
samples_paths = [samples_path + x for x in samples_file_names]

calibration_samples_path = "../raw_data/gold/"
calibration_file_names = os.listdir(calibration_samples_path)
calibration_paths = [calibration_samples_path + x for x in gold_file_names]

In [8]:
print("Loading data...")
loaded_data = load(all_paths)

print("Preprocessing data...")
preprocessed_data = preprocess(loaded_data)

print("Getting masks...")
masks = get_mask(preprocessed_data)

print("Finding centers...")
centers = find_center([(d, m) for d, m in zip(preprocessed_data, masks)])

Loading data...
Preprocessing data...
Getting masks...
Finding centers...


In [9]:
import matplotlib.pyplot as plt

to_print=centers

for i in range(len(to_print)):
    print(all_paths[i])
    if isinstance(to_print[i], np.ndarray):
        plt.imshow(to_print[i])
        plt.show()
    else:
        print(to_print[i])

../raw_data/samples/Carbono_beamstopper_Dia9_original.dm3
(1328.0036977709453, 1354.9999998283872)
../raw_data/samples/Cu120W_Zr144W_beamstopper_original.dm3
(1375.4204734722014, 1346.999995678504)
../raw_data/samples/Cu30W_beamstopper_original.dm3
(1350.001616883015, 1382.9999978076628)
../raw_data/samples/Cu30W_Zr100W_beamstopper_original.dm3
(1320.9999991985503, 1363.0000045854842)
../raw_data/samples/Cu30W_Zr120W_beamstopper_original.dm3
(1341.0000000097098, 1339.9999948130157)
../raw_data/samples/Cu30W_Zr210W_beamstopper_original.dm3
(1357.9999967504048, 1390.9999990948202)
../raw_data/samples/Cu60W_Zr72W_beamstopper_original.dm3
(1355.0000000081618, 1401.0073673419215)
../raw_data/samples/Zr30W_beamstopper_original.dm3
(1333.9999998285296, 1348.99998589857)
../raw_data/gold/Ouro_beamstopper_final_original_Dia8.dm3
(1356.999826273311, 1380.9999986035434)
../raw_data/gold/Ouro_beamstopper_inicial_original_Dia8.dm3
(1091.9999822907412, 1103.9999994456232)
