# Dataset Statistics/exploration

Try to understand the quality of the data better. 
In total five types of data:

    1. DMD patterns 
    2. Chromox real beam
    3. YAG real beam
    4. Chromox laser scan
    5. Yag laser scan

Beam data started from Wednesday (Chromox) (2025-11-19), then Friday and Saturday (Chromox, 2025-11-21 + 2025-11-22), Sunday (YAG, 2025-11-23), with laser scan also in Saturday and Sunday, plus DMD in the middle of sections



In [1]:
import os
from pathlib import Path
os.chdir(Path.cwd().parent)   # go one level up
print(os.getcwd())         

from xflow import SqlProvider, pipe_each, TransformRegistry as T
from xflow.utils import plot_image
import xflow.extensions.physics
from xflow.utils.io import scan_files
from xflow.utils.sql import union_sqlite_db_tables, merge_sqlite_dbs

from config_utils import load_config, detect_machine
from utils import *

from functools import partial
import pandas as pd
import sqlite3
import json

experiment_name = "CAE_validate_clear"  
machine = detect_machine() 

config = load_config(
    f"{experiment_name}.yaml",
    machine=machine
)

c:\Users\qiyuanxu\Documents\GitHub\fiber-image-reconstruction-comparison
[config_utils] Using machine profile: win-qiyuanxu


# Scope 1 - DMD synthetic data and its corresponding Real data

Create such ready to use dataset for training and evaluation (could reverse the training testing logic)

In [2]:
# ============================
# Merge entire CLEAR 2025 dataset in to a single database
# ============================

merged_path = "C:\\Users\\qiyuanxu\\Desktop\\clear_2025_dataset.db"
db_paths = scan_files("C:\\Users\\qiyuanxu\\Documents\\DataHub\\datasets", extensions=[".db"], return_type="str")
merge_sqlite_dbs(db_paths, output_path=merged_path, source_column="db_path")

  merged_df = pd.concat(dfs, ignore_index=True)


'C:\\Users\\qiyuanxu\\Desktop\\clear_2025_dataset.db'

In [3]:
# ============================
# Left join metadata into the big merged database to form a single table
# ============================

sql = """
SELECT
    d.*,
    c.experiment_description,
    c.image_source,
    c.image_device,
    c.fiber_config,
    c.camera_config,
    c.other_config
FROM mmf_dataset_metadata AS d
LEFT JOIN mmf_experiment_config AS c
  ON c.id = d.config_id
 AND c.db_path = d.db_path;
"""

with sqlite3.connect(str(merged_path)) as con:
    tables_df = pd.read_sql_query(sql, con)

# optional: drop duplicate column names (e.g. both tables have "id", "db_path")
tables_df = tables_df.loc[:, ~tables_df.columns.duplicated()]
print(tables_df.shape)

(93787, 31)


In [8]:
# ============================
# DMD hourly data collection statistics
# ============================

tables_df["other_config"] = tables_df["other_config"].map(
    lambda x: x if isinstance(x, dict)
    else json.loads(x) if isinstance(x, str) and x.strip().startswith("{")
    else None
)

mask = tables_df["other_config"].map(
    lambda d: (
        isinstance(d, dict)
        and d.get("dmd_config", {}).get("type") != "DummyDMD"
        and d.get("beam_settings") is None
    )
)

dmd_df = tables_df.loc[mask].copy()
print("Total rows kept:", int(mask.sum()))

out = (
    pd.to_datetime(dmd_df["create_time"], errors="coerce")
      .dt.strftime("%Y-%m-%d %H")
      .dropna()
      .value_counts()
      .rename_axis("hour")
      .reset_index(name="count")
      .sort_values("hour")
      .reset_index(drop=True)
)
out

Total rows kept: 10337


Unnamed: 0,hour,count
0,2025-11-19 16,225
1,2025-11-19 17,891
2,2025-11-20 06,25
3,2025-11-20 07,1096
4,2025-11-20 08,1436
5,2025-11-20 10,434
6,2025-11-20 11,2432
7,2025-11-20 12,173
8,2025-11-20 19,558
9,2025-11-20 20,394


In [6]:
# ============================
# Chromox hourly data collection statistics
# ============================
mask = tables_df["beam_settings"].notna() & tables_df["image_device"].astype(str).str.contains("Chromox", na=False)
chromox_df = tables_df.loc[mask].copy()
print("Total rows kept:", int(mask.sum()))

out = (
    pd.to_datetime(chromox_df["create_time"], errors="coerce")
      .dt.strftime("%Y-%m-%d %H")
      .dropna()
      .value_counts()
      .rename_axis("hour")
      .reset_index(name="count")
      .sort_values("hour")
      .reset_index(drop=True)
)
out

Total rows kept: 49583


