In [None]:
import pickle
import numpy as np
from pathlib import Path

from typing import List, Union, Tuple
import h5py
import json

In [None]:
def sample_from_joint(mu: np.ndarray, Sigma: np.ndarray, n_samples: int, rng=None):
    rng = np.random.default_rng(None if rng is None else rng)
    N, D = mu.shape
    vec = rng.multivariate_normal(mu.ravel(), Sigma, size=n_samples, method="cholesky")
    return vec.reshape(n_samples, N, D)

In [None]:
with open("Sigma.pkl", "rb") as f:
    Sigma = pickle.load(f)

with open("mu.pkl", "rb") as f:
    mu = pickle.load(f)

In [None]:
data_directory = Path("/fs/gpfs41/lv11/fileset01/pool/pool-smola/pythonny/data/long_train")

In [None]:
sims = sample_from_joint(mu, Sigma, n_samples=10*(2**15))

In [None]:
def create_hdf5_file(specimens_data: List[np.ndarray], output_path: Path, file_prefix: str = "specimens", start_idx: int = 0) -> int:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with h5py.File(output_path, 'w') as f:
        specimens_group = f.create_group('specimens')
        
        for i, specimen_data in enumerate(specimens_data):
            specimen_key = f"{file_prefix}_{start_idx + i:06d}"
            
            if specimen_data.shape != (558, 3):
                raise ValueError(f"Expected shape (558, 3), got {specimen_data.shape} for specimen {i}")
            
            canonical_ids = np.arange(558).reshape(-1, 1)
            specimen_full = np.hstack([canonical_ids, specimen_data])
            
            specimens_group.create_dataset(specimen_key, data=specimen_full)
        
        f.attrs['num_specimens'] = len(specimens_data)
        f.attrs['format_version'] = '1.1'
        f.attrs['description'] = 'C. elegans nuclei data: [canonical_id, x, y, z]'
    
    return start_idx + len(specimens_data)


def convert_specimens_to_hdf5(specimens_data: List[np.ndarray], output_dir: Union[str, Path], split_ratios: Tuple[float, float, float] = (0.8, 0.10, 0.10), specimens_per_file: int = 2**14, shuffle: bool = True) -> None:
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    if abs(sum(split_ratios) - 1.0) > 1e-6:
        raise ValueError(f"Split ratios must sum to 1.0, got {sum(split_ratios)}")
    
    print(f"Converting {len(specimens_data)} specimens to HDF5...")
    print(f"Split ratios: train={split_ratios[0]:.1%}, val={split_ratios[1]:.1%}, test={split_ratios[2]:.1%}")
    print(f"Specimens per file: {specimens_per_file}")
    
    if shuffle:
        indices = np.random.permutation(len(specimens_data))
        specimens_data = [specimens_data[i] for i in indices]
        print("Data shuffled")
    
    n_total = len(specimens_data)
    n_train = int(n_total * split_ratios[0])
    n_val = int(n_total * split_ratios[1])
    n_test = n_total - n_train - n_val
    
    print(f"Split sizes: train={n_train}, val={n_val}, test={n_test}")
    
    train_data = specimens_data[:n_train]
    val_data = specimens_data[n_train:n_train + n_val]
    test_data = specimens_data[n_train + n_val:]

    splits = [("train", train_data), ("val", val_data), ("test", test_data)]

    for split_name, split_data in splits:
        if not split_data:
            print(f"Warning: No data for {split_name} split")
            continue
        
        split_dir = output_dir / split_name
        split_dir.mkdir(exist_ok=True)
        
        file_idx = 0
        specimen_idx = 0
        
        for start_idx in range(0, len(split_data), specimens_per_file):
            end_idx = min(start_idx + specimens_per_file, len(split_data))
            batch_data = split_data[start_idx:end_idx]
            
            file_path = split_dir / f"{split_name}_{file_idx:04d}.h5"
            next_specimen_idx = create_hdf5_file(batch_data, file_path, file_prefix="specimen", start_idx=specimen_idx)

            print(f"  Created {file_path} with {len(batch_data)} specimens")
            
            file_idx += 1
            specimen_idx = next_specimen_idx
    
    info = {
        "total_specimens": n_total,
        "splits": {
            "train": n_train,
            "val": n_val, 
            "test": n_test
        },
        "specimens_per_file": specimens_per_file,
        "format": "[canonical_id, x, y, z]",
        "canonical_id_range": "[0, 557]"
    }
    
    with open(output_dir / "dataset_info.json", 'w') as f:
        json.dump(info, f, indent=2)
    
    print(f"\nDataset created in {output_dir}")
    print(f"Dataset info saved to {output_dir / 'dataset_info.json'}")

In [None]:
convert_specimens_to_hdf5(sims, data_directory)