In [1]:
import json
import pickle
from pathlib import Path
from typing import List, Sequence, Union, Tuple, Dict, Any, Optional
import h5py
import numpy as np
from sklearn.decomposition import PCA

In [None]:
# The key difference here is that we remove canonical_ids = np.arange(558).reshape(-1, 1)
def create_hdf5_file(
    specimens_data: List[np.ndarray],
    output_path: Path,
    file_prefix: str = "specimen",
    start_idx: int = 0,
) -> int:
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(output_path, "w") as f:
        specimens_group = f.create_group("specimens")

        for i, specimen_data in enumerate(specimens_data):
            specimen_key = f"{file_prefix}_{start_idx + i:06d}"
            specimens_group.create_dataset(specimen_key, data=specimen_data.astype(np.float32))

        f.attrs["num_specimens"] = len(specimens_data)
        f.attrs["format_version"] = "1.1"
        f.attrs["description"] = "C. elegans nuclei data: [canonical_id, x, y, z]"

    return start_idx + len(specimens_data)

In [None]:
# Here, we do not have the shuffling and splitting of the dataset because it all refers to the test set
def convert_specimens_to_hdf5(
    specimens_data: Union[np.ndarray, Sequence[np.ndarray]],
    output_dir: Union[str, Path],
    specimens_per_file: int = 2**14,
) -> None:
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    specimens_list = list(specimens_data)

    split_dir = output_dir / "test"
    split_dir.mkdir(exist_ok=True)

    file_idx = 0
    specimen_idx = 0

    for start_idx in range(0, len(specimens_list), specimens_per_file):
        end_idx = min(start_idx + specimens_per_file, len(specimens_list))
        batch_data = specimens_list[start_idx:end_idx]

        file_path = split_dir / f"test_{file_idx:04d}.h5"
        specimen_idx = create_hdf5_file(
            batch_data,
            file_path,
            file_prefix="specimen",
            start_idx=specimen_idx,
        )
        file_idx += 1

    info = {
        "total_specimens": len(specimens_list),
        "split": "test",
        "specimens_per_file": specimens_per_file,
        "format": "[canonical_id, x, y, z]",
    }

    with open(output_dir / "dataset_info.json", "w") as f:
        json.dump(info, f, indent=2)

In [2]:
test_worm_1_pickle_path = Path("/fs/pool/pool-mlsb/bulat/Wormologist/synthetic_data_generator/test1worms.pkl")
test_worm_2_pickle_path = Path("/fs/pool/pool-mlsb/bulat/Wormologist/synthetic_data_generator/test2worms.pkl")

In [3]:
with open(test_worm_1_pickle_path, "rb") as f:
    worm1 = pickle.load(f)
with open(test_worm_2_pickle_path, "rb") as f:
    worm2 = pickle.load(f)
all_worms = worm1 + worm2

# Initial random (with adapted file type)

In [None]:
specimens_per_file = 2**14
rng_seed = 42
rng = np.random.default_rng(rng_seed)
test_sizes = np.array(list(range(10, 560, 10)) + [558])
subgraph_output_directory = ""

In [None]:
for subgraph_size in test_sizes:
    subgraph_samples: List[np.ndarray] = []

    for answer, coords in all_worms:
        nodes = len(answer)
        sample_count = int(np.ceil(558 / subgraph_size)) * 2

        for _ in range(sample_count):
            sampled_indices = rng.choice(nodes, size=min(subgraph_size, nodes), replace=False)
            canonical_ids = np.asarray(answer)[sampled_indices]
            coords_subset = np.asarray(coords)[sampled_indices]

            sample = np.zeros((len(sampled_indices), 4), dtype=np.float32)
            sample[:, 0] = canonical_ids.astype(np.float32)
            sample[:, 1:] = coords_subset.astype(np.float32)
            subgraph_samples.append(sample)

    subgraph_dir = subgraph_output_directory / f"subgraph_{int(subgraph_size):03d}"
    
    convert_specimens_to_hdf5(
        subgraph_samples,
        output_dir=subgraph_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(subgraph_samples)} subgraphs of size {subgraph_size} to {subgraph_dir}")

# Adapted random for comparison

In [None]:
specimens_per_file = 2**14
rng_seed = 42
rng = np.random.default_rng(rng_seed)
min_size = 53
max_size = 59
subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/random_comparison_to_real_test_set_20")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)
sample_count = 12
num_datasets = 10

In [None]:
for dataset_idx in range(num_datasets):
    subgraph_samples: List[np.ndarray] = []
    subgraph_sizes: List[int] = []

    for answer, coords in all_worms:
        nodes = len(answer)
        
        for _ in range(sample_count):
            subgraph_size = rng.integers(min_size, max_size + 1)
            sample_size = min(subgraph_size, nodes)
            subgraph_sizes.append(sample_size)
            sampled_indices = rng.choice(nodes, size=sample_size, replace=False)

            canonical_ids = np.asarray(answer)[sampled_indices]
            coords_subset = np.asarray(coords)[sampled_indices]

            sample = np.zeros((len(sampled_indices), 4), dtype=np.float32)
            sample[:, 0] = canonical_ids.astype(np.float32)
            sample[:, 1:] = coords_subset.astype(np.float32)
            subgraph_samples.append(sample)

    subgraph_dir = subgraph_output_directory / f"dataset_{dataset_idx:02d}"
    
    convert_specimens_to_hdf5(
        subgraph_samples,
        output_dir=subgraph_dir,
        specimens_per_file=specimens_per_file,
    )

    avg_size = float(np.mean(subgraph_sizes)) if subgraph_sizes else float("nan")
    std_size = float(np.std(subgraph_sizes)) if subgraph_sizes else float("nan")
    min_size_obs = int(np.min(subgraph_sizes)) if subgraph_sizes else 0
    max_size_obs = int(np.max(subgraph_sizes)) if subgraph_sizes else 0

    print(
        f"Saved {len(subgraph_samples)} subgraphs "
        f"(avg size {avg_size:.2f} ± {std_size:.2f}) to {subgraph_dir}"
    )

    stats_path = subgraph_dir / "subgraph_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"dataset_index: {dataset_idx}\n")
        fh.write(f"num_samples: {len(subgraph_samples)}\n")
        fh.write(f"size_range_config: [{min_size}, {max_size}]\n")
        fh.write(f"size_min_observed: {min_size_obs}\n")
        fh.write(f"size_max_observed: {max_size_obs}\n")
        fh.write(f"size_mean: {avg_size:.2f}\n")
        fh.write(f"size_std: {std_size:.2f}\n")

# Slicing

In [None]:
def get_slice_indices(points: np.ndarray,
                      n_slices: int = 40,
                      slice_thickness: float = 0.05
                      ) -> Tuple[np.ndarray, List[int]]:
    """Fixed-width slicing along the PCA axis with diagnostics."""
    if points.ndim != 2 or points.shape[1] != 3:
        raise ValueError(f"Expected points to have shape (N, 3); got {points.shape}")

    pca = PCA(n_components=1)
    t = pca.fit_transform(points).ravel()

    centers = np.linspace(t.min(), t.max(), n_slices)

    per_slice_counts: List[int] = []
    selected = set()
    half_thickness = slice_thickness / 2.0

    for center in centers:
        mask = np.abs(t - center) <= half_thickness
        hits = np.where(mask)[0]
        per_slice_counts.append(int(hits.size))
        selected.update(hits.tolist())

    indices = np.array(sorted(selected), dtype=int)
    total_selected = len(indices)
    avg_per_slice = float(np.mean(per_slice_counts)) if per_slice_counts else 0.0

    print(f"[get_slice_indices] total nuclei selected: {total_selected}")
    print(f"[get_slice_indices] per-slice counts: {per_slice_counts}")
    print(f"[get_slice_indices] average nuclei per slice: {avg_per_slice:.2f}")

    return indices, per_slice_counts

