# ANDI Dataset Generation Pipeline

Generates training data for the Dual CNN Cross-Attention architecture.

Dataset specifications:
- 2,000,000 trajectories (67.5% train / 7.5% val / 25% test)
- 5 diffusion models: ATTM, CTRW, FBM, LW, SBM
- 39 discrete alpha values: [0.05, 0.10, ..., 2.00]
- Variable trajectory lengths: [10, 1000] across 12 bins
- Ground truth Œ± from ANDI generation, D from MSD-based Langevin estimator
- Dual preprocessing: scaled displacements (alpha-branch) + raw displacements (D-branch)
- HDF5 output with gzip compression (float16 for size reduction)
- Clean training data (noise only in test set)

Estimated runtime: 4-7 hours for full 2M dataset.

References:
- Firbas et al. (2023): ConvTransformer baseline
- Korabel & Waigh (2023): MSD-based D estimation
- ANDI Challenge specifications


In [None]:
# Installation and environment setup

import os
import sys

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

# Mount Google Drive if in Colab
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = '/content/drive/MyDrive/ERP_Shrey'
    WORK_DIR = BASE_DIR
else:
    BASE_DIR = os.getcwd()
    WORK_DIR = BASE_DIR

# Create data directories
os.makedirs(os.path.join(BASE_DIR, 'data/andi'), exist_ok=True)
if IN_COLAB:
    os.makedirs(os.path.join(WORK_DIR, 'temp_data'), exist_ok=True)

# Install packages
# andi-datasets requires numpy<=1.26.4
if IN_COLAB:
    print("Installing compatible numpy version...")
    !pip install -q "numpy<=1.26.4" --force-reinstall --no-deps
    !pip install -q "numpy<=1.26.4"

print("Installing andi-datasets...")
!pip install -q andi-datasets

print("Installing hurst...")
!pip install -q hurst

print("Installing additional packages...")
if IN_COLAB:
    !pip install -q tqdm h5py scikit-learn pandas matplotlib seaborn
else:
    !pip install -q tqdm h5py scikit-learn pandas numpy matplotlib seaborn

print("Packages installed")

‚úì Running on local machine
‚úì Base directory: /home/magjun/Documents/ERP_Shrey/Report_V2_Preprocessing_and_training

INSTALLING REQUIRED PACKAGES

Installing andi-datasets...

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Installing hurst...

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Installing additional packages...

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

‚úì All packages ins

In [None]:
# Import libraries and setup paths

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
import time
import os
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_split

try:
    from andi_datasets.datasets_theory import datasets_theory
    import hurst as hurst_module
    andi = datasets_theory()
except ImportError as e:
    raise RuntimeError(f"Required library not found: {e}. Run Cell 1 first.")

from hurst import compute_Hc

# Setup output directories
if IN_COLAB:
    TEMP_OUTPUT_DIR = os.path.join(WORK_DIR, 'temp_data')
    FINAL_OUTPUT_DIR = os.path.join(BASE_DIR, 'data/andi')
else:
    TEMP_OUTPUT_DIR = os.path.join(BASE_DIR, 'data/andi')
    FINAL_OUTPUT_DIR = TEMP_OUTPUT_DIR

os.makedirs(TEMP_OUTPUT_DIR, exist_ok=True)
os.makedirs(FINAL_OUTPUT_DIR, exist_ok=True)

‚úì All libraries imported successfully

PATH CONFIGURATION

Data directory: /home/magjun/Documents/ERP_Shrey/Report_V2_Preprocessing_and_training/data/andi

‚úì Directories configured



In [None]:
# Dataset configuration

SEED = 42
np.random.seed(SEED)

# Model mapping (ANDI standard)
# 0: ATTM, 1: CTRW, 2: FBM, 3: LW, 4: SBM
MODELS = [0, 1, 2, 3, 4]
MODEL_NAMES = {0: "ATTM", 1: "CTRW", 2: "FBM", 3: "LW", 4: "SBM"}

# Dataset size
TEST_MODE = True  # Set False for production

if TEST_MODE:
    print("TEST MODE: generating 100K trajectories")
    TOTAL_TRAJECTORIES = 100_000
    BATCH_SIZE = 10_000
else:
    TOTAL_TRAJECTORIES = 500_000
    BATCH_SIZE = 10_000

# Split ratios (Firbas et al. 2023)
TRAIN_RATIO = 0.67  # 67.5%
VAL_RATIO = 0.075    # 7.5%
TEST_RATIO = 0.25    # 25.0%

# Calculate split counts based on TEST_MODE
if TEST_MODE:
    # In test mode, use ratios to calculate from TOTAL_TRAJECTORIES
    N_TRAIN = int(TOTAL_TRAJECTORIES * TRAIN_RATIO)
    N_VAL = int(TOTAL_TRAJECTORIES * VAL_RATIO)
    N_TEST = int(TOTAL_TRAJECTORIES * TEST_RATIO)
else:
    # Production mode: Fixed targets
    N_TRAIN = 1200000  # Fixed target
    N_VAL = 200000      # Fixed target
    N_TEST = 500000     # Fixed target

# --------------------------------------------------------------------------- #
# ALPHA CONFIGURATION - ANDI TABLE 2 SPECIFICATION
# --------------------------------------------------------------------------- #
# Different alpha ranges per model (step 0.1, stored with 2 decimal precision)
# This follows ANDI Challenge Table 2 exactly

ANDI_ALPHA_SPECS = {
    0: np.round(np.arange(0.10, 1.01, 0.1), 2),   # ATTM: [0.10 to 1.00] ‚Üí 10 values
    1: np.round(np.arange(0.10, 1.01, 0.1), 2),   # CTRW: [0.10 to 1.00] ‚Üí 10 values
    2: np.round(np.arange(0.10, 1.91, 0.1), 2),   # FBM:  [0.10 to 1.90] ‚Üí 19 values
    3: np.round(np.arange(1.00, 1.91, 0.1), 2),   # LW:   [1.00 to 1.90] ‚Üí 10 values
    4: np.round(np.arange(0.10, 1.91, 0.1), 2),   # SBM:  [0.10 to 1.90] ‚Üí 19 values
}

# Total unique (model, alpha) combinations
N_MODEL_ALPHA_COMBINATIONS = sum(len(alphas) for alphas in ANDI_ALPHA_SPECS.values())  # 68

# Model-specific alpha constraints (for validation)
ALPHA_CONSTRAINTS = {
    0: (0.10, 1.00),   # ATTM: sub-diffusive only
    1: (0.10, 1.00),   # CTRW: sub-diffusive only
    2: (0.10, 1.90),   # FBM: sub & super-diffusive
    3: (1.00, 1.90),   # LW: super-diffusive only
    4: (0.10, 1.90),   # SBM: full range
}

# --------------------------------------------------------------------------- #
# TRAJECTORY LENGTH CONFIGURATION
# --------------------------------------------------------------------------- #
# For TRAIN/VAL: Variable lengths with weighted distribution (Firbas approach)
LENGTH_BINS = [
    (10, 20), (21, 30), (31, 40), (41, 50),
    (51, 100), (101, 200), (201, 300), (301, 400),
    (401, 500), (501, 600), (601, 800), (801, 1000)
]