Unnamed: 0,hour,count
0,2025-11-19 12,326
1,2025-11-19 13,35
2,2025-11-19 14,785
3,2025-11-19 15,1842
4,2025-11-21 08,28
5,2025-11-21 09,4407
6,2025-11-21 10,4836
7,2025-11-21 11,247
8,2025-11-21 22,2913
9,2025-11-21 23,5047


In [7]:
# ============================
# Yag hourly data collection statistics
# ============================

mask = tables_df["beam_settings"].notna() & tables_df["image_device"].astype(str).str.contains("YAG", na=False)
yag_df = tables_df.loc[mask].copy()
print("Total rows kept:", int(mask.sum()))

out = (
    pd.to_datetime(yag_df["create_time"], errors="coerce")
      .dt.strftime("%Y-%m-%d %H")
      .dropna()
      .value_counts()
      .rename_axis("hour")
      .reset_index(name="count")
      .sort_values("hour")
      .reset_index(drop=True)
)
out

Total rows kept: 15341


Unnamed: 0,hour,count
0,2025-11-23 09,1954
1,2025-11-23 10,3648
2,2025-11-23 11,3406
3,2025-11-23 12,4694
4,2025-11-23 13,1639


In [None]:
# Connect to database and read image sample paths
# """wednesday chromox"""
# dirs = config["paths"]["chromox_2025-11-19"] 
# query = """
# SELECT 
#     image_path
# FROM 
#     mmf_dataset_metadata 
# WHERE 
#     batch IN (10, 11, 12)
# --LIMIT 20
# """

"""Friday Chromox"""
dirs = config["paths"]["chromox_2025-11-21"] 
query = """
SELECT 
    image_path
FROM 
    mmf_dataset_metadata 
WHERE
    batch IN (1, 2, 3)
--LIMIT 20
"""

# """Saturday Chromox"""
# dirs = config["paths"]["chromox_2025-11-22-morning"] 
# query = """
# SELECT 
#     image_path
# FROM 
#     mmf_dataset_metadata 
# WHERE
#     batch IN (10, 11, 12)
# --LIMIT 20
# """

db_path = f"{dirs}/db/dataset_meta.db"
realbeam_provider = SqlProvider(
    sources={"connection": db_path, "sql": query}, output_config={'list': "image_path"}
)
image_paths = realbeam_provider()
print(f"Found {len(image_paths)} entries in the database.")

In [None]:
from xflow.extensions.physics.beam import extract_beam_parameters
params = {
    "crop_gt": [[74, 728], [1158, 488]],
    "crop_fo": [[360, 0], [1560, 1200]],
    "inp_size": (256, 256),
    "out_size": (256, 256),
}

def default_meta(path):
    """Default metadata extractor - path, filename, extension."""
    p = Path(path)
    return {"path": str(p), "filename": p.stem, "extension": p.suffix}
    
extract_moments = partial(extract_beam_parameters, method="moments")


# Define transform pipeline
transforms = [
    partial(T.get("add_parent_dir"), parent_dir=dirs),
    partial(T.get("torch_load_image_with_meta"), meta_fn=default_meta),
    # [None, T.get("debug_print")], 
    [T.get("torch_to_tensor"), None],  
    [T.get("torch_to_grayscale"), None],
    [T.get("torch_remap_range"), None],
    [partial(T.get("torch_split_width"), swap=True), None],    # (fiber_output, ground_truth)
    # [
    #     partial(T.get("torch_crop_area"), points=params["crop_fo"]),
    #     partial(T.get("torch_crop_area"), points=params["crop_gt"]),
    # ],
    [
        partial(T.get("torch_resize"), size=params["inp_size"]),
        partial(T.get("torch_resize"), size=params["out_size"]),
        None
    ],
    # [
    #     T.get("discard"),
    #     partial(T.get("apply"), fn=extract_moments)
    # ],
    # T.get("flatten_nested"),
    # T.get("collect"),
    T.get("raise_if_none"),
    T.get("join_image"),
    T.get("save_image_from_meta"),
]

results = list(pipe_each(
    image_paths,
    *transforms,
    progress=True,
    desc="Processing images",
    skip_errors=True,
))
print(len(results))

In [None]:
plot_image(results[0])

# Scope 2 - Real data only with data augmentation pipeline (super position)
Create such ready to use dataset for training and evaluation

# Visualization (temp)

In [None]:
from xflow.utils.visualization import stack_log_remap, stack_linear_clip
stacked = stack_log_remap([x[1] for x in results])
plot_image(stacked)

stacked = stack_linear_clip([x[1] for x in results])
plot_image(stacked)

In [None]:
from __future__ import annotations
from typing import Tuple, Optional
import numpy as np

Point = Tuple[float, float]