In [None]:
NUM_SLICES = 40
SLICE_THICKNESS = 0.005  # units must match the input coordinates
specimens_per_file = 2**14

sliced_samples: List[np.ndarray] = []

for canonical_ids, coords in all_worms:
    coords_arr = np.asarray(coords, dtype=np.float32)
    ids_arr = np.asarray(canonical_ids, dtype=np.int64)

    if coords_arr.ndim != 2 or coords_arr.shape[1] != 3:
        raise ValueError(f"Expected coords to be [N, 3], got {coords_arr.shape}")

    selected_idx, per_slice_counts = get_slice_indices(
        coords_arr,
        n_slices=NUM_SLICES,
        slice_thickness=SLICE_THICKNESS,
    )

    if selected_idx.size == 0:
        continue  # no nuclei captured for this worm

    sliced_ids = ids_arr[selected_idx]
    sliced_coords = coords_arr[selected_idx]

    sample = np.zeros((len(selected_idx), 4), dtype=np.float32)
    sample[:, 0] = sliced_ids.astype(np.float32)
    sample[:, 1:] = sliced_coords
    sliced_samples.append(sample)

    print(
        f"Sliced worm (len={len(ids_arr)}): "
        f"{len(selected_idx)} nuclei selected | per-slice counts = {per_slice_counts}"
    )

# Persist all sliced specimens in one HDF5 shard (adjust output path as needed)
output_dir = subgraph_output_directory / "sliced_subgraphs"
convert_specimens_to_hdf5(
    sliced_samples,
    output_dir=output_dir,
    specimens_per_file=specimens_per_file,
)
print(f"Saved {len(sliced_samples)} sliced subgraphs to {output_dir}")


# Slicing with a shift

In [None]:
def get_slice_indices(points: np.ndarray,
                      n_slices: int = 40,
                      slice_thickness: float = 0.005,
                      shift: float = 0.0
                      ) -> Tuple[np.ndarray, List[int]]:
    """
    Fixed-width slicing along the PCA axis with diagnostics.

    Args:
        points          : [N, 3] array of nucleus centers.
        n_slices        : number of slices (default 40).
        slice_thickness : axial thickness of each slice (same units as points).
        shift           : axial offset applied to every slice center. Positive
                          values slide the entire stack toward larger PCA coordinates.

    Returns:
        indices          : sorted unique nucleus indices captured by at least one slice.
        per_slice_counts : hit counts for each slice after the shift.
    """
    if points.ndim != 2 or points.shape[1] != 3:
        raise ValueError(f"Expected points to have shape (N, 3); got {points.shape}")

    pca = PCA(n_components=1)
    t = pca.fit_transform(points).ravel()

    centers = np.linspace(t.min(), t.max(), n_slices) + shift

    per_slice_counts: List[int] = []
    selected: set[int] = set()
    half_thickness = slice_thickness / 2.0

    for center in centers:
        mask = np.abs(t - center) <= half_thickness
        hits = np.where(mask)[0]
        per_slice_counts.append(int(hits.size))
        selected.update(hits.tolist())

    indices = np.array(sorted(selected), dtype=int)
    total_selected = len(indices)
    avg_per_slice = float(np.mean(per_slice_counts)) if per_slice_counts else 0.0

    print(
        f"[get_slice_indices] shift={shift:.4f} | total nuclei selected: {total_selected} | "
        f"average per slice: {avg_per_slice:.2f}"
    )

    return indices, per_slice_counts


In [None]:
NUM_SLICES = 20  
SLICE_THICKNESS = 0.005 # it can be relative to the worm length so as to work with subgraph matching (it can be some ratio of the worm length)
NUM_SHIFTS = 24                     # 12 shifts × 200 worms = 2400 samples (match what we had for the random subgraphs for now)
                                 
center_spacing = 1 / (NUM_SLICES - 1)   # distance between consecutive slice centres. It is fixed to 1 / ... because all wors are scaled to length 1
max_non_overlap_shift = max(0.0, center_spacing - SLICE_THICKNESS)

SHIFT_STEPS = np.linspace(0.0, max_non_overlap_shift, NUM_SHIFTS)
# e.g. [0.0000, 0.0018, 0.0036, …, 0.0200] — still within the original band

specimens_per_file = 2**14 # should be changed, carries no meaning

subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/sliced_testing_data/20_0007_no_proj")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)

print(f"NUM_SLICES = {NUM_SLICES}")
print(f"SLICE_THICKNESS = {SLICE_THICKNESS}")
print(f"NUM_SHIFTS = {NUM_SHIFTS}")
print(f"center_spacing = {center_spacing:.6f}")
print(f"max_non_overlap_shift = {max_non_overlap_shift:.6f}")
print("SHIFT_STEPS =", ", ".join(f"{s:.6f}" for s in SHIFT_STEPS))
print(f"specimens_per_file = {specimens_per_file}")