# Length distribution: Emphasize short trajectories (biologically relevant)
LENGTH_BIN_WEIGHTS = np.array([
    0.10, 0.10, 0.10, 0.10,  # Short (40%)
    0.10, 0.10, 0.10, 0.10,  # Medium (40%)
    0.05, 0.05, 0.05, 0.05   # Long (20%)
])

# For TEST: ANDI Table 2 fixed lengths
ANDI_TEST_LENGTHS = [10, 20, 30, 40, 50, 100, 200, 300, 400, 500, 600, 800, 1000]  # 13 values

MAX_LENGTH = 1000  # For padding in HDF5

# Dimensionality
DIM = 1  # 1D trajectories (Firbas approach for benchmarking)

# --------------------------------------------------------------------------- #
# SNR CONFIGURATION - ANDI TABLE 2 SPECIFICATION
# --------------------------------------------------------------------------- #
# Signal-to-Noise Ratio levels for test set
ANDI_SNR_LEVELS = [0, 0]  # No noise (clean data)

# For TRAIN/VAL: Clean data (no noise)
# For TEST: Clean data (no noise)

# Calculate total test permutations:
# ATTM: 13 lengths √ó 2 SNR √ó 10 alpha = 260
# CTRW: 13 lengths √ó 2 SNR √ó 10 alpha = 260
# FBM:  13 lengths √ó 2 SNR √ó 19 alpha = 494
# LW:   13 lengths √ó 2 SNR √ó 10 alpha = 260
# SBM:  13 lengths √ó 2 SNR √ó 19 alpha = 494
# TOTAL: 1,768 unique permutations

N_TEST_PERMUTATIONS = (
    len(ANDI_TEST_LENGTHS) * len(ANDI_SNR_LEVELS) *
    (len(ANDI_ALPHA_SPECS[0]) + len(ANDI_ALPHA_SPECS[1]) +
     len(ANDI_ALPHA_SPECS[2]) + len(ANDI_ALPHA_SPECS[3]) +
     len(ANDI_ALPHA_SPECS[4]))
)

N_REPS_PER_TEST_PERMUTATION = N_TEST // N_TEST_PERMUTATIONS  # ~283 replications

# --------------------------------------------------------------------------- #
# TRAIN/VAL GENERATION STRATEGY
# --------------------------------------------------------------------------- #
# Balanced generation: Equal trajectories per (model, alpha) combination
N_TRAIN_VAL_TOTAL = N_TRAIN + N_VAL  # 1,500,000
N_TRAJS_PER_MODEL_ALPHA = N_TRAIN_VAL_TOTAL // N_MODEL_ALPHA_COMBINATIONS  # ~22,058

# --------------------------------------------------------------------------- #
# GROUND TRUTH CALCULATION METHODS
# --------------------------------------------------------------------------- #
# H estimation methods:
# - FBM: Use theoretical H = Œ±/2 (exact)
# - Others: Use R/S (Rescaled Range) analysis (Korabel et al.)

# USE_THEORETICAL_H_FOR_FBM = True  # H = Œ±/2 for FBM only # DEPRECATED: Using alpha directly from ANDI
# USE_RS_ANALYSIS_FOR_OTHERS = True  # R/S for CTRW, ATTM, SBM, LW # DEPRECATED: Using alpha directly from ANDI

# D estimation: MSD-based Langevin method (unbiased for any Œ±)
D_ESTIMATION_MAX_TAU = 50  # Maximum lag time for MSD calculation

# --------------------------------------------------------------------------- #
# PRINT CONFIGURATION SUMMARY
# --------------------------------------------------------------------------- #
print("=" * 70)
print("DATASET CONFIGURATION - ANDI TABLE 2 COMPLIANT")
print("=" * 70)
print(f"Mode: {'TEST' if TEST_MODE else 'PRODUCTION'}")
print(f"Total trajectories: {TOTAL_TRAJECTORIES:,}")
print(f"  - Train: {N_TRAIN:,} ({N_TRAIN/TOTAL_TRAJECTORIES*100:.1f}%)")
print(f"  - Val: {N_VAL:,} ({N_VAL/TOTAL_TRAJECTORIES*100:.1f}%)")
print(f"  - Test: {N_TEST:,} ({N_TEST/TOTAL_TRAJECTORIES*100:.1f}%)")
print(f"\nModels: {list(MODEL_NAMES.values())}")
print()
print("ANDI Table 2 Test Set Specifications:")
print(f"  - Unique permutations: {N_TEST_PERMUTATIONS:,}")
print(f"  - Replications per permutation: {N_REPS_PER_TEST_PERMUTATION}")
print(f"  - Lengths: {len(ANDI_TEST_LENGTHS)} values {ANDI_TEST_LENGTHS}")
print(f"  - SNR levels: {ANDI_SNR_LEVELS}")
print(f"  - Alpha ranges:")
for model_id, model_name in MODEL_NAMES.items():
    alphas = ANDI_ALPHA_SPECS[model_id]
    print(f"      {model_name}: [{alphas[0]:.2f} to {alphas[-1]:.2f}] ({len(alphas)} values)")
print()
print("Train/Val Set Specifications:")
print(f"  - Total (model, Œ±) combinations: {N_MODEL_ALPHA_COMBINATIONS}")
print(f"  - Trajectories per combination: ~{N_TRAJS_PER_MODEL_ALPHA:,}")
print(f"  - Length sampling: Weighted distribution from bins")
print(f"  - Noise: None (clean data)")
print()
print(f"Max length (padding): {MAX_LENGTH}")
print(f"Dimensionality: {DIM}D")
print()
print("Ground truth methods:")
print("  - H (FBM): Theoretical (H = Œ±/2)")
print("  - H (others): R/S analysis")
print("  - D (all): Variance-based MSD")
print()
print("Noise protocol:")
print("  - Train/Val: Clean data (no noise)")
print(f"  - Test: Clean data (no noise)")
print()
print(f"Batch size: {BATCH_SIZE:,} trajectories")
print(f"Random seed: {SEED}")
print("=" * 70)
print()

# Save configuration for reference
CONFIG = {
    "seed": SEED,
    "test_mode": TEST_MODE,
    "andi_table2_compliant": True,
    "total_trajectories": TOTAL_TRAJECTORIES,
    "split_counts": {"train": N_TRAIN, "val": N_VAL, "test": N_TEST},
    "models": {int(k): v for k, v in MODEL_NAMES.items()},
    "andi_alpha_specs": {int(k): v.tolist() for k, v in ANDI_ALPHA_SPECS.items()},
    "alpha_constraints": {int(k): v for k, v in ALPHA_CONSTRAINTS.items()},
    "test_config": {
        "lengths": ANDI_TEST_LENGTHS,
        "snr_levels": ANDI_SNR_LEVELS,
        "n_permutations": N_TEST_PERMUTATIONS,
        "n_reps_per_permutation": N_REPS_PER_TEST_PERMUTATION
    },
    "train_val_config": {
        "n_model_alpha_combinations": N_MODEL_ALPHA_COMBINATIONS,
        "n_trajs_per_combination": N_TRAJS_PER_MODEL_ALPHA,
        "length_bins": LENGTH_BINS,
        "length_bin_weights": LENGTH_BIN_WEIGHTS.tolist()
    },
    "max_length": MAX_LENGTH,
    "dimension": DIM,
    "ground_truth_methods": {
        "H_FBM": "theoretical",
        "H_others": "R/S_analysis",
        "D_all": "variance_based_MSD"
    },
    "batch_size": BATCH_SIZE,
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}