def _min_mass_segment(weights: np.ndarray, frac: float, eps: float = 1e-12) -> Tuple[int, int]:
    """
    Smallest contiguous index range [i, j) whose sum >= frac * total.
    Returns (i, j) with j exclusive.
    """
    w = np.asarray(weights, dtype=float)
    w = np.clip(w, 0.0, None)
    n = w.size
    if n == 0:
        raise ValueError("Empty weights.")
    total = float(w.sum())
    if total <= eps:
        # No mass -> return full range
        return 0, n

    target = frac * total
    best_i, best_j = 0, n
    best_len = n + 1

    j = 0
    s = 0.0
    for i in range(n):
        while j < n and s < target:
            s += w[j]
            j += 1
        if s >= target:
            if (j - i) < best_len:
                best_len = j - i
                best_i, best_j = i, j
        s -= w[i]

    return best_i, best_j


def _clamp_interval(a: float, b: float, lo: float = 0.0, hi: float = 1.0) -> Tuple[float, float]:
    """Clamp [a,b] into [lo,hi] by shifting (keeps length if possible)."""
    length = b - a
    if length >= (hi - lo):
        return lo, hi
    if a < lo:
        b = b + (lo - a)
        a = lo
    if b > hi:
        a = a - (b - hi)
        b = hi
    a = max(lo, a)
    b = min(hi, b)
    return a, b


def square_from_projections(
    img: np.ndarray,
    frac: float,
    *,
    make_square: bool = True,
    channel_reduce: str = "sum",  # "sum" or "mean"
    eps: float = 1e-12,
) -> Tuple[Point, Point]:
    """
    1) Project image onto x and y by summing pixels.
    2) Find minimal contiguous x-interval containing `frac` of x-projection mass,
       and same for y.
    3) Form rectangle. If make_square=True, expand to the smallest axis-aligned square
       that contains that rectangle (centered), clamped to [0,1].

    Returns (top_left, bottom_right) in normalized coords, with (0,0) at top-left.
    """
    if frac > 1.0:
        frac = frac / 100.0
    if not (0.0 < frac <= 1.0):
        raise ValueError("frac must be in (0,1] or (0,100].")

    a = np.asarray(img, dtype=float)
    if a.ndim == 3:
        if channel_reduce == "mean":
            a = a.mean(axis=2)
        elif channel_reduce == "sum":
            a = a.sum(axis=2)
        else:
            raise ValueError("channel_reduce must be 'sum' or 'mean'")
    elif a.ndim != 2:
        raise ValueError("img must be 2D or 3D array")

    # Ensure non-negative "mass"
    a = np.clip(a, 0.0, None)

    H, W = a.shape
    if H == 0 or W == 0:
        raise ValueError("img has zero size")

    proj_x = a.sum(axis=0)  # length W
    proj_y = a.sum(axis=1)  # length H

    ix0, ix1 = _min_mass_segment(proj_x, frac, eps=eps)  # [ix0, ix1)
    iy0, iy1 = _min_mass_segment(proj_y, frac, eps=eps)  # [iy0, iy1)

    # Convert pixel-edge indices to normalized [0,1]
    x0, x1 = ix0 / W, ix1 / W
    y0, y1 = iy0 / H, iy1 / H

    if not make_square:
        return (x0, y0), (x1, y1)

    # Expand rectangle to a square (axis-aligned), centered on the rectangle
    w = x1 - x0
    h = y1 - y0
    side = max(w, h)

    cx = 0.5 * (x0 + x1)
    cy = 0.5 * (y0 + y1)

    sx0, sx1 = cx - 0.5 * side, cx + 0.5 * side
    sy0, sy1 = cy - 0.5 * side, cy + 0.5 * side

    sx0, sx1 = _clamp_interval(sx0, sx1, 0.0, 1.0)
    sy0, sy1 = _clamp_interval(sy0, sy1, 0.0, 1.0)

    return (sx0, sy0), (sx1, sy1)


# Example:
tl, br = square_from_projections(stacked, 0.96)
highlighted = draw_red_square(stacked, tl, br, thickness=2)
plot_image(highlighted)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

a = np.array(results)
hc, vc, hw, vw = a[:, 0], a[:, 1], a[:, 2], a[:, 3]

fig, ax = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True)

ax[0].scatter(hc, vc, s=5, alpha=0.3)
ax[0].set(xlim=(0, 1), ylim=(0, 1), xlabel="h_centroid", ylabel="v_centroid", title="Centroids")
ax[0].set_aspect("equal")

ax[1].scatter(hw, vw, s=5, alpha=0.3)
ax[1].set(xlim=(0, 1), ylim=(0, 1), xlabel="h_width", ylabel="v_width", title="Widths")
ax[1].set_aspect("equal")

plt.show()