In [None]:
for shift in SHIFT_STEPS:
    sliced_samples: List[np.ndarray] = []
    per_worm_selected: List[int] = []
    per_worm_slice_counts: List[List[int]] = []

    for canonical_ids, coords in all_worms:
        coords_arr = np.asarray(coords, dtype=np.float32)
        ids_arr = np.asarray(canonical_ids, dtype=np.int64)

        if coords_arr.ndim != 2 or coords_arr.shape[1] != 3:
            raise ValueError(f"Expected coords to be [N, 3], got {coords_arr.shape}")

        selected_idx, per_slice_counts = get_slice_indices(
            coords_arr,
            n_slices=NUM_SLICES,
            slice_thickness=SLICE_THICKNESS,
            shift=shift,
        )

        if selected_idx.size == 0:
            continue

        per_worm_selected.append(len(selected_idx))
        per_worm_slice_counts.append(per_slice_counts)

        sliced_ids = ids_arr[selected_idx]
        sliced_coords = coords_arr[selected_idx]

        sample = np.zeros((len(selected_idx), 4), dtype=np.float32)
        sample[:, 0] = sliced_ids.astype(np.float32)
        sample[:, 1:] = sliced_coords
        sliced_samples.append(sample)

        print(
            f"Sliced worm (len={len(ids_arr)}) @ shift {shift:.4f}: "
            f"{len(selected_idx)} nuclei | per-slice counts = {per_slice_counts}"
        )

    if not sliced_samples:
        print(f"No samples produced for shift {shift:.4f}; skipping file output.")
        continue

    output_dir = subgraph_output_directory / f"sliced_subgraphs_shift_{shift:.3f}"
    convert_specimens_to_hdf5(
        sliced_samples,
        output_dir=output_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(sliced_samples)} sliced subgraphs to {output_dir}")

    # ---- summary stats -----------------------------------------------------

    avg_selected = float(np.mean(per_worm_selected))
    std_selected = float(np.std(per_worm_selected))

    slice_matrix = np.asarray(per_worm_slice_counts, dtype=float)
    avg_slice_counts = slice_matrix.mean(axis=0)
    std_slice_counts = slice_matrix.std(axis=0)
    avg_per_slice_overall = float(avg_slice_counts.mean())

    stats_path = output_dir / "slice_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"shift: {shift:.4f}\n")
        fh.write(f"num_worms: {len(per_worm_selected)}\n")
        fh.write(f"avg_total_nuclei: {avg_selected:.2f}\n")
        fh.write(f"std_total_nuclei: {std_selected:.2f}\n")
        fh.write(f"avg_per_slice_overall: {avg_per_slice_overall:.2f}\n")
        fh.write("avg_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in avg_slice_counts) + "\n")
        fh.write("std_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in std_slice_counts) + "\n")
    # ------------------------------------------------------------------------



# Projecting onto the slices

In [None]:
def get_slice_indices(points: np.ndarray,
                      n_slices: int = 40,
                      slice_thickness: float = 0.005,
                      shift: float = 0.0
                      ) -> Tuple[np.ndarray, List[int], np.ndarray]:
    """
    Select nuclei with fixed-width slices and return their indices plus
    the slice-projected coordinates.

    Returns:
        indices           : sorted unique nucleus indices captured by at least one slice
        per_slice_counts  : hit count for each slice (after applying `shift`)
        projected_coords  : (len(indices), 3) array containing the coordinates
                            projected onto the slice planes, ordered to match `indices`
    """
    if points.ndim != 2 or points.shape[1] != 3:
        raise ValueError(f"Expected points to have shape (N, 3); got {points.shape}")

    pca = PCA(n_components=3)                         # fit PCA to learn the body axis and a full orthonormal basis
    pca.fit(points)
    axis = pca.components_[0]                        # unit vector along the dominant anatomical axis
    t = (points - pca.mean_) @ axis                  # scalar coordinate for each nucleus along that axis

    centers = np.linspace(t.min(), t.max(), n_slices) + shift  # evenly spaced slice centers, shifted as requested

    per_slice_counts: List[int] = []
    selected: Dict[int, Dict[str, Any]] = {}
    half_thickness = slice_thickness / 2.0

    for center in centers:
        mask = np.abs(t - center) <= half_thickness   # nuclei whose axial distance from this slice is within half-thickness
        hits = np.where(mask)[0]
        per_slice_counts.append(int(hits.size))

        for idx in hits:
            offset = t[idx] - center                  # signed distance from nucleus center to slice center along the axis
            abs_offset = abs(offset)

            prev = selected.get(idx)                  # if the nucleus belongs to multiple slices, keep the closest one
            if prev is None or abs_offset < prev["abs_offset"]:
                projected = points[idx] - offset * axis     # drop the perpendicular component onto the slice plane
                selected[idx] = {
                    "abs_offset": abs_offset,
                    "projected": projected,
                }

    if not selected:
        return np.array([], dtype=int), per_slice_counts, np.zeros((0, 3), dtype=points.dtype)

    indices = np.array(sorted(selected.keys()), dtype=int)
    projected_coords = np.stack([selected[idx]["projected"] for idx in indices], axis=0)

    total_selected = len(indices)
    avg_per_slice = float(np.mean(per_slice_counts)) if per_slice_counts else 0.0
    print(
        f"[get_slice_indices] shift={shift:.4f} | total nuclei selected: {total_selected} | "
        f"average per slice: {avg_per_slice:.2f}"
    )

    return indices, per_slice_counts, projected_coords


In [None]:
NUM_SLICES = 40
SLICE_THICKNESS = 0.005
NUM_SHIFTS = 12                     # 12 shifts × 200 worms = 2400 samples
                                 
center_spacing = 1 / (NUM_SLICES - 1)   # distance between consecutive slice centres. It is fixed to 1 / ... because all wors are scaled to length 1
max_non_overlap_shift = max(0.0, center_spacing - SLICE_THICKNESS)

SHIFT_STEPS = np.linspace(0.0, max_non_overlap_shift, NUM_SHIFTS)
# e.g. [0.0000, 0.0018, 0.0036, …, 0.0200] — still within the original band

subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/sliced_testing_data/20_0007_with_proj")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)

specimens_per_file = 2**14 # should be changed, carries no meaning

print(f"NUM_SLICES = {NUM_SLICES}")
print(f"SLICE_THICKNESS = {SLICE_THICKNESS}")
print(f"NUM_SHIFTS = {NUM_SHIFTS}")
print(f"center_spacing = {center_spacing:.6f}")
print(f"max_non_overlap_shift = {max_non_overlap_shift:.6f}")
print("SHIFT_STEPS =", ", ".join(f"{s:.6f}" for s in SHIFT_STEPS))
print(f"specimens_per_file = {specimens_per_file}")

In [None]:
for shift in SHIFT_STEPS:
    sliced_samples: List[np.ndarray] = []
    per_worm_selected: List[int] = []
    per_worm_slice_counts: List[List[int]] = []

    for canonical_ids, coords in all_worms:
        coords_arr = np.asarray(coords, dtype=np.float32)
        ids_arr = np.asarray(canonical_ids, dtype=np.int64)

        if coords_arr.ndim != 2 or coords_arr.shape[1] != 3:
            raise ValueError(f"Expected coords to be [N, 3], got {coords_arr.shape}")

        selected_idx, per_slice_counts, projected_coords = get_slice_indices(
            coords_arr,
            n_slices=NUM_SLICES,
            slice_thickness=SLICE_THICKNESS,
            shift=shift,
        )

        if selected_idx.size == 0:
            continue

        per_worm_selected.append(len(selected_idx))
        per_worm_slice_counts.append(per_slice_counts)

        sliced_ids = ids_arr[selected_idx]             # keep canonical IDs in the slice order
        sample = np.zeros((len(selected_idx), 4), dtype=np.float32)
        sample[:, 0] = sliced_ids.astype(np.float32)   # column 0: canonical IDs
        sample[:, 1:] = projected_coords               # columns 1–3: 3-D coordinates lying on the slice planes
        sliced_samples.append(sample)

        print(
            f"Sliced worm (len={len(ids_arr)}) @ shift {shift:.4f}: "
            f"{len(selected_idx)} nuclei | per-slice counts = {per_slice_counts}"
        )

    if not sliced_samples:
        print(f"No samples produced for shift {shift:.4f}; skipping file output.")
        continue

    output_dir = subgraph_output_directory / f"sliced_subgraphs_shift_{shift:.3f}"
    convert_specimens_to_hdf5(
        sliced_samples,
        output_dir=output_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(sliced_samples)} sliced subgraphs to {output_dir}")

    avg_selected = float(np.mean(per_worm_selected))
    std_selected = float(np.std(per_worm_selected))

    slice_matrix = np.asarray(per_worm_slice_counts, dtype=float)
    avg_slice_counts = slice_matrix.mean(axis=0)
    std_slice_counts = slice_matrix.std(axis=0)
    avg_per_slice_overall = float(avg_slice_counts.mean())

    stats_path = output_dir / "slice_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"shift: {shift:.4f}\n")
        fh.write(f"num_worms: {len(per_worm_selected)}\n")
        fh.write(f"avg_total_nuclei: {avg_selected:.2f}\n")
        fh.write(f"std_total_nuclei: {std_selected:.2f}\n")
        fh.write(f"avg_per_slice_overall: {avg_per_slice_overall:.2f}\n")
        fh.write("avg_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in avg_slice_counts) + "\n")
        fh.write("std_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in std_slice_counts) + "\n")

# Cross section cropping

In [None]:
NUM_SLICES = 40
SLICE_THICKNESS = 0.005
NUM_SHIFTS = 12                     # 12 shifts × 200 worms = 2400 samples
                                 
center_spacing = 1 / (NUM_SLICES - 1)   # distance between consecutive slice centres. It is fixed to 1 / ... because all wors are scaled to length 1
max_non_overlap_shift = max(0.0, center_spacing - SLICE_THICKNESS)

SHIFT_STEPS = np.linspace(0.0, max_non_overlap_shift, NUM_SHIFTS)
# e.g. [0.0000, 0.0018, 0.0036, …, 0.0200] — still within the original band

subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/sliced_tests/sliced_testing_data/cross_section/40_0005_with_proj_y_neg")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)

specimens_per_file = 2**14 # should be changed, carries no meaning

CROP_AXIS = 'y'
CROP_SIDE = 'negative'
CROP_FRACTION = 0.0  


print(f"NUM_SLICES = {NUM_SLICES}")
print(f"SLICE_THICKNESS = {SLICE_THICKNESS}")
print(f"NUM_SHIFTS = {NUM_SHIFTS}")
print(f"center_spacing = {center_spacing:.6f}")
print(f"max_non_overlap_shift = {max_non_overlap_shift:.6f}")
print("SHIFT_STEPS =", ", ".join(f"{s:.6f}" for s in SHIFT_STEPS))
print(f"specimens_per_file = {specimens_per_file}")
print(f"CROP_AXIS = {CROP_AXIS}")
print(f"CROP_SIDE = {CROP_SIDE}")
print(f"CROP_FRACTION = {CROP_FRACTION}")

In [None]:
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from sklearn.decomposition import PCA

def get_slice_indices(points: np.ndarray,
                      n_slices: int = 40,
                      slice_thickness: float = 0.005,
                      shift: float = 0.0,
                      *,
                      crop_axis: Optional[str] = None,
                      crop_side: str = "positive",
                      crop_fraction: float = 0.0
                      ) -> Tuple[np.ndarray, List[int], np.ndarray]:
    """
    Select nuclei with fixed-width slices, project them onto slice planes,
    and optionally drop one side of each slice (LR or DV). When a crop is requested,
    the entire chosen side is removed with the exception of a thin band (crop_fraction)
    near the center line that can remain.

    Args:
        points          : (N, 3) array of Cartesian nucleus centers.
        n_slices        : number of slice planes along the worm’s main axis.
        slice_thickness : axial thickness of each slice.
        shift           : uniform axial offset applied to every slice center.
        crop_axis       : 'x', 'y', or None; choose LR or DV for trimming.
        crop_side       : 'positive' or 'negative'; which half-space to mostly remove.
        crop_fraction   : fraction (0–1) of the removed side’s span to keep near the center.

    Returns:
        indices           : sorted nucleus indices retained after slicing/cropping.
        per_slice_counts  : nuclei per slice after all filters.
        projected_coords  : coordinates projected onto their slice planes.
    """
    pca = PCA(n_components=3)               # Fit PCA to get axial & slice-plane bases
    pca.fit(points)

    axial_axis = pca.components_[0]         # dominant head–tail direction
    mean = pca.mean_
    t = (points - mean) @ axial_axis        # axial coordinate for each nucleus

    global_x = np.array([1.0, 0.0, 0.0])    # dataset LR direction
    global_y = np.array([0.0, 1.0, 0.0])    # dataset DV direction
    in_plane = [pca.components_[1], pca.components_[2]]  # orthonormal slice-plane basis

    dot_lr = [abs(np.dot(vec, global_x)) for vec in in_plane]
    lr_index = int(np.argmax(dot_lr))       # pick the basis vector that aligns best with +x
    plane_lr = in_plane[lr_index].copy()
    plane_dv = in_plane[1 - lr_index].copy()

    if np.dot(plane_lr, global_x) < 0:      # ensure +LR points to +x
        plane_lr *= -1
    if np.dot(plane_dv, global_y) < 0:      # ensure +DV points to +y
        plane_dv *= -1

    centered = points - mean                # center cloud before in-plane projection
    local_lr = centered @ plane_lr          # LR coordinate per nucleus
    local_dv = centered @ plane_dv          # DV coordinate per nucleus

    centers = np.linspace(t.min(), t.max(), n_slices) + shift  # slice centers along axis
    selected: Dict[int, Dict[str, Any]] = {}
    half_thickness = slice_thickness / 2.0

    for slice_idx, center in enumerate(centers):
        mask = np.abs(t - center) <= half_thickness
        hits = np.where(mask)[0]

        if crop_axis in {"x", "y"} and hits.size > 0:
            slice_values = local_lr[hits] if crop_axis == "x" else local_dv[hits]

            if crop_side == "positive":
                # keep everything at or below the center line (value <= 0)
                keep_mask = slice_values <= 0.0
                if crop_fraction > 0.0:
                    pos_vals = slice_values[slice_values > 0.0]
                    if pos_vals.size > 0:
                        span = pos_vals.max()          # farthest reach on the removed side
                        allowed = span * crop_fraction # band to keep near the center
                        keep_mask |= (slice_values > 0.0) & (slice_values <= allowed)
            else:
                keep_mask = slice_values >= 0.0
                if crop_fraction > 0.0:
                    neg_vals = slice_values[slice_values < 0.0]
                    if neg_vals.size > 0:
                        span = abs(neg_vals.min())
                        allowed = span * crop_fraction
                        keep_mask |= (slice_values < 0.0) & (slice_values >= -allowed)

            hits = hits[keep_mask]

        for idx in hits:
            offset = t[idx] - center
            abs_offset = abs(offset)

            prev = selected.get(idx)
            if prev is None or abs_offset < prev["abs_offset"]:
                projected = points[idx] - offset * axial_axis  # drop axial component → slice plane
                selected[idx] = {
                    "abs_offset": abs_offset,
                    "projected": projected,
                    "slice_index": slice_idx,
                }

    if not selected:
        empty_counts = [0] * n_slices
        empty_coords = np.zeros((0, 3), dtype=points.dtype)
        return np.array([], dtype=int), empty_counts, empty_coords

    per_slice_counts = [0] * n_slices
    for entry in selected.values():
        per_slice_counts[entry["slice_index"]] += 1

    indices = np.array(sorted(selected.keys()), dtype=int)
    projected_coords = np.stack([selected[idx]["projected"] for idx in indices], axis=0)

    total_selected = len(indices)
    avg_per_slice = float(np.mean(per_slice_counts)) if per_slice_counts else 0.0
    print(
        f"[get_slice_indices] shift={shift:.4f} | total nuclei selected: {total_selected} | "
        f"average per slice: {avg_per_slice:.2f}"
    )

    return indices, per_slice_counts, projected_coords

In [None]:
for shift in SHIFT_STEPS:
    sliced_samples: List[np.ndarray] = []
    per_worm_selected: List[int] = []
    per_worm_slice_counts: List[List[int]] = []

    for canonical_ids, coords in all_worms:
        coords_arr = np.asarray(coords, dtype=np.float32)
        ids_arr = np.asarray(canonical_ids, dtype=np.int64)

        if coords_arr.ndim != 2 or coords_arr.shape[1] != 3:
            raise ValueError(f"Expected coords to be [N, 3], got {coords_arr.shape}")

        selected_idx, per_slice_counts, projected_coords = get_slice_indices(
            coords_arr,
            n_slices=NUM_SLICES,
            slice_thickness=SLICE_THICKNESS,
            shift=shift,
            crop_axis=CROP_AXIS,
            crop_side=CROP_SIDE,
            crop_fraction=CROP_FRACTION,
        )

        if selected_idx.size == 0:
            continue

        per_worm_selected.append(len(selected_idx))
        per_worm_slice_counts.append(per_slice_counts)

        sliced_ids = ids_arr[selected_idx]             # keep canonical IDs in the slice order
        sample = np.zeros((len(selected_idx), 4), dtype=np.float32)
        sample[:, 0] = sliced_ids.astype(np.float32)   # column 0: canonical IDs
        sample[:, 1:] = projected_coords               # columns 1–3: 3-D coordinates lying on the slice planes
        sliced_samples.append(sample)

        print(
            f"Sliced worm (len={len(ids_arr)}) @ shift {shift:.4f}: "
            f"{len(selected_idx)} nuclei | per-slice counts = {per_slice_counts}"
        )

    if not sliced_samples:
        print(f"No samples produced for shift {shift:.4f}; skipping file output.")
        continue

    output_dir = subgraph_output_directory / f"sliced_subgraphs_shift_{shift:.3f}"
    convert_specimens_to_hdf5(
        sliced_samples,
        output_dir=output_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(sliced_samples)} sliced subgraphs to {output_dir}")

    avg_selected = float(np.mean(per_worm_selected))
    std_selected = float(np.std(per_worm_selected))

    slice_matrix = np.asarray(per_worm_slice_counts, dtype=float)
    avg_slice_counts = slice_matrix.mean(axis=0)
    std_slice_counts = slice_matrix.std(axis=0)
    avg_per_slice_overall = float(avg_slice_counts.mean())

    stats_path = output_dir / "slice_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"shift: {shift:.4f}\n")
        fh.write(f"num_worms: {len(per_worm_selected)}\n")
        fh.write(f"avg_total_nuclei: {avg_selected:.2f}\n")
        fh.write(f"std_total_nuclei: {std_selected:.2f}\n")
        fh.write(f"avg_per_slice_overall: {avg_per_slice_overall:.2f}\n")
        fh.write("avg_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in avg_slice_counts) + "\n")
        fh.write("std_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in std_slice_counts) + "\n")

# Alignment check

In [None]:
import numpy as np
from sklearn.decomposition import PCA

def check_alignment_all_with_ids(
    worms_iterable,
    *,
    z_axis=np.array([0.0, 0.0, 1.0]),
    x_axis=np.array([1.0, 0.0, 0.0]),
    y_axis=np.array([0.0, 1.0, 0.0]),
    good_thresh: float = 0.98,     # |cos(AP,Z)| >= 0.98 ≈ tilt <= ~11.5°
    warn_thresh: float = 0.95      # |cos(AP,Z)| >= 0.95 ≈ tilt <= ~18.2°
):
    """
    worms_iterable: iterable of (canonical_ids, coords)
        canonical_ids: (N,) int-like or str-like IDs
        coords:        (N, 3) float array with (x=LR, y=DV, z=AP) if aligned

    Returns:
        {
          "rows": list of per-worm dicts,
          "cos_ap_z": np.ndarray,
          "summary": dict
        }
    """
    rows = []
    cos_vals = []

    for w_idx, (canonical_ids, coords) in enumerate(worms_iterable):
        coords = np.asarray(coords, dtype=float)
        if coords.ndim != 2 or coords.shape[1] != 3:
            raise ValueError(f"Worm {w_idx}: coords must be (N,3), got {coords.shape}")

        # PCA → AP axis
        pca = PCA(n_components=3)
        pca.fit(coords)
        ap_axis = pca.components_[0]
        ap_axis = ap_axis / np.linalg.norm(ap_axis)

        # Cosines with global axes
        cZ = float(np.dot(ap_axis, z_axis))
        cY = float(np.dot(ap_axis, y_axis))
        cX = float(np.dot(ap_axis, x_axis))

        # Tilt angle from Z in degrees (use absolute to ignore head/tail sign)
        tilt_deg = float(np.degrees(np.arccos(min(1.0, max(-1.0, abs(cZ))))))

        # Keep a small ID fingerprint for reference (first few canonical IDs)
        ids_arr = np.asarray(canonical_ids)
        id_preview = ids_arr[:3].tolist() if ids_arr.ndim == 1 else None

        row = {
            "worm_index": w_idx,
            "n_points": int(coords.shape[0]),
            "cos_ap_z": cZ,
            "cos_ap_y": cY,
            "cos_ap_x": cX,
            "tilt_deg_from_z": tilt_deg,
            "id_preview": id_preview,
        }
        rows.append(row)
        cos_vals.append(cZ)

        print(
            f"Worm {w_idx:03d} | N={coords.shape[0]:4d} | "
            f"AP•Z={cZ:+.4f} | tilt={tilt_deg:6.2f}° | "
            f"AP•X={cX:+.4f} AP•Y={cY:+.4f} | ids={id_preview}"
        )

    cos_vals = np.asarray(cos_vals, dtype=float)
    abs_cos = np.abs(cos_vals)

    # Summary stats
    mean_cos = float(np.mean(cos_vals)) if cos_vals.size else float("nan")
    std_cos  = float(np.std(cos_vals))  if cos_vals.size else float("nan")
    min_cos  = float(np.min(cos_vals))  if cos_vals.size else float("nan")
    max_cos  = float(np.max(cos_vals))  if cos_vals.size else float("nan")

    pct_good = float(np.mean(abs_cos >= good_thresh) * 100.0) if cos_vals.size else 0.0
    pct_warn = float(np.mean((abs_cos >= warn_thresh) & (abs_cos < good_thresh)) * 100.0) if cos_vals.size else 0.0
    pct_bad  = 100.0 - pct_good - pct_warn if cos_vals.size else 0.0

    print("\nSummary over worms")
    print(f"Mean cos(AP,Z): {mean_cos:+.4f} | Std: {std_cos:.4f} | Min: {min_cos:+.4f} | Max: {max_cos:+.4f}")
    print(f"Well aligned   (|cos| ≥ {good_thresh:.2f}): {pct_good:5.1f}%")
    print(f"Moderate tilt  (|cos| ≥ {warn_thresh:.2f} & < {good_thresh:.2f}): {pct_warn:5.1f}%")
    print(f"Poor alignment (|cos| <  {warn_thresh:.2f}): {pct_bad:5.1f}%")

    summary = {
        "mean_cos_ap_z": mean_cos,
        "std_cos_ap_z": std_cos,
        "min_cos_ap_z": min_cos,
        "max_cos_ap_z": max_cos,
        "pct_well_aligned": pct_good,
        "pct_moderate_tilt": pct_warn,
        "pct_poor_alignment": pct_bad,
        "good_thresh": good_thresh,
        "warn_thresh": warn_thresh,
    }

    return {"rows": rows, "cos_ap_z": cos_vals, "summary": summary}


In [None]:
import numpy as np
from sklearn.decomposition import PCA

def check_alignment_xy_and_plane_with_ids(
    worms_iterable,
    *,
    z_axis=np.array([0.0, 0.0, 1.0]),
    x_axis=np.array([1.0, 0.0, 0.0]),
    y_axis=np.array([0.0, 1.0, 0.0]),
    ap_good_thresh: float = 0.98   # |cos(AP,Z)| >= 0.98 considered well-aligned
):
    """
    Extended alignment diagnostics per worm.

    Input:
      worms_iterable: iterable of (canonical_ids, coords) with coords (N,3)

    Prints (per worm):
      - AP·Z and tilt angle from Z
      - ap·x and ap·y (should be ~0 if AP ≡ Z)
      - |PC1·x|, |PC1·y|, |PC2·x|, |PC2·y| (in-plane relationships)
      - handedness sign of [PC1, PC2, AP] (should be +1 for right-handed)

    Returns:
      dict with per-worm rows and simple summary on AP·Z.
    """
    rows = []
    cos_ap_z_all = []

    for w_idx, (canonical_ids, coords) in enumerate(worms_iterable):
        P = np.asarray(coords, dtype=float)
        if P.ndim != 2 or P.shape[1] != 3:
            raise ValueError(f"Worm {w_idx}: coords must be (N,3), got {P.shape}")

        pca = PCA(n_components=3).fit(P)
        # By sklearn convention, components_[0] has the largest variance (AP axis)
        ap = pca.components_[0]; ap /= np.linalg.norm(ap)
        pc1 = pca.components_[1]; pc1 /= np.linalg.norm(pc1)
        pc2 = pca.components_[2]; pc2 /= np.linalg.norm(pc2)

        # Make sure pc2 is orthogonal to ap and pc1 numerically (for sanity)
        # (PCA already gives orthonormal vectors; we just rely on sklearn.)
        cos_ap_z = float(np.dot(ap, z_axis))
        tilt_deg = float(np.degrees(np.arccos(np.clip(abs(cos_ap_z), -1.0, 1.0))))

        # Orthogonality of global x,y with AP (should be ~0 if AP ≡ Z)
        ap_dot_x = float(np.dot(ap, x_axis))
        ap_dot_y = float(np.dot(ap, y_axis))

        # In-plane relationships (absolute cosines)
        c_pc1_x = float(abs(np.dot(pc1, x_axis)))
        c_pc1_y = float(abs(np.dot(pc1, y_axis)))
        c_pc2_x = float(abs(np.dot(pc2, x_axis)))
        c_pc2_y = float(abs(np.dot(pc2, y_axis)))

        # Handedness of the PCA basis
        handed = np.linalg.det(np.stack([pc1, pc2, ap], axis=1))
        handed_sign = +1 if handed > 0 else -1

        rows.append({
            "worm_index": w_idx,
            "n_points": int(P.shape[0]),
            "cos_ap_z": cos_ap_z,
            "tilt_deg_from_z": tilt_deg,
            "ap_dot_x": ap_dot_x,
            "ap_dot_y": ap_dot_y,
            "abs_pc1_dot_x": c_pc1_x,
            "abs_pc1_dot_y": c_pc1_y,
            "abs_pc2_dot_x": c_pc2_x,
            "abs_pc2_dot_y": c_pc2_y,
            "handed_sign": handed_sign,
        })

        print(
            f"Worm {w_idx:03d} | N={P.shape[0]:4d} | "
            f"AP•Z={cos_ap_z:+.4f} (tilt={tilt_deg:5.2f}°) | "
            f"AP•X={ap_dot_x:+.4e} AP•Y={ap_dot_y:+.4e} | "
            f"|PC1•X|={c_pc1_x:.3f} |PC1•Y|={c_pc1_y:.3f} | "
            f"|PC2•X|={c_pc2_x:.3f} |PC2•Y|={c_pc2_y:.3f} | "
            f"handed={'RH' if handed_sign>0 else 'LH'}"
        )

        cos_ap_z_all.append(cos_ap_z)

    cos_ap_z_all = np.asarray(cos_ap_z_all, dtype=float)
    pct_aligned = float(np.mean(np.abs(cos_ap_z_all) >= ap_good_thresh) * 100.0) if cos_ap_z_all.size else 0.0
    print(f"\nWell aligned to Z (|AP•Z| >= {ap_good_thresh:.2f}): {pct_aligned:.1f}%")

    return {
        "rows": rows,
        "cos_ap_z": cos_ap_z_all,
        "pct_well_aligned": pct_aligned,
        "ap_good_thresh": ap_good_thresh,
    }


In [None]:
report = check_alignment_xy_and_plane_with_ids(all_worms)

# New cross section cropping slicing function

In [None]:
NUM_SLICES = 20
SLICE_THICKNESS = 0.008
NUM_SHIFTS = 12                     # 12 shifts × 200 worms = 2400 samples
                                 
center_spacing = 1 / (NUM_SLICES - 1)   # distance between consecutive slice centres. It is fixed to 1 / ... because all worms are scaled to length 1
max_non_overlap_shift = max(0.0, center_spacing - SLICE_THICKNESS)

SHIFT_STEPS = np.linspace(0.0, max_non_overlap_shift, NUM_SHIFTS)
# e.g. [0.0000, 0.0018, 0.0036, …, 0.0200] — still within the original band

subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/new_data/20_0008_05")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)

specimens_per_file = 2**14 # should be changed, carries no meaning

CROP_AXIS = 'random'
CROP_FRACTION_RANGE = (0.0, np.nextafter(1.0, np.inf))

seed = 42
RNG = np.random.default_rng(seed)

# CROP_SIDE = 'positive'
# CROP_FRACTION = 0.25

print(f"NUM_SLICES = {NUM_SLICES}")
print(f"SLICE_THICKNESS = {SLICE_THICKNESS}")
print(f"NUM_SHIFTS = {NUM_SHIFTS}")
print(f"center_spacing = {center_spacing:.6f}")
print(f"max_non_overlap_shift = {max_non_overlap_shift:.6f}")
print("SHIFT_STEPS =", ", ".join(f"{s:.6f}" for s in SHIFT_STEPS))
print(f"specimens_per_file = {specimens_per_file}")
print(f"CROP_AXIS = {CROP_AXIS}")
# print(f"CROP_SIDE = {CROP_SIDE}")
print(f"CROP_FRACTION_RANGE = {CROP_FRACTION_RANGE}")
print(f"RNG seed = {seed}")

In [None]:
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from sklearn.decomposition import PCA

def get_slice_indices(
    points: np.ndarray,
    n_slices: int = 40,
    slice_thickness: float = 0.005,
    shift: float = 0.0,
    *,
    crop_axis: Optional[str] = None,   # 'x', 'y', 'random', or None
    crop_side: str = "positive",
    crop_fraction: float = 0.0,
    random_state: Optional[np.random.Generator] = None,
) -> Tuple[np.ndarray, List[int], np.ndarray]:
    """
    Slice a worm‑shaped point cloud into thin slabs and optionally remove one half
    of each slice. The main (AP) axis is estimated via PCA; crop_axis controls
    whether you drop the positive or negative side of x (LR), y (DV), a random
    in‑plane direction, or nothing at all. crop_fraction retains a thin band
    near the centreline on the removed side.

    Parameters
    ----------
    points : (N, 3) array
        Cartesian nucleus centres.  Axes are assumed to be (x=LR, y=DV, z=AP).
    n_slices : int
        Number of slices along the worm’s length.
    slice_thickness : float
        Thickness of each axial slab (distance along the AP axis).
    shift : float
        Uniform shift applied to all slice centres along the AP axis.
    crop_axis : {"x", "y", "random", None}
        Direction along which to crop within each slice:
          • None: no cropping.
          • 'x' : remove LR half (positive or negative).
          • 'y' : remove DV half.
          • 'random': remove half along a random direction in the xy‑plane.
    crop_side : {"positive", "negative"}
        Which half to remove (“positive” removes values > 0 and keeps ≤ 0).
    crop_fraction : float
        Fraction of the removed side’s span to retain near zero. 0.0 keeps
        nothing; 1.0 keeps the entire removed side.
    random_state : np.random.Generator, optional
        Source of randomness for 'random' crop_axis; use to make behaviour reproducible.

    Returns
    -------
    indices : (M,) array of int
        Sorted indices of nuclei that survived slicing and cropping.
    per_slice_counts : list of int
        Number of nuclei kept in each of the n_slices axial slabs.
    projected_coords : (M, 3) array
        The 3D coordinates projected onto their slice planes.
    """
    # Step 1: fit PCA to find the AP axis; use only PC0 for slicing.
    pca = PCA(n_components=3)
    pca.fit(points)
    axial_axis = pca.components_[0]           # unit vector for AP
    mean = pca.mean_
    t = (points - mean) @ axial_axis          # scalar position along AP for each point

    # Centre the cloud once; used for all in‑plane coordinate calculations.
    centered = points - mean
    slice_coord: Optional[np.ndarray] = None  # will hold signed distances for cropping

    # Step 2: choose in‑plane direction based on crop_axis.
    if crop_axis is None:
        # no cropping requested
        pass
    elif crop_axis == "x":
        # crop along LR; positive x values lie on one side, negative on the other
        slice_coord = centered[:, 0]
    elif crop_axis == "y":
        # crop along DV
        slice_coord = centered[:, 1]
    elif crop_axis == "random":
        # sample a random unit vector in the xy‑plane (perpendicular to z)
        rng = random_state or np.random.default_rng()
        angle = rng.uniform(0.0, 2.0 * np.pi)
        cos_orient = np.cos(angle)
        sin_orient = np.sin(angle)
        global_x_unit = np.array([1.0, 0.0, 0.0], dtype=points.dtype)
        global_y_unit = np.array([0.0, 1.0, 0.0], dtype=points.dtype)
        rand_vec = cos_orient * global_x_unit + sin_orient * global_y_unit
        slice_coord = centered @ rand_vec
    else:
        raise ValueError(f"crop_axis must be 'x', 'y', 'random' or None; got {crop_axis}")

    # Step 3: create slice centres along AP and assign points to slabs.
    centres = np.linspace(t.min(), t.max(), n_slices) + shift
    selected: Dict[int, Dict[str, Any]] = {}
    half_thickness = slice_thickness / 2.0

    for slice_idx, centre_val in enumerate(centres):
        mask = np.abs(t - centre_val) <= half_thickness
        hits = np.where(mask)[0]

        # Apply optional in‑plane cropping.
        if slice_coord is not None and hits.size > 0:
            vals = slice_coord[hits]
            if crop_side == "positive":
                keep_mask = vals <= 0.0
                if crop_fraction > 0.0:
                    pos_vals = vals[vals > 0.0]
                    if pos_vals.size > 0:
                        span = pos_vals.max()
                        allowed = span * crop_fraction
                        keep_mask |= (vals > 0.0) & (vals <= allowed)
            else:
                keep_mask = vals >= 0.0
                if crop_fraction > 0.0:
                    neg_vals = vals[vals < 0.0]
                    if neg_vals.size > 0:
                        span = abs(neg_vals.min())
                        allowed = span * crop_fraction
                        keep_mask |= (vals < 0.0) & (vals >= -allowed)
            hits = hits[keep_mask]

        # Project survivors onto their slice plane; deduplicate by keeping the nearest slice.
        for idx in hits:
            offset = t[idx] - centre_val
            abs_offset = abs(offset)
            prev = selected.get(idx)
            if prev is None or abs_offset < prev["abs_offset"]:
                projected = points[idx] - offset * axial_axis
                selected[idx] = {
                    "abs_offset": abs_offset,
                    "projected": projected,
                    "slice_index": slice_idx,
                }

    if not selected:
        return np.array([], dtype=int), [0] * n_slices, np.zeros((0, 3), dtype=points.dtype)

    per_slice_counts: List[int] = [0] * n_slices
    for v in selected.values():
        per_slice_counts[v["slice_index"]] += 1

    indices = np.array(sorted(selected.keys()), dtype=int)
    projected_coords = np.stack([selected[i]["projected"] for i in indices], axis=0)

    return indices, per_slice_counts, projected_coords


In [None]:
for shift in SHIFT_STEPS:
    sliced_samples: List[np.ndarray] = []
    per_worm_selected: List[int] = []
    per_worm_slice_counts: List[List[int]] = []

    for canonical_ids, coords in all_worms:
        coords_arr = np.asarray(coords, dtype=np.float32)
        ids_arr = np.asarray(canonical_ids, dtype=np.int64)

        if coords_arr.ndim != 2 or coords_arr.shape[1] != 3:
            raise ValueError(f"Expected coords to be [N, 3], got {coords_arr.shape}")
        
        crop_side = RNG.choice(("positive", "negative"))
        crop_fraction = RNG.uniform(*CROP_FRACTION_RANGE)
        random_state = np.random.default_rng(RNG.integers(0, 2**32))

        selected_idx, per_slice_counts, projected_coords = get_slice_indices(
            coords_arr,
            n_slices=NUM_SLICES,
            slice_thickness=SLICE_THICKNESS,
            shift=shift,
            crop_axis=CROP_AXIS,
            crop_side=crop_side,
            crop_fraction=crop_fraction,
            random_state=random_state,
        )

        if selected_idx.size == 0:
            continue

        per_worm_selected.append(len(selected_idx))
        per_worm_slice_counts.append(per_slice_counts)

        sliced_ids = ids_arr[selected_idx]             # keep canonical IDs in the slice order
        sample = np.zeros((len(selected_idx), 4), dtype=np.float32)
        sample[:, 0] = sliced_ids.astype(np.float32)   # column 0: canonical IDs
        sample[:, 1:] = projected_coords               # columns 1–3: 3-D coordinates lying on the slice planes
        sliced_samples.append(sample)

        print(
            f"Sliced worm (len={len(ids_arr)}) @ shift {shift:.4f}: "
            f"{len(selected_idx)} nuclei | per-slice counts = {per_slice_counts}"
            f"| crop_side={crop_side} | crop_fraction={crop_fraction:.3f}"
        )

    if not sliced_samples:
        print(f"No samples produced for shift {shift:.4f}; skipping file output.")
        continue

    output_dir = subgraph_output_directory / f"sliced_subgraphs_shift_{shift:.3f}"
    convert_specimens_to_hdf5(
        sliced_samples,
        output_dir=output_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(sliced_samples)} sliced subgraphs to {output_dir}")

    avg_selected = float(np.mean(per_worm_selected))
    std_selected = float(np.std(per_worm_selected))

    slice_matrix = np.asarray(per_worm_slice_counts, dtype=float)
    avg_slice_counts = slice_matrix.mean(axis=0)
    std_slice_counts = slice_matrix.std(axis=0)
    avg_per_slice_overall = float(avg_slice_counts.mean())

    stats_path = output_dir / "slice_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"shift: {shift:.4f}\n")
        fh.write(f"num_worms: {len(per_worm_selected)}\n")
        fh.write(f"avg_total_nuclei: {avg_selected:.2f}\n")
        fh.write(f"std_total_nuclei: {std_selected:.2f}\n")
        fh.write(f"avg_per_slice_overall: {avg_per_slice_overall:.2f}\n")
        fh.write("avg_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in avg_slice_counts) + "\n")
        fh.write("std_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in std_slice_counts) + "\n")

# Saving the whole worm as well

In [4]:
def convert_specimens_to_hdf5(
    specimens_data: Union[np.ndarray, Sequence[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]],
    output_dir: Union[str, Path],
    specimens_per_file: int = 2**14,
) -> None:
    """
    Save sliced subgraphs (and optional full-worm copies) into test/*.h5 files.

    Each element of `specimens_data` may be:
      • a single array shaped [M, 4]          → only a subgraph (legacy behaviour)
      • a tuple (subgraph, full_worm), both [*, 4] arrays with canonical_id,x,y,z.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    specimens_list = list(specimens_data)
    split_dir = output_dir / "test"
    split_dir.mkdir(exist_ok=True)

    file_idx = 0
    specimen_idx = 0

    for start_idx in range(0, len(specimens_list), specimens_per_file):
        end_idx = min(start_idx + specimens_per_file, len(specimens_list))
        batch_data = specimens_list[start_idx:end_idx]

        file_path = split_dir / f"test_{file_idx:04d}.h5"
        specimen_idx = create_hdf5_file_with_full_worms(
            batch_data,
            file_path,
            file_prefix="specimen",
            start_idx=specimen_idx,
        )
        file_idx += 1

    info = {
        "total_specimens": len(specimens_list),
        "split": "test",
        "specimens_per_file": specimens_per_file,
        "format": {
            "subgraph": "[canonical_id, x, y, z]",
            "full_worm": "[canonical_id, x, y, z]",
        },
    }

    with open(output_dir / "dataset_info.json", "w") as f:
        json.dump(info, f, indent=2)


In [5]:
def create_hdf5_file_with_full_worms(
    specimens: Sequence[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]],
    file_path: Union[str, Path],
    *,
    file_prefix: str = "specimen",
    start_idx: int = 0,
) -> int:
    """
    Write a chunk of specimens into one HDF5 file.
    Accepts either raw subgraphs or (subgraph, full_worm) pairs.
    """
    file_path = Path(file_path)
    file_path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(file_path, "w") as h5f:
        group = h5f.create_group("specimens")
        specimen_id = start_idx

        for entry in specimens:
            if isinstance(entry, tuple) and len(entry) == 2:
                subgraph, full_worm = entry
            else:
                subgraph = np.asarray(entry, dtype=np.float32)
                full_worm = subgraph

            if subgraph.ndim != 2 or subgraph.shape[1] != 4:
                raise ValueError(f"subgraph must be [N,4], got {subgraph.shape}")
            if full_worm.ndim != 2 or full_worm.shape[1] != 4:
                raise ValueError(f"full_worm must be [M,4], got {full_worm.shape}")

            specimen_group = group.create_group(f"{file_prefix}_{specimen_id:06d}")
            specimen_group.create_dataset("subgraph", data=subgraph, compression="gzip")
            specimen_group.create_dataset("full_worm", data=full_worm, compression="gzip")

            specimen_id += 1

    return specimen_id


In [None]:
for shift in SHIFT_STEPS:
    sliced_samples: List[Tuple[np.ndarray, np.ndarray]] = []  # (subgraph, full_worm)
    per_worm_selected: List[int] = []
    per_worm_slice_counts: List[List[int]] = []

    for canonical_ids, coords in all_worms:
        coords_arr = np.asarray(coords, dtype=np.float32)
        ids_arr = np.asarray(canonical_ids, dtype=np.int64)

        if coords_arr.ndim != 2 or coords_arr.shape[1] != 3:
            raise ValueError(f"Expected coords to be [N, 3], got {coords_arr.shape}")
        
        crop_side = RNG.choice(("positive", "negative"))
        crop_fraction = RNG.uniform(*CROP_FRACTION_RANGE)
        random_state = np.random.default_rng(RNG.integers(0, 2**32))

        selected_idx, per_slice_counts, projected_coords = get_slice_indices(
            coords_arr,
            n_slices=NUM_SLICES,
            slice_thickness=SLICE_THICKNESS,
            shift=shift,
            crop_axis=CROP_AXIS,
            crop_side=crop_side,
            crop_fraction=crop_fraction,
            random_state=random_state,
        )

        if selected_idx.size == 0:
            continue

        per_worm_selected.append(len(selected_idx))
        per_worm_slice_counts.append(per_slice_counts)

        sliced_ids = ids_arr[selected_idx]             # keep canonical IDs in the slice order
        sample = np.zeros((len(selected_idx), 4), dtype=np.float32)
        sample[:, 0] = sliced_ids.astype(np.float32)   # column 0: canonical IDs
        sample[:, 1:] = projected_coords               # columns 1–3: 3-D coordinates lying on the slice planes

        full_worm_sample = np.zeros((len(ids_arr), 4), dtype=np.float32)
        full_worm_sample[:, 0] = ids_arr.astype(np.float32)
        full_worm_sample[:, 1:] = coords_arr
        
        sliced_samples.append((sample, full_worm_sample))

        print(
            f"Sliced worm (len={len(ids_arr)}) @ shift {shift:.4f}: "
            f"{len(selected_idx)} nuclei | per-slice counts = {per_slice_counts}"
            f"| crop_side={crop_side} | crop_fraction={crop_fraction:.3f}"
        )

    if not sliced_samples:
        print(f"No samples produced for shift {shift:.4f}; skipping file output.")
        continue

    output_dir = subgraph_output_directory / f"sliced_subgraphs_shift_{shift:.3f}"
    convert_specimens_to_hdf5(
        sliced_samples,
        output_dir=output_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(sliced_samples)} sliced subgraphs to {output_dir}")

    avg_selected = float(np.mean(per_worm_selected))
    std_selected = float(np.std(per_worm_selected))

    slice_matrix = np.asarray(per_worm_slice_counts, dtype=float)
    avg_slice_counts = slice_matrix.mean(axis=0)
    std_slice_counts = slice_matrix.std(axis=0)
    avg_per_slice_overall = float(avg_slice_counts.mean())

    stats_path = output_dir / "slice_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"shift: {shift:.4f}\n")
        fh.write(f"num_worms: {len(per_worm_selected)}\n")
        fh.write(f"avg_total_nuclei: {avg_selected:.2f}\n")
        fh.write(f"std_total_nuclei: {std_selected:.2f}\n")
        fh.write(f"avg_per_slice_overall: {avg_per_slice_overall:.2f}\n")
        fh.write("avg_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in avg_slice_counts) + "\n")
        fh.write("std_per_slice_counts: " +
                 ", ".join(f"{v:.2f}" for v in std_slice_counts) + "\n")

In [6]:
specimens_per_file = 2**14
rng_seed = 42
rng = np.random.default_rng(rng_seed)
min_size = 111
max_size = 117
subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)
sample_count = 12
num_datasets = 1

In [7]:
for dataset_idx in range(num_datasets):
    # Optional: vary seed per dataset for reproducibility
    
    specimens = []        # will hold (subgraph, full_worm) tuples
    subgraph_sizes = []

    for answer, coords in all_worms:
        answer = np.asarray(answer, dtype=np.int64)          # [N]
        coords = np.asarray(coords, dtype=np.float32)        # [N, 3]
        nodes = len(answer)
        if nodes == 0:
            continue

        # Build full_worm once per worm
        full_worm = np.zeros((nodes, 4), dtype=np.float32)
        full_worm[:, 0] = answer.astype(np.float32)          # canonical_id
        full_worm[:, 1:] = coords                            # x,y,z

        # Sample subgraphs from this worm
        for _ in range(sample_count):
            subgraph_size = int(rng.integers(min_size, max_size + 1))
            sample_size = min(subgraph_size, nodes)
            subgraph_sizes.append(sample_size)

            sampled_indices = rng.choice(nodes, size=sample_size, replace=False)

            subgraph = np.zeros((sample_size, 4), dtype=np.float32)
            subgraph[:, 0] = answer[sampled_indices].astype(np.float32)
            subgraph[:, 1:] = coords[sampled_indices]

            # Pair the subgraph with the full worm
            specimens.append((subgraph, full_worm))

    subgraph_dir = subgraph_output_directory / f"dataset_{dataset_idx:02d}"
    convert_specimens_to_hdf5(
        specimens,
        output_dir=subgraph_dir,
        specimens_per_file=specimens_per_file,
    )

    avg_size = float(np.mean(subgraph_sizes)) if subgraph_sizes else float("nan")
    std_size = float(np.std(subgraph_sizes)) if subgraph_sizes else float("nan")
    min_size_obs = int(np.min(subgraph_sizes)) if subgraph_sizes else 0
    max_size_obs = int(np.max(subgraph_sizes)) if subgraph_sizes else 0

    print(
        f"Saved {len(specimens)} subgraphs "
        f"(avg size {avg_size:.2f} ± {std_size:.2f}) to {subgraph_dir}"
    )

    stats_path = subgraph_dir / "subgraph_stats.txt"
    with stats_path.open("w") as fh:
        fh.write(f"dataset_index: {dataset_idx}\n")
        fh.write(f"num_samples: {len(specimens)}\n")
        fh.write(f"size_range_config: [{min_size}, {max_size}]\n")
        fh.write(f"size_min_observed: {min_size_obs}\n")
        fh.write(f"size_max_observed: {max_size_obs}\n")
        fh.write(f"size_mean: {avg_size:.2f}\n")
        fh.write(f"size_std: {std_size:.2f}\n")


Saved 2400 subgraphs (avg size 113.99 ± 2.00) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40/dataset_00
Saved 2400 subgraphs (avg size 113.99 ± 2.00) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40/dataset_01
Saved 2400 subgraphs (avg size 113.97 ± 1.99) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40/dataset_02
Saved 2400 subgraphs (avg size 113.99 ± 1.99) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40/dataset_03
Saved 2400 subgraphs (avg size 113.97 ± 2.02) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40/dataset_04
Saved 2400 subgraphs (avg size 114.06 ± 2.00) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions/random/comparison_to_real_test_set_40/dataset_05
Saved 2400 subgraphs (avg size 113.98 ± 2.02) to /fs/pool/pool-mlsb/bulat/Wormologist/mispredictions