Generating small dataset for testing (100K trajectories)
Set TEST_MODE = False for production 2M dataset

DATASET CONFIGURATION - ANDI TABLE 2 COMPLIANT
Mode: TEST
Total trajectories: 100,000
  - Train: 0 (0.0%)
  - Val: 0 (0.0%)
  - Test: 100,000 (100.0%)

Models: ['ATTM', 'CTRW', 'FBM', 'LW', 'SBM']

ANDI Table 2 Test Set Specifications:
  - Unique permutations: 1,768
  - Replications per permutation: 56
  - Lengths: 13 values [10, 20, 30, 40, 50, 100, 200, 300, 400, 500, 600, 800, 1000]
  - SNR levels: [0, 0]
  - Alpha ranges:
      ATTM: [0.10 to 1.00] (10 values)
      CTRW: [0.10 to 1.00] (10 values)
      FBM: [0.10 to 1.90] (19 values)
      LW: [1.00 to 1.90] (10 values)
      SBM: [0.10 to 1.90] (19 values)

Train/Val Set Specifications:
  - Total (model, Œ±) combinations: 68
  - Trajectories per combination: ~0
  - Length sampling: Weighted distribution from bins
  - Noise: None (clean data)

Max length (padding): 1000
Dimensionality: 1D

Ground truth methods:
  - H (FBM): T

In [None]:
# Helper functions for ground truth calculation

def estimate_H_rescaled_range(trajectory, min_window=8):
    """
    Estimate Hurst exponent using R/S analysis.
    Used for non-FBM models (CTRW, ATTM, SBM, LW).
    """
    try:
        traj = np.squeeze(trajectory)
        if traj.ndim != 1:
            traj = traj[:, 0]
        
        T = len(traj)
        if T < 2 * min_window:
            return np.nan
        
        H, c, data = hurst_module.compute_Hc(traj, kind='random_walk', simplified=False)
        
        if not (0 <= H <= 1):
            H_simple, _, _ = hurst_module.compute_Hc(traj, kind='random_walk', simplified=True)
            H = H_simple if (0 <= H_simple <= 1) else np.nan
        
        return float(H)
    except Exception:
        return np.nan


def estimate_H_theoretical(alpha, model_id):
    """Calculate theoretical H for FBM: H = Œ±/2"""
    if model_id == 2:
        return alpha / 2.0
    else:
        return np.nan


def estimate_D_from_variance(trajectory, delta_t=1.0):
    """
    Estimate D from displacement variance: D = Var(Œîx) / (2 * Œît)
    """
    try:
        traj = np.squeeze(trajectory)
        if traj.ndim != 1:
            traj = traj[:, 0]
        
        # Calculate displacements
        delta_x = np.diff(traj)
        
        if len(delta_x) < 2:
            return np.nan
        
        # Variance of displacements
        sigma_sq = np.var(delta_x)
        
        # D = œÉ¬≤ / (2Œît)
        D = sigma_sq / (2.0 * delta_t)
        
        return float(D) if D > 0 else np.nan
    
    except Exception:
        return np.nan


