In [None]:
import json
import pickle
from pathlib import Path
from typing import List, Sequence, Union
import h5py
import numpy as np

In [None]:
# The key difference here is that we remove canonical_ids = np.arange(558).reshape(-1, 1)
def create_hdf5_file(
    specimens_data: List[np.ndarray],
    output_path: Path,
    file_prefix: str = "specimen",
    start_idx: int = 0,
) -> int:
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(output_path, "w") as f:
        specimens_group = f.create_group("specimens")

        for i, specimen_data in enumerate(specimens_data):
            specimen_key = f"{file_prefix}_{start_idx + i:06d}"
            specimens_group.create_dataset(specimen_key, data=specimen_data.astype(np.float32))

        f.attrs["num_specimens"] = len(specimens_data)
        f.attrs["format_version"] = "1.1"
        f.attrs["description"] = "C. elegans nuclei data: [canonical_id, x, y, z]"

    return start_idx + len(specimens_data)

In [None]:
# Here, we do not have the shuffling and splitting of the dataset because it all refers to the test set
def convert_specimens_to_hdf5(
    specimens_data: Union[np.ndarray, Sequence[np.ndarray]],
    output_dir: Union[str, Path],
    specimens_per_file: int = 2**14,
) -> None:
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    specimens_list = list(specimens_data)

    split_dir = output_dir / "test"
    split_dir.mkdir(exist_ok=True)

    file_idx = 0
    specimen_idx = 0

    for start_idx in range(0, len(specimens_list), specimens_per_file):
        end_idx = min(start_idx + specimens_per_file, len(specimens_list))
        batch_data = specimens_list[start_idx:end_idx]

        file_path = split_dir / f"test_{file_idx:04d}.h5"
        specimen_idx = create_hdf5_file(
            batch_data,
            file_path,
            file_prefix="specimen",
            start_idx=specimen_idx,
        )
        file_idx += 1

    info = {
        "total_specimens": len(specimens_list),
        "split": "test",
        "specimens_per_file": specimens_per_file,
        "format": "[canonical_id, x, y, z]",
    }

    with open(output_dir / "dataset_info.json", "w") as f:
        json.dump(info, f, indent=2)

In [None]:
test_worm_1_pickle_path = Path("/fs/pool/pool-mlsb/bulat/Wormologist/synthetic_data_generator/test1worms.pkl")
test_worm_2_pickle_path = Path("/fs/pool/pool-mlsb/bulat/Wormologist/synthetic_data_generator/test2worms.pkl")
subgraph_output_directory = Path("/fs/pool/pool-mlsb/bulat/Wormologist/new_subgraph_testing_data")
subgraph_output_directory.mkdir(parents=True, exist_ok=True)

In [None]:
specimens_per_file = 2**14
rng_seed = 42
test_sizes = np.array(list(range(10, 560, 10)) + [558])

In [None]:
with open(test_worm_1_pickle_path, "rb") as f:
    worm1 = pickle.load(f)
with open(test_worm_2_pickle_path, "rb") as f:
    worm2 = pickle.load(f)
all_worms = worm1 + worm2

In [None]:
rng = np.random.default_rng(rng_seed)

for subgraph_size in test_sizes:
    subgraph_samples: List[np.ndarray] = []

    for answer, coords in all_worms:
        nodes = len(answer)
        sample_count = int(np.ceil(558 / subgraph_size)) * 2

        for _ in range(sample_count):
            sampled_indices = rng.choice(nodes, size=min(subgraph_size, nodes), replace=False)
            canonical_ids = np.asarray(answer)[sampled_indices]
            coords_subset = np.asarray(coords)[sampled_indices]

            sample = np.zeros((len(sampled_indices), 4), dtype=np.float32)
            sample[:, 0] = canonical_ids.astype(np.float32)
            sample[:, 1:] = coords_subset.astype(np.float32)
            subgraph_samples.append(sample)

    subgraph_dir = subgraph_output_directory / f"subgraph_{int(subgraph_size):03d}"
    
    convert_specimens_to_hdf5(
        subgraph_samples,
        output_dir=subgraph_dir,
        specimens_per_file=specimens_per_file,
    )
    print(f"Saved {len(subgraph_samples)} subgraphs of size {subgraph_size} to {subgraph_dir}")