In [None]:
import sys
import os
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path('/mnt/home/mlee1/vdm_BIND')
sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.gridspec as gridspec
from tqdm import tqdm
import h5py

# Import BIND and utilities
from bind.bind import BIND
from bind.workflow_utils import ConfigLoader, load_normalization_stats

# Setup publication-quality plotting
plt.rcParams.update({
    'font.size': 14,
    'font.family': 'serif',
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 11,
    'figure.figsize': (12, 8),
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

# Output directory for figures
FIGURE_DIR = PROJECT_ROOT / 'analysis' / 'figures' / 'bind'
FIGURE_DIR.mkdir(parents=True, exist_ok=True)
print(f"Figures will be saved to: {FIGURE_DIR}")

## Configuration

In [None]:
# ============================================================================
# CONFIGURATION - MODIFY THIS SECTION
# ============================================================================

# Model type: 'clean' (3-channel joint) or 'triple' (3 separate VDMs)
MODEL_TYPE = 'clean'

# Config file path
if MODEL_TYPE == 'clean':
    CONFIG_PATH = PROJECT_ROOT / 'configs' / 'clean_vdm_aggressive_stellar.ini'
else:
    CONFIG_PATH = PROJECT_ROOT / 'configs' / 'clean_vdm_triple.ini'

# CV simulation to process
SIM_NUM = 12  # Choose from 0-26 (excluding 17)

# Simulation paths
CAMELS_ROOT = Path('/mnt/ceph/users/camels')
DMO_PATH = CAMELS_ROOT / 'Sims' / 'IllustrisTNG_DM' / 'L50n512' / 'CV' / f'CV_{SIM_NUM}'
HYDRO_PATH = CAMELS_ROOT / 'Sims' / 'IllustrisTNG' / 'L50n512' / 'CV' / f'CV_{SIM_NUM}'

# Snapshot number (90 = z=0 for CAMELS)
SNAPNUM = 90

# Grid parameters
BOX_SIZE = 50000.0  # kpc/h
GRID_SIZE = 1024

# Number of realizations
N_REALIZATIONS = 1

# Output directory for BIND results
OUTPUT_DIR = Path(f'/mnt/home/mlee1/ceph/BIND2d/CV/sim_{SIM_NUM}')

print(f"Model type: {MODEL_TYPE}")
print(f"Config: {CONFIG_PATH}")
print(f"CV Simulation: {SIM_NUM}")
print(f"DMO path: {DMO_PATH}")
print(f"Hydro path: {HYDRO_PATH}")

## Load Cosmological Parameters

In [None]:
import pandas as pd

# Load parameter bounds for normalization
param_path = '/mnt/home/mlee1/Sims/IllustrisTNG_extras/L50n512/SB35/SB35_param_minmax.csv'
if Path(param_path).exists():
    param_df = pd.read_csv(param_path)
    fiducial_params = list(param_df['FiducialVal'])
    print(f"Loaded {len(fiducial_params)} cosmological parameters")
else:
    fiducial_params = None
    print("Warning: Parameter file not found, using None")

## Initialize BIND Pipeline

In [None]:
# Initialize BIND
bind = BIND(
    simulation_path=str(DMO_PATH),
    snapnum=SNAPNUM,
    boxsize=BOX_SIZE,
    gridsize=GRID_SIZE,
    config_path=str(CONFIG_PATH),
    output_dir=str(OUTPUT_DIR),
    verbose=True,
    dim='2d',
    axis=2,  # Project along z-axis
    r_in_factor=2,
    r_out_factor=3,
    mass_threshold=1e13
)

## Run BIND Pipeline

In [None]:
# Step 1: Voxelize DMO simulation
print("Step 1: Voxelizing DMO simulation...")
bind.voxelize_simulation()

In [None]:
# Step 2: Extract halos with large-scale conditioning
print("\nStep 2: Extracting halos...")
bind.extract_halos(omega_m=0.3, use_large_scale=True, num_large_scales=3)

# Get number of halos
num_halos = len(bind.extracted['metadata'])
print(f"Extracted {num_halos} halos")

In [None]:
# Prepare conditional parameters (repeat for each halo)
if fiducial_params is not None:
    conditional_params = np.tile(fiducial_params, (num_halos, 1))
else:
    conditional_params = None

In [None]:
# Step 3: Generate baryonic fields
print("\nStep 3: Generating baryonic fields...")
generated_halos = bind.generate_halos(
    batch_size=16,
    conditional_params=conditional_params,
    use_large_scale=True,
    conserve_mass=True
)

In [None]:
# Step 4: Paste halos back to full box
print("\nStep 4: Pasting halos to full box...")
final_maps = bind.paste_halos(
    realizations=N_REALIZATIONS,
    use_enhanced=True
)

print(f"Generated {len(final_maps)} realization(s)")
print(f"Final map shape: {final_maps[0].shape}")

## Load Hydro Ground Truth

In [None]:
def load_hydro_maps(hydro_path, snapnum, box_size, grid_size, axis=2):
    """
    Load and voxelize hydro simulation fields.
    
    Returns:
        dict with 'dm', 'gas', 'stars' 2D projected fields
    """
    import MAS_library as MASL
    
    snap_path = Path(hydro_path) / f'snap_{snapnum:03d}.hdf5'
    
    if not snap_path.exists():
        # Try alternative naming
        snap_path = Path(hydro_path) / f'snapdir_{snapnum:03d}' / f'snap_{snapnum:03d}.0.hdf5'
    
    print(f"Loading hydro from: {snap_path}")
    
    fields_3d = {}
    
    with h5py.File(snap_path, 'r') as f:
        header = dict(f['Header'].attrs)
        box_size_file = header['BoxSize']  # Usually in kpc/h
        
        # Load DM particles (PartType1)
        if 'PartType1' in f:
            dm_pos = f['PartType1/Coordinates'][:]
            if 'PartType1/Masses' in f:
                dm_mass = f['PartType1/Masses'][:]
            else:
                dm_mass = np.ones(len(dm_pos)) * header['MassTable'][1]
            
            dm_field = np.zeros((grid_size, grid_size, grid_size), dtype=np.float32)
            MASL.MA(dm_pos.astype(np.float32), dm_field, box_size_file, MAS='CIC', 
                   W=dm_mass.astype(np.float32), verbose=False)
            fields_3d['dm'] = dm_field
        
        # Load Gas particles (PartType0)
        if 'PartType0' in f:
            gas_pos = f['PartType0/Coordinates'][:]
            gas_mass = f['PartType0/Masses'][:]
            
            gas_field = np.zeros((grid_size, grid_size, grid_size), dtype=np.float32)
            MASL.MA(gas_pos.astype(np.float32), gas_field, box_size_file, MAS='CIC', 
                   W=gas_mass.astype(np.float32), verbose=False)
            fields_3d['gas'] = gas_field
        
        # Load Star particles (PartType4)
        if 'PartType4' in f:
            star_pos = f['PartType4/Coordinates'][:]
            star_mass = f['PartType4/Masses'][:]
            
            star_field = np.zeros((grid_size, grid_size, grid_size), dtype=np.float32)
            MASL.MA(star_pos.astype(np.float32), star_field, box_size_file, MAS='CIC', 
                   W=star_mass.astype(np.float32), verbose=False)
            fields_3d['stars'] = star_field
    
    # Project to 2D along specified axis
    fields_2d = {}
    for name, field in fields_3d.items():
        fields_2d[name] = field.sum(axis=axis)
    
    return fields_2d

In [None]:
# Load hydro ground truth
print("Loading hydro ground truth...")
try:
    hydro_fields = load_hydro_maps(HYDRO_PATH, SNAPNUM, BOX_SIZE, GRID_SIZE, axis=2)
    print(f"Loaded hydro fields: {list(hydro_fields.keys())}")
    for name, field in hydro_fields.items():
        print(f"  {name}: {field.shape}, sum={field.sum():.2e}")
except Exception as e:
    print(f"Failed to load hydro: {e}")
    hydro_fields = None

## Power Spectrum Computation

In [None]:
def compute_power_spectrum_2d(field, box_size_mpc, n_bins=50):
    """
    Compute 2D power spectrum using FFT.
    
    Args:
        field: 2D density field
        box_size_mpc: Box size in Mpc/h
        n_bins: Number of k bins
    
    Returns:
        k: Wavenumber array in h/Mpc
        Pk: Power spectrum P(k)
    """
    N = field.shape[0]
    
    # Compute FFT
    fft_field = np.fft.fft2(field)
    fft_field = np.fft.fftshift(fft_field)
    
    # Power spectrum
    power = np.abs(fft_field)**2
    
    # Create k grid
    kx = np.fft.fftfreq(N, d=box_size_mpc/N) * 2 * np.pi
    ky = np.fft.fftfreq(N, d=box_size_mpc/N) * 2 * np.pi
    kx = np.fft.fftshift(kx)
    ky = np.fft.fftshift(ky)
    kx_grid, ky_grid = np.meshgrid(kx, ky)
    k_mag = np.sqrt(kx_grid**2 + ky_grid**2)
    
    # Bin the power spectrum
    k_min = 2 * np.pi / box_size_mpc
    k_max = np.pi * N / box_size_mpc
    k_bins = np.logspace(np.log10(k_min), np.log10(k_max), n_bins + 1)
    k_centers = np.sqrt(k_bins[:-1] * k_bins[1:])
    
    Pk = np.zeros(n_bins)
    counts = np.zeros(n_bins)
    
    for i in range(n_bins):
        mask = (k_mag >= k_bins[i]) & (k_mag < k_bins[i+1])
        if mask.sum() > 0:
            Pk[i] = power[mask].mean()
            counts[i] = mask.sum()
    
    # Normalize
    Pk *= (box_size_mpc / N)**2
    
    # Remove empty bins
    valid = counts > 0
    
    return k_centers[valid], Pk[valid]

In [None]:
# Convert box size to Mpc/h
box_size_mpc = BOX_SIZE / 1000.0

# Compute power spectra
print("Computing power spectra...")

# DMO power spectrum
dmo_field = bind.sim_grid
k_dmo, Pk_dmo = compute_power_spectrum_2d(dmo_field, box_size_mpc)
print(f"  DMO: k range [{k_dmo.min():.3f}, {k_dmo.max():.1f}] h/Mpc")

# BIND power spectrum (first realization, sum of all channels)
bind_field = final_maps[0]
if bind_field.ndim == 3:  # (3, H, W) -> sum to total
    bind_total = bind_field.sum(axis=0)
else:
    bind_total = bind_field
k_bind, Pk_bind = compute_power_spectrum_2d(bind_total, box_size_mpc)
print(f"  BIND: computed")

# Hydro power spectra (if available)
Pk_hydro = {}
if hydro_fields is not None:
    for name, field in hydro_fields.items():
        k, Pk = compute_power_spectrum_2d(field, box_size_mpc)
        Pk_hydro[name] = (k, Pk)
    
    # Total hydro
    hydro_total = sum(hydro_fields.values())
    k_hydro, Pk_hydro_total = compute_power_spectrum_2d(hydro_total, box_size_mpc)
    Pk_hydro['total'] = (k_hydro, Pk_hydro_total)
    print(f"  Hydro: computed all channels")

## Plot Power Spectra

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Absolute power spectra
ax = axes[0]
ax.loglog(k_dmo, Pk_dmo, 'k-', linewidth=2, label='DMO')
ax.loglog(k_bind, Pk_bind, 'b-', linewidth=2, label='BIND (total)')

if hydro_fields is not None:
    k_h, Pk_h = Pk_hydro['total']
    ax.loglog(k_h, Pk_h, 'r--', linewidth=2, label='Hydro (total)')

ax.set_xlabel(r'$k$ [h/Mpc]')
ax.set_ylabel(r'$P(k)$ [(Mpc/h)$^2$]')
ax.set_title('2D Power Spectrum')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Power spectrum ratio (suppression)
ax = axes[1]

# Interpolate to common k grid
from scipy.interpolate import interp1d

# Use DMO k as reference
k_ref = k_dmo

# BIND suppression
interp_bind = interp1d(k_bind, Pk_bind, bounds_error=False, fill_value=np.nan)
interp_dmo = interp1d(k_dmo, Pk_dmo, bounds_error=False, fill_value=np.nan)

ratio_bind = interp_bind(k_ref) / interp_dmo(k_ref)
ax.semilogx(k_ref, ratio_bind, 'b-', linewidth=2, label='BIND / DMO')

if hydro_fields is not None:
    k_h, Pk_h = Pk_hydro['total']
    interp_hydro = interp1d(k_h, Pk_h, bounds_error=False, fill_value=np.nan)
    ratio_hydro = interp_hydro(k_ref) / interp_dmo(k_ref)
    ax.semilogx(k_ref, ratio_hydro, 'r--', linewidth=2, label='Hydro / DMO')

ax.axhline(y=1, color='gray', linestyle='--', linewidth=1)
ax.set_xlabel(r'$k$ [h/Mpc]')
ax.set_ylabel(r'$P(k) / P_{\rm DMO}(k)$')
ax.set_title('Power Spectrum Suppression')
ax.set_ylim(0, 1.5)
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle(f'CV Simulation {SIM_NUM}', fontsize=16, y=1.02)
plt.tight_layout()

fig.savefig(FIGURE_DIR / f'cv{SIM_NUM}_power_spectrum.png')
plt.show()

## Per-Channel Power Spectra

In [None]:
channel_names = ['DM Hydro', 'Gas', 'Stars']
channel_colors = ['#1f77b4', '#2ca02c', '#d62728']

if bind_field.ndim == 3 and bind_field.shape[0] == 3:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for ch_idx, (ax, ch_name, color) in enumerate(zip(axes, channel_names, channel_colors)):
        # BIND channel
        k_b, Pk_b = compute_power_spectrum_2d(bind_field[ch_idx], box_size_mpc)
        ax.loglog(k_b, Pk_b, '-', color=color, linewidth=2, label='BIND')
        
        # Hydro channel (if available)
        if hydro_fields is not None:
            hydro_key = ['dm', 'gas', 'stars'][ch_idx]
            if hydro_key in Pk_hydro:
                k_h, Pk_h = Pk_hydro[hydro_key]
                ax.loglog(k_h, Pk_h, '--', color='gray', linewidth=2, label='Hydro')
        
        ax.set_xlabel(r'$k$ [h/Mpc]')
        ax.set_ylabel(r'$P(k)$')
        ax.set_title(ch_name)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle(f'Per-Channel Power Spectra - CV {SIM_NUM}', fontsize=16, y=1.02)
    plt.tight_layout()
    
    fig.savefig(FIGURE_DIR / f'cv{SIM_NUM}_power_spectrum_channels.png')
    plt.show()
else:
    print("BIND output is not 3-channel, skipping per-channel analysis")

## Visual Comparison: Full Box Maps

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# DMO
ax = axes[0]
im = ax.imshow(dmo_field, cmap='viridis', norm=LogNorm())
ax.set_title('DMO')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

# BIND
ax = axes[1]
im = ax.imshow(bind_total, cmap='inferno', norm=LogNorm())
ax.set_title('BIND (total)')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

# Hydro
ax = axes[2]
if hydro_fields is not None:
    hydro_total = sum(hydro_fields.values())
    im = ax.imshow(hydro_total, cmap='inferno', norm=LogNorm())
    ax.set_title('Hydro (total)')
else:
    ax.text(0.5, 0.5, 'Hydro not loaded', ha='center', va='center', transform=ax.transAxes)
    ax.set_title('Hydro (not available)')
ax.axis('off')
if hydro_fields is not None:
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle(f'Full Box Comparison - CV {SIM_NUM}', fontsize=16, y=1.02)
plt.tight_layout()

fig.savefig(FIGURE_DIR / f'cv{SIM_NUM}_full_box_comparison.png')
plt.show()

## Summary Statistics

In [None]:
print(f"\n{'='*60}")
print(f"BIND Analysis Summary - CV {SIM_NUM}")
print(f"{'='*60}\n")

# Config info
config = ConfigLoader(str(CONFIG_PATH), verbose=False)
print(f"Model: {config.model_name}")
print(f"Checkpoint: {Path(config.best_ckpt).name}")
print()

# Simulation info
print(f"Simulation:")
print(f"  CV number: {SIM_NUM}")
print(f"  Box size: {BOX_SIZE/1000:.1f} Mpc/h")
print(f"  Grid: {GRID_SIZE}^2")
print(f"  Halos processed: {num_halos}")
print()

# Mass comparison
print(f"Total Mass Comparison:")
print(f"  DMO:   {dmo_field.sum():.4e}")
print(f"  BIND:  {bind_total.sum():.4e}")
if hydro_fields is not None:
    hydro_total_mass = sum(f.sum() for f in hydro_fields.values())
    print(f"  Hydro: {hydro_total_mass:.4e}")
print()

# Power spectrum comparison at specific k
k_ref_vals = [0.5, 1.0, 5.0, 10.0]  # h/Mpc
print(f"Power Spectrum Ratio at specific k:")
for k_val in k_ref_vals:
    if k_val < k_ref.max() and k_val > k_ref.min():
        idx = np.argmin(np.abs(k_ref - k_val))
        bind_ratio = ratio_bind[idx]
        if hydro_fields is not None:
            hydro_ratio = ratio_hydro[idx]
            print(f"  k={k_val:.1f}: BIND/DMO={bind_ratio:.3f}, Hydro/DMO={hydro_ratio:.3f}")
        else:
            print(f"  k={k_val:.1f}: BIND/DMO={bind_ratio:.3f}")

## Compare Multiple Models (Optional)

In [None]:
# ============================================================================
# MODEL COMPARISON - Set to True to enable
# ============================================================================

COMPARE_MODELS = False

if COMPARE_MODELS:
    print("Model comparison enabled - this will run BIND for multiple models")
    print("This may take a while...")
    
    # Models to compare
    MODELS_TO_COMPARE = [
        ('clean_vdm_aggressive_stellar.ini', 'Clean 3ch', '#1f77b4'),
        ('clean_vdm_triple.ini', 'Triple', '#ff7f0e'),
    ]
    
    comparison_results = {}
    
    for config_file, label, color in MODELS_TO_COMPARE:
        config_path = PROJECT_ROOT / 'configs' / config_file
        if not config_path.exists():
            print(f"Config not found: {config_path}")
            continue
        
        print(f"\nRunning BIND with {label}...")
        
        # Initialize BIND with this config
        bind_cmp = BIND(
            simulation_path=str(DMO_PATH),
            snapnum=SNAPNUM,
            boxsize=BOX_SIZE,
            gridsize=GRID_SIZE,
            config_path=str(config_path),
            output_dir=str(OUTPUT_DIR / label.replace(' ', '_')),
            verbose=False,
            dim='2d',
            axis=2,
            r_in_factor=2,
            r_out_factor=3,
            mass_threshold=1e13
        )
        
        # Run pipeline
        bind_cmp.voxelize_simulation()
        bind_cmp.extract_halos(omega_m=0.3, use_large_scale=True, num_large_scales=3)
        
        n_halos = len(bind_cmp.extracted['metadata'])
        cond_params = np.tile(fiducial_params, (n_halos, 1)) if fiducial_params else None
        
        bind_cmp.generate_halos(batch_size=16, conditional_params=cond_params,
                               use_large_scale=True, conserve_mass=True)
        final_maps_cmp = bind_cmp.paste_halos(realizations=1, use_enhanced=True)
        
        # Compute power spectrum
        bind_field_cmp = final_maps_cmp[0]
        if bind_field_cmp.ndim == 3:
            bind_total_cmp = bind_field_cmp.sum(axis=0)
        else:
            bind_total_cmp = bind_field_cmp
        
        k_cmp, Pk_cmp = compute_power_spectrum_2d(bind_total_cmp, box_size_mpc)
        
        comparison_results[label] = {
            'k': k_cmp,
            'Pk': Pk_cmp,
            'color': color
        }
        print(f"  Done.")
    
    # Plot comparison
    if comparison_results:
        fig, ax = plt.subplots(figsize=(10, 6))
        
        for label, data in comparison_results.items():
            interp_cmp = interp1d(data['k'], data['Pk'], bounds_error=False, fill_value=np.nan)
            ratio = interp_cmp(k_ref) / interp_dmo(k_ref)
            ax.semilogx(k_ref, ratio, '-', color=data['color'], linewidth=2, label=f'{label} / DMO')
        
        if hydro_fields is not None:
            ax.semilogx(k_ref, ratio_hydro, 'k--', linewidth=2, label='Hydro / DMO')
        
        ax.axhline(y=1, color='gray', linestyle=':', linewidth=1)
        ax.set_xlabel(r'$k$ [h/Mpc]')
        ax.set_ylabel(r'$P(k) / P_{\rm DMO}(k)$')
        ax.set_title(f'Model Comparison - CV {SIM_NUM}')
        ax.set_ylim(0, 1.5)
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        fig.savefig(FIGURE_DIR / f'cv{SIM_NUM}_model_comparison.png')
        plt.show()