def estimate_D_from_msd(trajectory, alpha, max_tau=50, delta_t=1.0):
    """
    Estimate D from MSD curve fitting: MSD(œÑ) = 2D œÑ^Œ±
    Uses log-linear regression on MSD vs lag time.
    """
    try:
        traj = np.squeeze(trajectory)
        if traj.ndim != 1:
            traj = traj[:, 0]
        
        T = len(traj)
        if T < 10:
            return np.nan
        
        # Limit max_tau to available data
        max_tau = min(max_tau, T // 4)
        if max_tau < 3:
            return np.nan
        
        # Calculate MSD for different lag times
        taus = np.arange(1, max_tau + 1)
        msds = []
        
        for tau in taus:
            displacements = traj[tau:] - traj[:-tau]
            if len(displacements) > 0:
                msds.append(np.mean(displacements ** 2))
            else:
                break
        
        msds = np.array(msds)
        taus = taus[:len(msds)]
        
        # Filter valid MSD values
        valid = np.isfinite(msds) & (msds > 0)
        if valid.sum() < 3:
            return np.nan
        
        # MSD(œÑ) = 2D œÑ^Œ± ‚Üí log(MSD) = log(2D) + Œ± log(œÑ)
        log_msd = np.log(msds[valid])
        log_tau = np.log(taus[valid] * delta_t)
        
        # Intercept = log(2D) ‚Üí D = exp(intercept) / 2
        intercept = np.mean(log_msd - alpha * log_tau)
        D = 0.5 * np.exp(intercept)
        
        return float(D) if D > 0 else np.nan
    
    except Exception:
        return np.nan


def compute_dual_preprocessing(trajectory):
    """
    Generate dual displacement arrays for Dual CNN architecture.
    Returns raw displacements (D-branch) and scaled displacements (alpha-branch).
    """
    traj = np.squeeze(trajectory)
    if traj.ndim != 1:
        traj = traj[:, 0]
    
    delta_x_raw = np.diff(traj)
    traj_range = np.ptp(traj)
    
    if traj_range > 0:
        delta_x_scaled = delta_x_raw / traj_range
    else:
        delta_x_scaled = np.zeros_like(delta_x_raw)
        traj_range = 1.0
    
    return delta_x_raw, delta_x_scaled, traj_range


def sample_trajectory_length(length_bin_idx, rng):
    """Sample trajectory length from given bin."""
    min_len, max_len = LENGTH_BINS[length_bin_idx]
    return rng.randint(min_len, max_len + 1)


def sample_valid_alpha_for_model(model_id, rng):
    """Sample valid alpha for model from discrete values."""
    alpha_min, alpha_max = ALPHA_CONSTRAINTS[model_id]
    valid_alphas = ALPHA_VALUES[(ALPHA_VALUES >= alpha_min) & (ALPHA_VALUES <= alpha_max)]
    
    if len(valid_alphas) == 0:
        raise ValueError(f"No valid alpha values for model {model_id}")
    
    return rng.choice(valid_alphas)


def assign_length_bin(length):
    """Assign trajectory to length bin."""
    for bin_idx, (min_len, max_len) in enumerate(LENGTH_BINS):
        if min_len <= length <= max_len:
            return bin_idx
    return -1

HELPER FUNCTIONS LOADED
‚úì [DEPRECATED] estimate_H_rescaled_range() - Use alpha directly from ANDI
‚úì [DEPRECATED] estimate_H_theoretical() - Use alpha directly from ANDI
‚úì [DEPRECATED] estimate_D_from_variance() - Biased for Œ±‚â†1
‚úì estimate_D_from_msd() - MSD-based Langevin estimator (ACTIVE)
‚úì compute_dual_preprocessing() - Scaled + Raw displacements
‚úì Trajectory sampling helpers



In [None]:
# SNR and ANDI test set helper functions

def add_noise_snr(trajectory, snr_target, random_state=None):
    """
    Add Gaussian noise to achieve target SNR.
    SNR = œÉ_signal / œÉ_noise
    """
    rng = np.random.RandomState(random_state)
    
    # Calculate signal standard deviation
    sigma_signal = np.std(trajectory)
    
    # Calculate required noise level to achieve target SNR
    sigma_noise = sigma_signal / snr_target
    
    # Generate Gaussian noise
    noise = rng.normal(0, sigma_noise, size=trajectory.shape)
    
    # Add noise to trajectory
    noisy_trajectory = trajectory + noise
    
    return noisy_trajectory, sigma_noise


def verify_snr(clean_trajectory, noisy_trajectory):
    """Verify actual SNR achieved after adding noise."""
    sigma_signal = np.std(clean_trajectory)
    noise = noisy_trajectory - clean_trajectory
    sigma_noise = np.std(noise)
    
    if sigma_noise == 0:
        return np.inf
    
    snr_actual = sigma_signal / sigma_noise
    return snr_actual


def generate_andi_test_permutations(verbose=True):
    """
    Generate all permutations for ANDI Table 2 test set.
    Each permutation: (model, length, snr, alpha)
    """
    permutations = []
    
    # ANDI Table 2 specifications
    lengths = ANDI_TEST_LENGTHS
    snr_levels = ANDI_SNR_LEVELS
    alpha_specs = ANDI_ALPHA_SPECS
    
    # Generate all permutations
    for model_id in range(5):
        for length in lengths:
            for snr in snr_levels:
                for alpha in alpha_specs[model_id]:
                    permutations.append({
                        'model_id': model_id,
                        'length': length,
                        'snr': snr,
                        'alpha': round(float(alpha), 2)  # Ensure 2 decimals
                    })
    
    if verbose:
        print(f"Generated {len(permutations):,} test permutations")
        for model_id in range(5):
            model_perms = [p for p in permutations if p['model_id'] == model_id]
            n_alphas = len(alpha_specs[model_id])
            n_lengths = len(lengths)
            n_snrs = len(snr_levels)
            expected = n_alphas * n_lengths * n_snrs
            print(f"  {MODEL_NAMES[model_id]}: {len(model_perms):,} "
                  f"({n_alphas} Œ± √ó {n_lengths} T √ó {n_snrs} SNR = {expected})")
    
    return permutations


def sample_alpha_for_model_balanced(model_id, random_state=None):
    """Sample alpha for model from ANDI-compliant ranges."""
    rng = np.random.RandomState(random_state)
    alpha_values = ANDI_ALPHA_SPECS[model_id]
    alpha = rng.choice(alpha_values)
    return round(float(alpha), 2)


def get_model_alpha_pairs():
    """Get all (model_id, alpha) pairs for balanced train/val generation."""
    pairs = []
    for model_id in range(5):
        for alpha in ANDI_ALPHA_SPECS[model_id]:
            pairs.append((model_id, round(float(alpha), 2)))
    return pairs

SNR AND ANDI TABLE 2 HELPER FUNCTIONS LOADED
‚úì add_noise_snr() - Add SNR-based Gaussian noise to trajectories
‚úì verify_snr() - Verify actual SNR achieved
‚úì generate_andi_test_permutations() - Generate all test set permutations
‚úì sample_alpha_for_model_balanced() - Sample alpha for train/val
‚úì get_model_alpha_pairs() - Get all (model, Œ±) pairs for balanced generation



In [None]:
# Generate ANDI Table 2 test set

import time as time_module

test_permutations = generate_andi_test_permutations(verbose=True)

# Storage for test set data
test_trajectories = []
test_displacements_raw = []
test_displacements_scaled = []
test_metadata = []

# Progress tracking
total_test_trajs = N_TEST
trajs_generated = 0
start_time = time_module.time()

print(f"Generating {total_test_trajs:,} test trajectories...")
print(f"  - {len(test_permutations):,} unique permutations")
print(f"  - {N_REPS_PER_TEST_PERMUTATION} replications per permutation")
print()

# Generate test set
batch_progress = 0
for perm_idx, perm in enumerate(tqdm(test_permutations, desc="Test Permutations")):
    model_id = perm['model_id']
    length = perm['length']
    snr = perm['snr']
    alpha = perm['alpha']  # Alpha comes directly from ANDI
    
    # Generate N replications for this permutation
    for rep in range(N_REPS_PER_TEST_PERMUTATION):
        try:
            # Generate CLEAN trajectory first
            traj_clean = andi.create_dataset(
                T=length,
                N_models=1,
                exponents=[alpha],
                models=[model_id],
                dimension=DIM
            )
            
            # Extract trajectory positions (skip first 3 metadata elements)
            traj_clean = traj_clean[0][3:]
            
            # Ensure correct shape [T, 1]
            if traj_clean.ndim == 1:
                traj_clean = traj_clean.reshape(-1, 1)
            
            actual_length = len(traj_clean)
            
            # =================================================================
            # GROUND TRUTH CALCULATION (on CLEAN trajectory)
            # =================================================================
            
            # Alpha: Already known from ANDI generation (no estimation needed)
            # D: Estimate using MSD-based Langevin method (unbiased for any Œ±)
            D = estimate_D_from_msd(traj_clean, alpha)
            
            # =================================================================
            # NO NOISE - Use clean trajectory
            # =================================================================
            
            traj_noisy = traj_clean  # No noise added
            sigma_noise = 0.0  # No noise standard deviation
            
            # =================================================================
            # Store test trajectory and metadata
            # =================================================================
            
            # Calculate raw displacements: dx = x[t+1] - x[t]
            displacements_raw = np.diff(traj_noisy, axis=0)
            
            # Calculate scaled displacements: dx / sqrt(dt)
            # For unit time step (dt=1): dx / sqrt(1) = dx
            displacements_scaled = displacements_raw / np.sqrt(1.0)
            
            # Pad trajectory to MAX_LENGTH
            traj_padded = np.zeros((MAX_LENGTH, DIM), dtype=np.float32)
            traj_padded[:actual_length] = traj_noisy
            
            # Pad displacements to MAX_LENGTH-1
            disp_raw_padded = np.zeros((MAX_LENGTH - 1, DIM), dtype=np.float32)
            disp_raw_padded[:actual_length-1] = displacements_raw
            
            disp_scaled_padded = np.zeros((MAX_LENGTH - 1, DIM), dtype=np.float32)
            disp_scaled_padded[:actual_length-1] = displacements_scaled
            
            # Store data
            test_trajectories.append(traj_padded)
            test_displacements_raw.append(disp_raw_padded)
            test_displacements_scaled.append(disp_scaled_padded)
            
            # Store metadata (alpha from ANDI, D from MSD estimation, SNR fields)
            test_metadata.append({
                'model_id': model_id,
                'alpha': alpha,  # From ANDI generation
                'D': D,          # From MSD estimation
                'length': actual_length,
                'snr': snr,
                'sigma_noise': sigma_noise
            })
            
            trajs_generated += 1
            
        except Exception as e:
            print(f"\nError generating trajectory for "
                  f"{MODEL_NAMES[model_id]}, Œ±={alpha:.2f}, T={length}, SNR={snr}")
            print(f"Error: {e}")
            continue

# Convert to arrays
test_trajectories = np.array(test_trajectories, dtype=np.float32)
test_displacements_raw = np.array(test_displacements_raw, dtype=np.float32)
test_displacements_scaled = np.array(test_displacements_scaled, dtype=np.float32)

# Convert metadata to DataFrame
df_test_metadata = pd.DataFrame(test_metadata)

# Summary
elapsed_time = time_module.time() - start_time
print(f"\nTest set complete: {len(test_trajectories):,} trajectories in {elapsed_time/60:.2f} min")
print(f"Shapes: trajectories {test_trajectories.shape}, metadata {df_test_metadata.shape}")
print(f"Models: {df_test_metadata['model_id'].value_counts().to_dict()}")

GENERATING ANDI TABLE 2 TEST SET

ANDI TABLE 2 TEST PERMUTATIONS GENERATED
Total unique permutations: 1,768

Breakdown by model:
  ATTM: 260 (10 Œ± √ó 13 T √ó 2 SNR = 260)
  CTRW: 260 (10 Œ± √ó 13 T √ó 2 SNR = 260)
  FBM: 494 (19 Œ± √ó 13 T √ó 2 SNR = 494)
  LW: 260 (10 Œ± √ó 13 T √ó 2 SNR = 260)
  SBM: 494 (19 Œ± √ó 13 T √ó 2 SNR = 494)

Generating 100,000 test trajectories...
  - 1,768 unique permutations
  - 56 replications per permutation



Test Permutations: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1768/1768 [01:05<00:00, 27.01it/s]



TEST SET GENERATION COMPLETE
Total trajectories generated: 99,008
Time elapsed: 66.41 seconds (1.11 minutes)

Data shapes:
  Trajectories: (99008, 1000, 1)
  Displacements (raw): (99008, 999, 1)
  Displacements (scaled): (99008, 999, 1)
  Metadata: (99008, 6)

Metadata summary:
  Models: {2: 27664, 4: 27664, 0: 14560, 1: 14560, 3: 14560}
  Lengths: [9, 19, 29, 39, 49, 99, 199, 299, 399, 499, 599, 799, 999]
  SNR levels: {0: 99008}

           model_id         alpha             D        length      snr  \
count  99008.000000  99008.000000  8.887700e+04  99008.000000  99008.0   
mean       2.264706      0.933824  4.349667e+00    310.538462      0.0   
std        1.389227      0.537593  1.560327e+01    315.470625      0.0   
min        0.000000      0.100000  1.876581e-60      9.000000      0.0   
25%        1.000000      0.500000  9.997586e-03     39.000000      0.0   
50%        2.000000      0.900000  1.075369e-01    199.000000      0.0   
75%        4.000000      1.400000  8.366643e-

In [None]:
# Generate balanced train/val set

import time as time_module

model_alpha_pairs = get_model_alpha_pairs()
print(f"Generating {N_TRAIN + N_VAL:,} train/val trajectories from {len(model_alpha_pairs)} (model, Œ±) pairs")

# Storage for train/val data
train_val_trajectories = []
train_val_displacements_raw = []
train_val_displacements_scaled = []
train_val_metadata = []

# Progress tracking
trajs_generated = 0
start_time = time_module.time()

print(f"Generating {N_TRAIN + N_VAL:,} train/val trajectories...")
print()

# Generate trajectories for each (model, alpha) pair
for pair_idx, (model_id, alpha) in enumerate(tqdm(model_alpha_pairs, desc="(Model, Œ±) Pairs")):
    
    # Generate N_TRAJS_PER_MODEL_ALPHA trajectories for this (model, alpha) pair
    for rep in range(N_TRAJS_PER_MODEL_ALPHA):
        try:
            # Sample trajectory length from weighted distribution
            length_bin_idx = np.random.choice(len(LENGTH_BINS), p=LENGTH_BIN_WEIGHTS)
            length_min, length_max = LENGTH_BINS[length_bin_idx]
            length = np.random.randint(length_min, length_max + 1)
            
            # Generate CLEAN trajectory
            traj_clean = andi.create_dataset(
                T=length,
                N_models=1,
                exponents=[alpha],
                models=[model_id],
                dimension=DIM
            )
            
            # Extract trajectory positions
            traj_clean = traj_clean[0][3:]
            
            # Ensure correct shape [T, 1]
            if traj_clean.ndim == 1:
                traj_clean = traj_clean.reshape(-1, 1)
            
            actual_length = len(traj_clean)
            
            # =================================================================
            # GROUND TRUTH CALCULATION (on CLEAN trajectory)
            # =================================================================
            
            # Alpha: from ANDI generation (ground truth)
            # D: MSD-based Langevin estimator (unbiased for any Œ±)
            D = estimate_D_from_msd(traj_clean, alpha)
            
            # =================================================================
            # NO NOISE - Use clean trajectory
            # =================================================================
            
            snr = 0  # Indicator for no noise
            traj_noisy = traj_clean  # No noise added
            sigma_noise = 0.0  # No noise standard deviation
            
            # =================================================================
            # Calculate displacements and pad data
            # =================================================================
            
            # Calculate raw displacements: dx = x[t+1] - x[t]
            displacements_raw = np.diff(traj_noisy, axis=0)
            
            # Calculate scaled displacements: dx / sqrt(dt)
            displacements_scaled = displacements_raw / np.sqrt(1.0)
            
            # Pad trajectory to MAX_LENGTH
            traj_padded = np.zeros((MAX_LENGTH, DIM), dtype=np.float32)
            traj_padded[:actual_length] = traj_noisy
            
            # Pad displacements
            disp_raw_padded = np.zeros((MAX_LENGTH - 1, DIM), dtype=np.float32)
            disp_raw_padded[:actual_length-1] = displacements_raw
            
            disp_scaled_padded = np.zeros((MAX_LENGTH - 1, DIM), dtype=np.float32)
            disp_scaled_padded[:actual_length-1] = displacements_scaled
            
            # Store data
            train_val_trajectories.append(traj_padded)
            train_val_displacements_raw.append(disp_raw_padded)
            train_val_displacements_scaled.append(disp_scaled_padded)
            
            # Store metadata (alpha from ANDI, D from MSD estimation, SNR fields)
            train_val_metadata.append({
                'traj_id': trajs_generated,
                'model_id': model_id,
                'model_name': MODEL_NAMES[model_id],
                'alpha': alpha,
                'D': D,
                'length': actual_length,
                'snr': snr,
                'sigma_noise': sigma_noise
            })
            
            trajs_generated += 1
            
            # Progress update every 10K trajectories
            if trajs_generated % 10_000 == 0:
                elapsed = time_module.time() - start_time
                rate = trajs_generated / elapsed
                remaining = ((N_TRAIN + N_VAL) - trajs_generated) / rate
                print(f"  Generated {trajs_generated:,}/{N_TRAIN + N_VAL:,} "
                      f"({trajs_generated/(N_TRAIN + N_VAL)*100:.1f}%) - "
                      f"Rate: {rate:.0f} traj/s - "
                      f"ETA: {remaining/60:.1f} min")
        
        except Exception as e:
            print(f"\n‚ö† Warning: Failed to generate trajectory for "
                  f"{MODEL_NAMES[model_id]}, Œ±={alpha:.2f}, T={length}")
            print(f"  Error: {e}")
            continue

# Convert to arrays
train_val_trajectories = np.array(train_val_trajectories, dtype=np.float32)
train_val_displacements_raw = np.array(train_val_displacements_raw, dtype=np.float32)
train_val_displacements_scaled = np.array(train_val_displacements_scaled, dtype=np.float32)

# Summary
elapsed_total = time_module.time() - start_time
df_train_val_metadata = pd.DataFrame(train_val_metadata)

print(f"\nTrain/val set complete: {trajs_generated:,} trajectories in {elapsed_total/60:.1f} min")
print(f"Rate: {trajs_generated/elapsed_total:.1f} traj/s")
print(f"Models: {df_train_val_metadata['model_name'].value_counts().to_dict()}")
print(f"Length range: [{df_train_val_metadata['length'].min()}, {df_train_val_metadata['length'].max()}]")


GENERATING BALANCED TRAIN/VAL SET

Total (model, Œ±) combinations: 68
Trajectories per combination: 0
Total train/val trajectories: 0

Generating 0 train/val trajectories...



(Model, Œ±) Pairs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 68/68 [00:00<00:00, 953888.54it/s]


TRAIN/VAL SET GENERATION COMPLETE
‚úì Generated: 0 trajectories
‚úì Time elapsed: 0.0 minutes
‚úì Average rate: 0.0 trajectories/second

Train/Val set distribution:





KeyError: 'model_name'

In [None]:
# =============================================================================
# DATA QUALITY VALIDATION: Check for gaps in alpha and model coverage
# =============================================================================

print("=" * 70)
print("DATA QUALITY VALIDATION")
print("=" * 70)
print()

# Combine all metadata
df_all_metadata = pd.concat([
    df_test_metadata,
    df_train_val_metadata
], ignore_index=True)

print(f"Total trajectories: {len(df_all_metadata):,}")
print()

# =============================================================================
# 1. CHECK ALPHA COVERAGE PER MODEL
# =============================================================================

print("1. ALPHA COVERAGE CHECK")
print("-" * 70)

all_gaps_found = False

for model_id, model_name in MODEL_NAMES.items():
    # Get expected alphas for this model
    expected_alphas = set(ANDI_ALPHA_SPECS[model_id])
    
    # Get actual alphas in dataset
    model_data = df_all_metadata[df_all_metadata['model_id'] == model_id]
    actual_alphas = set(np.round(model_data['alpha'].unique(), 2))
    
    # Find gaps
    missing_alphas = expected_alphas - actual_alphas
    extra_alphas = actual_alphas - expected_alphas
    
    # Report
    print(f"\n{model_name} (model_id={model_id}):")
    print(f"  Expected alpha range: [{min(expected_alphas):.2f}, {max(expected_alphas):.2f}] ({len(expected_alphas)} values)")
    print(f"  Actual alpha range:   [{min(actual_alphas):.2f}, {max(actual_alphas):.2f}] ({len(actual_alphas)} values)")
    print(f"  Trajectories: {len(model_data):,}")
    
    if missing_alphas:
        print(f"  ‚ö†Ô∏è  MISSING ALPHAS: {sorted(missing_alphas)}")
        all_gaps_found = True
    else:
        print(f"  ‚úì All expected alphas present")
    
    if extra_alphas:
        print(f"  ‚ö†Ô∏è  EXTRA ALPHAS (not in ANDI spec): {sorted(extra_alphas)}")
        all_gaps_found = True

print()
print("-" * 70)

# =============================================================================
# 2. CHECK MODEL BALANCE
# =============================================================================

print("\n2. MODEL BALANCE CHECK")
print("-" * 70)

model_counts = df_all_metadata['model_name'].value_counts()
print("\nTrajectories per model:")
for model_name, count in model_counts.items():
    percentage = 100 * count / len(df_all_metadata)
    print(f"  {model_name:10s}: {count:8,} ({percentage:5.2f}%)")

# Check if balanced (should be roughly 20% each)
min_percentage = 100 * model_counts.min() / len(df_all_metadata)
max_percentage = 100 * model_counts.max() / len(df_all_metadata)
imbalance = max_percentage - min_percentage

print(f"\nImbalance: {imbalance:.2f}%")
if imbalance > 5.0:
    print(f"  ‚ö†Ô∏è  WARNING: Models are imbalanced (>{imbalance:.1f}% difference)")
    all_gaps_found = True
else:
    print(f"  ‚úì Models are reasonably balanced")

# =============================================================================
# FINAL VERDICT
# =============================================================================

print("\n" + "=" * 70)
if all_gaps_found:
    print("‚ö†Ô∏è  VALIDATION FAILED: Gaps or issues found in dataset coverage")
else:
    print("‚úÖ VALIDATION PASSED: Dataset has complete coverage")
print("=" * 70)
print()


DATA QUALITY VALIDATION

Total trajectories: 99,688

1. ALPHA COVERAGE CHECK
----------------------------------------------------------------------

ATTM (model_id=0):
  Expected alpha range: [0.10, 1.00] (10 values)
  Actual alpha range:   [0.10, 1.00] (10 values)
  Trajectories: 14,660
  ‚úì All expected alphas present

CTRW (model_id=1):
  Expected alpha range: [0.10, 1.00] (10 values)
  Actual alpha range:   [0.10, 1.00] (10 values)
  Trajectories: 14,660
  ‚úì All expected alphas present

FBM (model_id=2):
  Expected alpha range: [0.10, 1.90] (19 values)
  Actual alpha range:   [0.10, 1.90] (19 values)
  Trajectories: 27,854
  ‚úì All expected alphas present

LW (model_id=3):
  Expected alpha range: [1.00, 1.90] (10 values)
  Actual alpha range:   [1.00, 1.90] (10 values)
  Trajectories: 14,660
  ‚úì All expected alphas present

SBM (model_id=4):
  Expected alpha range: [0.10, 1.90] (19 values)
  Actual alpha range:   [0.10, 1.90] (19 values)
  Trajectories: 27,854
  ‚úì All expec

In [None]:
# Cell 7: Split Train/Val Set

from sklearn.model_selection import train_test_split

# =============================================================================
# STRATIFIED TRAIN/VAL SPLIT
# =============================================================================

print("=" * 70)
print("SPLITTING TRAIN/VAL SET")
print("=" * 70)
print()

# Test set is already separate, we only need to split train/val
print(f"Test set: {len(test_metadata):,} trajectories (already separated)")
print(f"Train/Val set: {len(train_val_metadata):,} trajectories (needs splitting)")
print(f"Target split: {N_TRAIN:,} train / {N_VAL:,} val")
print()

# Create stratification labels (model_id + alpha combination)
df_train_val_metadata['strat_label'] = (
    df_train_val_metadata['model_id'].astype(str) + "_" + 
    df_train_val_metadata['alpha'].astype(str)
)

# Split train/val with stratification
train_idx, val_idx = train_test_split(
    np.arange(len(df_train_val_metadata)),
    test_size=N_VAL / (N_TRAIN + N_VAL),  # 10% for val
    random_state=SEED,
    stratify=df_train_val_metadata['strat_label']
)

# Assign split labels
df_train_val_metadata['split'] = 'train'
df_train_val_metadata.loc[val_idx, 'split'] = 'val'

print("=" * 70)
print("SPLIT COMPLETE")
print("=" * 70)
print(f"Train: {len(train_idx):,} ({len(train_idx)/len(df_train_val_metadata)*100:.2f}%)")
print(f"Val:   {len(val_idx):,} ({len(val_idx)/len(df_train_val_metadata)*100:.2f}%)")
print("=" * 70)
print()

# Verify stratification
print("Model distribution across train/val splits:")
print(pd.crosstab(df_train_val_metadata['model_name'], df_train_val_metadata['split'], normalize='columns') * 100)
print()

# Combine metadata
# Add 'split' column to test metadata
df_test_metadata['split'] = 'test'

# Combine all metadata
all_metadata = pd.concat([
    df_train_val_metadata,
    df_test_metadata
], ignore_index=True)

# Reassign traj_ids to be sequential
all_metadata['traj_id'] = np.arange(len(all_metadata))

print(f"Total dataset: {len(all_metadata):,} trajectories")
print(f"  Train: {(all_metadata['split'] == 'train').sum():,}")
print(f"  Val:   {(all_metadata['split'] == 'val').sum():,}")
print(f"  Test:  {(all_metadata['split'] == 'test').sum():,}")
print()
print("=" * 70)
print()

SPLITTING TRAIN/VAL SET

Test set: 99,008 trajectories (already separated)
Train/Val set: 0 trajectories (needs splitting)
Target split: 0 train / 0 val



KeyError: 'model_id'

In [None]:
# Cell 8: Export to HDF5 with ANDI Table 2 Structure

import h5py
import os
import time as time_module
import shutil

# =============================================================================
# EXPORT TO HDF5 FORMAT
# =============================================================================

print("=" * 70)
print("EXPORTING TO HDF5 FORMAT")
print("=" * 70)
print("Creating PyTorch-ready HDF5 dataset with:")
print("  - Padded trajectories (positions)")
print("  - Dual-preprocessed displacements (raw + scaled)")
print("  - Padding masks for variable lengths")
print("  - Ground truth labels (D, alpha, model_id)")
print("  - SNR and sigma_noise for test set")
print("  - Separate train/val/test groups")
print("  - ANDI Table 2 compliant")
print("=" * 70)
print()

# Determine output path based on environment
if IN_COLAB:
    # Generate in fast local storage first
    output_dir = TEMP_OUTPUT_DIR
else:
    # Local machine - use final directory directly
    output_dir = FINAL_OUTPUT_DIR

os.makedirs(output_dir, exist_ok=True)

# Generate filename with timestamp
timestamp = time.strftime("%Y%m%d_%H%M%S")
h5_filename = f'andi_dataset_table2_{timestamp}.h5'
h5_path = os.path.join(output_dir, h5_filename)

print(f"Output file: {h5_path}")
print()

# =============================================================================
# PREPARE DATA ARRAYS
# =============================================================================

print("Preparing data arrays...")

# Combine trajectories and metadata by split
train_indices = all_metadata[all_metadata['split'] == 'train'].index.tolist()
val_indices = all_metadata[all_metadata['split'] == 'val'].index.tolist()
test_indices = all_metadata[all_metadata['split'] == 'test'].index.tolist()

# Combine all trajectory data
all_trajectories = train_val_trajectories + test_trajectories
all_displacements_raw = train_val_displacements_raw + test_displacements_raw
all_displacements_scaled = train_val_displacements_scaled + test_displacements_scaled

print(f"‚úì Total trajectories: {len(all_trajectories):,}")
print(f"  Train: {len(train_indices):,}")
print(f"  Val:   {len(val_indices):,}")
print(f"  Test:  {len(test_indices):,}")
print()

# =============================================================================
# CREATE HDF5 FILE
# =============================================================================

start_time = time_module.time()

with h5py.File(h5_path, 'w') as hf:
    
    # Process each split
    for split_name in ['train', 'val', 'test']:
        print(f"Processing {split_name} split...")
        
        # Get indices for this split
        split_indices = all_metadata[all_metadata['split'] == split_name].index.tolist()
        n_split = len(split_indices)
        
        if n_split == 0:
            print(f"  ‚ö† Warning: No data for {split_name} split, skipping...")
            continue
        
        # Create group for this split
        split_group = hf.create_group(split_name)
        
        # Get metadata for this split
        split_meta = all_metadata.iloc[split_indices]
        
        # 1. POSITIONS (padded to MAX_LENGTH)
        print(f"  Writing positions...")
        positions_padded = np.zeros((n_split, MAX_LENGTH, DIM), dtype=np.float16)  # float16 for 50% size reduction)
        mask_positions = np.zeros((n_split, MAX_LENGTH), dtype=bool)
        
        for i, idx in enumerate(split_indices):
            traj = all_trajectories[idx]
            T = len(traj)
            positions_padded[i, :T, :] = traj
            mask_positions[i, :T] = True
        
        split_group.create_dataset(
            'positions',
            data=positions_padded,
            compression="gzip",  # Enable compression
            compression_opts=4   # Level 4 = good balance
        )
        split_group.create_dataset('mask_positions', data=mask_positions, compression='gzip', compression_opts=4)
        
        # 2. DISPLACEMENTS RAW (padded to MAX_LENGTH - 1)
        print(f"  Writing raw displacements...")
        disp_raw_padded = np.zeros((n_split, MAX_LENGTH - 1, DIM), dtype=np.float16)  # float16 for 50% size reduction)
        disp_scaled_padded = np.zeros((n_split, MAX_LENGTH - 1, DIM), dtype=np.float16)  # float16 for 50% size reduction
        mask_disp = np.zeros((n_split, MAX_LENGTH - 1), dtype=bool)
        
        for i, idx in enumerate(split_indices):
            disp = all_displacements_raw[idx]
            T = len(disp)
            # Ensure disp is 2D (T, 1) for proper broadcasting
            if disp.ndim == 1:
                disp = disp.reshape(-1, 1)
            disp_raw_padded[i, :T, :] = disp
            mask_disp[i, :T] = True
        
        split_group.create_dataset(
            'displacements_raw',
            data=disp_raw_padded,
            compression="gzip",
            compression_opts=4
        )
        
        # 3. DISPLACEMENTS SCALED
        print(f"  Writing scaled displacements...")
        disp_scaled_padded = np.zeros((n_split, MAX_LENGTH - 1, DIM), dtype=np.float16) # float16 for 50% size reduction)
        
        for i, idx in enumerate(split_indices):
            disp = all_displacements_scaled[idx]
            T = len(disp)
            # Ensure disp is 2D (T, 1) for proper broadcasting
            if disp.ndim == 1:
                disp = disp.reshape(-1, 1)
            disp_scaled_padded[i, :T, :] = disp
        
        split_group.create_dataset(
            'displacements_scaled',
            data=disp_scaled_padded,
            compression="gzip",
            compression_opts=4
        )
        split_group.create_dataset('mask_displacements', data=mask_disp, compression='gzip', compression_opts=4)
        
        # 4. GROUND TRUTH LABELS
        print(f"  Writing ground truth labels...")
        
        # Handle D values for float32 safety (prevent underflow/overflow)
        D_values = split_meta['D'].values.copy()
        
        # Replace NaN and inf with a safe default value
        D_values = np.nan_to_num(D_values, nan=1e-10, posinf=1e10, neginf=1e-10)
        
        # Clip to safe range before casting (float32 min ~1e-38, but use 1e-10 for safety)
        D_values = np.clip(D_values, 1e-10, 1e10)
        
        # Now safe to cast to float32
        D_values = D_values.astype(np.float32)
        
        # H field removed - using alpha directly
        split_group.create_dataset('D', data=D_values)
        split_group.create_dataset('alpha', data=split_meta['alpha'].values.astype(np.float32))
        split_group.create_dataset('model_id', data=split_meta['model_id'].values.astype(np.int32))
        split_group.create_dataset('length', data=split_meta['length'].values.astype(np.int32))
        split_group.create_dataset('traj_id', data=split_meta['traj_id'].values.astype(np.int32))
        
        # 5. SNR FIELDS (ALL SPLITS: train/val/test all have SNR=1 or SNR=2)
        print(f"  Writing SNR metadata...")
        # Handle any potential np.inf values: convert to a large finite value for float32 storage
        # (Note: train/val/test should all have SNR=1 or SNR=2, but this handles edge cases)
        snr_values = split_meta['snr'].values.copy()
        snr_values = np.where(np.isinf(snr_values), 1e6, snr_values)  # Replace inf with 1e6 (if any)
        split_group.create_dataset('snr', data=snr_values.astype(np.float32))
        split_group.create_dataset('sigma_noise', data=split_meta['sigma_noise'].values.astype(np.float32))
        
        print(f"‚úì {split_name}: {n_split:,} trajectories written")
        print()
    
    # =============================================================================
    # METADATA
    # =============================================================================
    
    print("Writing metadata...")
    metadata_group = hf.create_group('metadata')
    
    # Model information
    metadata_group.create_dataset('model_names', data=np.array(list(MODEL_NAMES.values()), dtype='S10'))
    metadata_group.create_dataset('model_ids', data=np.array(list(MODEL_NAMES.keys()), dtype=np.int32))
    
    # ANDI Table 2 specifications
    metadata_group.create_dataset('andi_test_lengths', data=np.array(ANDI_TEST_LENGTHS, dtype=np.int32))
    metadata_group.create_dataset('andi_snr_levels', data=np.array(ANDI_SNR_LEVELS, dtype=np.int32))
    
    # Alpha ranges per model
    for model_id in range(5):
        metadata_group.create_dataset(
            f'alpha_range_model_{model_id}',
            data=ANDI_ALPHA_SPECS[model_id].astype(np.float32)
        )
    
    # Configuration attributes
    metadata_group.attrs['seed'] = SEED
    metadata_group.attrs['max_length'] = MAX_LENGTH
    metadata_group.attrs['dimension'] = DIM
    metadata_group.attrs['n_models'] = len(MODELS)
    metadata_group.attrs['andi_table2_compliant'] = True
    metadata_group.attrs['test_permutations'] = N_TEST_PERMUTATIONS
    metadata_group.attrs['test_reps_per_permutation'] = N_REPS_PER_TEST_PERMUTATION
    metadata_group.attrs['timestamp'] = timestamp
    
    print("‚úì Metadata written")
    print()

elapsed = time_module.time() - start_time

print("=" * 70)
print("HDF5 EXPORT COMPLETE")
print("=" * 70)
print(f"‚úì File: {h5_path}")
print(f"‚úì Size: {os.path.getsize(h5_path) / 1e9:.2f} GB")
print(f"‚úì Time: {elapsed/60:.1f} minutes")
print("=" * 70)
print()

# =============================================================================
# VERIFICATION
# =============================================================================

print("Verifying HDF5 structure...")
with h5py.File(h5_path, 'r') as hf:
    print("\nGroups:")
    for key in hf.keys():
        print(f"  /{key}")
        if key in ['train', 'val', 'test']:
            print(f"    Datasets:")
            for dset in hf[key].keys():
                shape = hf[key][dset].shape
                dtype = hf[key][dset].dtype
                print(f"      {dset}: {shape} {dtype}")
    
    print("\n‚úì Verification complete")
    print("=" * 70)

# =============================================================================
# COPY TO GOOGLE DRIVE (COLAB ONLY)
# =============================================================================

if IN_COLAB:
    print()
    print("=" * 70)
    print("COPYING TO GOOGLE DRIVE FOR PERSISTENCE")
    print("=" * 70)
    print()
    
    final_path = os.path.join(FINAL_OUTPUT_DIR, h5_filename)
    
    if h5_path != final_path:
        print(f"Copying from: {h5_path}")
        print(f"Copying to:   {final_path}")
        print("‚è≥ This may take 5-10 minutes for large files...")
        print()
        
        copy_start = time_module.time()
        shutil.copy(h5_path, final_path)
        copy_elapsed = time_module.time() - copy_start
        
        print(f"‚úì File copied to Google Drive in {copy_elapsed/60:.1f} minutes")
        print(f"‚úì File will persist after Colab session ends")
        print()
        
        # Clean up local copy to save space
        print("Cleaning up local temporary file...")
        os.remove(h5_path)
        print("‚úì Local temporary file removed")
        print()
        
        # Update path reference
        h5_path = final_path
    else:
        print("‚úì File already in Google Drive")
    
    print()
    print("üì• TO DOWNLOAD TO YOUR LOCAL MACHINE:")
    print("Run this in a new cell:")
    print("---")
    print("from google.colab import files")
    print(f"files.download('{final_path}')")
    print("---")
    print()
    print("=" * 70)
    print()
else:
    print()
    print(f"‚úì Dataset saved locally at: {h5_path}")
    print()

print()
print("=" * 70)
print("üéâ DATA GENERATION PIPELINE COMPLETE!")
print("=" * 70)
print()
print("Next steps:")
print("1. If on Colab: Download the HDF5 file to your local machine")
print("2. Update training notebook dataset path")
print("3. Run training on local machine or Colab")
print()
print(f"Final dataset location: {h5_path}")
print("=" * 70)

EXPORTING TO HDF5 FORMAT
Creating PyTorch-ready HDF5 dataset with:
  - Padded trajectories (positions)
  - Dual-preprocessed displacements (raw + scaled)
  - Padding masks for variable lengths
  - Ground truth labels (D, alpha, model_id)
  - SNR and sigma_noise for test set
  - Separate train/val/test groups
  - ANDI Table 2 compliant

Output file: /home/magjun/Documents/ERP_Shrey/Report_V2_Preprocessing_and_training/data/andi/andi_dataset_table2_20251130_165225.h5

Preparing data arrays...
‚úì Total trajectories: 99,688
  Train: 67,442
  Val:   7,494
  Test:  24,752

Processing train split...
  Writing positions...
  Writing raw displacements...
  Writing scaled displacements...
  Writing ground truth labels...
  Writing SNR metadata...
‚úì train: 67,442 trajectories written

Processing val split...
  Writing positions...
  Writing raw displacements...
  Writing scaled displacements...
  Writing ground truth labels...
  Writing SNR metadata...
‚úì val: 7,494 trajectories written

Proc