# Dataset Statistics/exploration

Try to understand the quality of the data better

In [None]:
import os
from pathlib import Path
os.chdir(Path.cwd().parent)   # go one level up
print(os.getcwd())         
from functools import partial
from xflow import SqlProvider, pipe_each, TransformRegistry as T
from xflow.utils import plot_image
import xflow.extensions.physics
from config_utils import load_config, detect_machine
from utils import *

experiment_name = "CAE_validate_clear"  
machine = detect_machine() 

config = load_config(
    f"{experiment_name}.yaml",
    machine=machine
)

In [None]:
# Connect to database and read image sample paths
# """wednesday chromox"""
# dirs = config["paths"]["chromox_2025-11-19"] 
# query = """
# SELECT 
#     image_path
# FROM 
#     mmf_dataset_metadata 
# WHERE 
#     batch IN (10, 11, 12)
# --LIMIT 20
# """

"""Friday Chromox"""
dirs = config["paths"]["chromox_2025-11-21"] 
query = """
SELECT 
    image_path
FROM 
    mmf_dataset_metadata 
WHERE
    batch IN (1, 2, 3)
--LIMIT 20
"""

# """Saturday Chromox"""
# dirs = config["paths"]["chromox_2025-11-22-morning"] 
# query = """
# SELECT 
#     image_path
# FROM 
#     mmf_dataset_metadata 
# WHERE
#     batch IN (10, 11, 12)
# --LIMIT 20
# """

db_path = f"{dirs}/db/dataset_meta.db"
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
image_paths = realbeam_provider()
print(f"Found {len(image_paths)} entries in the database.")

In [None]:
from xflow.extensions.physics.beam import extract_beam_parameters
params = {
    "crop_gt": [[74, 728], [1158, 488]],
    "crop_fo": [[360, 0], [1560, 1200]],
    "inp_size": (256, 256),
    "out_size": (256, 256),
}

def default_meta(path):
    """Default metadata extractor - path, filename, extension."""
    p = Path(path)
    return {"path": str(p), "filename": p.stem, "extension": p.suffix}
    
extract_moments = partial(extract_beam_parameters, method="moments")


# Define transform pipeline
transforms = [
    partial(T.get("add_parent_dir"), parent_dir=dirs),
    partial(T.get("torch_load_image_with_meta"), meta_fn=default_meta),
    # [None, T.get("debug_print")], 
    [T.get("torch_to_tensor"), None],  
    [T.get("torch_to_grayscale"), None],
    [T.get("torch_remap_range"), None],
    [partial(T.get("torch_split_width"), swap=True), None],    # (fiber_output, ground_truth)
    # [
    #     partial(T.get("torch_crop_area"), points=params["crop_fo"]),
    #     partial(T.get("torch_crop_area"), points=params["crop_gt"]),
    # ],
    [
        partial(T.get("torch_resize"), size=params["inp_size"]),
        partial(T.get("torch_resize"), size=params["out_size"]),
        None
    ],
    # [
    #     T.get("discard"),
    #     partial(T.get("apply"), fn=extract_moments)
    # ],
    # T.get("flatten_nested"),
    # T.get("collect"),
    T.get("raise_if_none"),
    T.get("join_image"),
    T.get("save_image_from_meta"),
]

results = list(pipe_each(
    image_paths,
    *transforms,
    progress=True,
    desc="Processing images",
    skip_errors=True,
))
print(len(results))

In [None]:
plot_image(results[0])

# Visualization

In [None]:
from xflow.utils.visualization import stack_log_remap, stack_linear_clip
stacked = stack_log_remap([x[1] for x in results])
plot_image(stacked)

stacked = stack_linear_clip([x[1] for x in results])
plot_image(stacked)

In [None]:
from __future__ import annotations
from typing import Tuple, Optional
import numpy as np

Point = Tuple[float, float]


def _min_mass_segment(weights: np.ndarray, frac: float, eps: float = 1e-12) -> Tuple[int, int]:
    """
    Smallest contiguous index range [i, j) whose sum >= frac * total.
    Returns (i, j) with j exclusive.
    """
    w = np.asarray(weights, dtype=float)
    w = np.clip(w, 0.0, None)
    n = w.size
    if n == 0:
        raise ValueError("Empty weights.")
    total = float(w.sum())
    if total <= eps:
        # No mass -> return full range
        return 0, n

    target = frac * total
    best_i, best_j = 0, n
    best_len = n + 1

    j = 0
    s = 0.0
    for i in range(n):
        while j < n and s < target:
            s += w[j]
            j += 1
        if s >= target:
            if (j - i) < best_len:
                best_len = j - i
                best_i, best_j = i, j
        s -= w[i]

    return best_i, best_j


def _clamp_interval(a: float, b: float, lo: float = 0.0, hi: float = 1.0) -> Tuple[float, float]:
    """Clamp [a,b] into [lo,hi] by shifting (keeps length if possible)."""
    length = b - a
    if length >= (hi - lo):
        return lo, hi
    if a < lo:
        b = b + (lo - a)
        a = lo
    if b > hi:
        a = a - (b - hi)
        b = hi
    a = max(lo, a)
    b = min(hi, b)
    return a, b


def square_from_projections(
    img: np.ndarray,
    frac: float,
    *,
    make_square: bool = True,
    channel_reduce: str = "sum",  # "sum" or "mean"
    eps: float = 1e-12,
) -> Tuple[Point, Point]:
    """
    1) Project image onto x and y by summing pixels.
    2) Find minimal contiguous x-interval containing `frac` of x-projection mass,
       and same for y.
    3) Form rectangle. If make_square=True, expand to the smallest axis-aligned square
       that contains that rectangle (centered), clamped to [0,1].

    Returns (top_left, bottom_right) in normalized coords, with (0,0) at top-left.
    """
    if frac > 1.0:
        frac = frac / 100.0
    if not (0.0 < frac <= 1.0):
        raise ValueError("frac must be in (0,1] or (0,100].")

    a = np.asarray(img, dtype=float)
    if a.ndim == 3:
        if channel_reduce == "mean":
            a = a.mean(axis=2)
        elif channel_reduce == "sum":
            a = a.sum(axis=2)
        else:
            raise ValueError("channel_reduce must be 'sum' or 'mean'")
    elif a.ndim != 2:
        raise ValueError("img must be 2D or 3D array")

    # Ensure non-negative "mass"
    a = np.clip(a, 0.0, None)

    H, W = a.shape
    if H == 0 or W == 0:
        raise ValueError("img has zero size")

    proj_x = a.sum(axis=0)  # length W
    proj_y = a.sum(axis=1)  # length H

    ix0, ix1 = _min_mass_segment(proj_x, frac, eps=eps)  # [ix0, ix1)
    iy0, iy1 = _min_mass_segment(proj_y, frac, eps=eps)  # [iy0, iy1)

    # Convert pixel-edge indices to normalized [0,1]
    x0, x1 = ix0 / W, ix1 / W
    y0, y1 = iy0 / H, iy1 / H

    if not make_square:
        return (x0, y0), (x1, y1)

    # Expand rectangle to a square (axis-aligned), centered on the rectangle
    w = x1 - x0
    h = y1 - y0
    side = max(w, h)

    cx = 0.5 * (x0 + x1)
    cy = 0.5 * (y0 + y1)

    sx0, sx1 = cx - 0.5 * side, cx + 0.5 * side
    sy0, sy1 = cy - 0.5 * side, cy + 0.5 * side

    sx0, sx1 = _clamp_interval(sx0, sx1, 0.0, 1.0)
    sy0, sy1 = _clamp_interval(sy0, sy1, 0.0, 1.0)

    return (sx0, sy0), (sx1, sy1)


# Example:
tl, br = square_from_projections(stacked, 0.96)
highlighted = draw_red_square(stacked, tl, br, thickness=2)
plot_image(highlighted)


In [None]:
tl, br

In [None]:
import numpy as np
import matplotlib.pyplot as plt

a = np.array(results)
hc, vc, hw, vw = a[:, 0], a[:, 1], a[:, 2], a[:, 3]

fig, ax = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True)

ax[0].scatter(hc, vc, s=5, alpha=0.3)
ax[0].set(xlim=(0, 1), ylim=(0, 1), xlabel="h_centroid", ylabel="v_centroid", title="Centroids")
ax[0].set_aspect("equal")

ax[1].scatter(hw, vw, s=5, alpha=0.3)
ax[1].set(xlim=(0, 1), ylim=(0, 1), xlabel="h_width", ylabel="v_width", title="Widths")
ax[1].set_aspect("equal")

plt.show()