In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, SymLogNorm
from pathlib import Path
import os

# Configuration
SIM_RES = 2500
OUTPUT_DIR = Path('/mnt/home/mlee1/ceph/hydro_replace_fields')
SNAPSHOTS = [29, 31, 33, 35, 38, 41, 43, 46, 49, 52, 56, 59, 63, 67, 71, 76, 80, 85, 90, 96, 99]
MASS_THRESHOLDS = ['Mgt12.5', 'Mgt13.0', 'Mgt13.5', 'Mgt14.0']
BCM_MODELS = ['arico20', 'schneider19', 'schneider25']

In [None]:
# Check file inventory
def check_files(snap):
    """Check which files exist for a snapshot."""
    snap_dir = OUTPUT_DIR / f'L205n{SIM_RES}TNG' / f'snap{snap:03d}' / 'projected'
    
    files = {
        'dmo': snap_dir / 'dmo.npz',
        'hydro': snap_dir / 'hydro.npz',
    }
    
    # Add replace files
    for mt in MASS_THRESHOLDS:
        files[f'replace_{mt}'] = snap_dir / f'replace_{mt}.npz'
    
    # Add BCM files
    for bcm in BCM_MODELS:
        for mt in MASS_THRESHOLDS:
            files[f'bcm_{bcm}_{mt}'] = snap_dir / f'bcm_{bcm}_{mt}.npz'
    
    return {k: v.exists() for k, v in files.items()}

# Check all snapshots
print("File Inventory:")
print("=" * 80)
for snap in SNAPSHOTS:
    status = check_files(snap)
    n_exists = sum(status.values())
    n_total = len(status)
    
    # Summarize
    has_dmo = '✓' if status['dmo'] else '✗'
    has_hydro = '✓' if status['hydro'] else '✗'
    n_replace = sum(1 for k, v in status.items() if k.startswith('replace') and v)
    n_bcm = sum(1 for k, v in status.items() if k.startswith('bcm') and v)
    
    print(f"Snap {snap:3d}: DMO {has_dmo} | Hydro {has_hydro} | Replace {n_replace}/4 | BCM {n_bcm}/12")

In [None]:
# Load and display maps for a specific snapshot
SNAP = 99  # Change as needed
MASS_CUT = 'Mgt12.5'  # Change as needed

snap_dir = OUTPUT_DIR / f'L205n{SIM_RES}TNG' / f'snap{SNAP:03d}' / 'projected'

def load_map(filename):
    """Load a map file if it exists."""
    path = snap_dir / filename
    if path.exists():
        return np.load(path)['field']
    return None

# Load maps
maps = {
    'DMO': load_map('dmo.npz'),
    'Hydro': load_map('hydro.npz'),
    'Replace': load_map(f'replace_{MASS_CUT}.npz'),
    'Arico20': load_map(f'bcm_arico20_{MASS_CUT}.npz'),
    'Schneider19': load_map(f'bcm_schneider19_{MASS_CUT}.npz'),
    'Schneider25': load_map(f'bcm_schneider25_{MASS_CUT}.npz'),
}

# Report what we have
print(f"Snapshot {SNAP}, Mass cut: {MASS_CUT}")
for name, m in maps.items():
    if m is not None:
        print(f"  {name}: {m.shape}, range [{m.min():.2e}, {m.max():.2e}]")
    else:
        print(f"  {name}: NOT FOUND")

In [None]:
# Visual comparison of all maps
available = {k: v for k, v in maps.items() if v is not None}
n_maps = len(available)

if n_maps == 0:
    print("No maps available!")
else:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    # Common colorscale
    all_vals = np.concatenate([v.flatten() for v in available.values()])
    vmin, vmax = np.percentile(all_vals[all_vals > 0], [1, 99])
    
    for i, (name, m) in enumerate(available.items()):
        if i >= 6:
            break
        ax = axes[i]
        # Show center 1000x1000 region for detail
        center = m.shape[0] // 2
        half = 500
        region = m[center-half:center+half, center-half:center+half]
        
        im = ax.imshow(region, norm=LogNorm(vmin=max(vmin, 1e-3), vmax=vmax),
                       cmap='magma', origin='lower')
        ax.set_title(name)
        ax.set_xlabel('x [pixels]')
        ax.set_ylabel('y [pixels]')
        plt.colorbar(im, ax=ax, label='Surface density')
    
    # Turn off unused axes
    for i in range(len(available), 6):
        axes[i].axis('off')
    
    plt.suptitle(f'Snapshot {SNAP}, {MASS_CUT} (center 1000×1000 pixels)', fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
# Ratio maps (relative to DMO)
if 'DMO' in available and len(available) > 1:
    dmo_map = available['DMO']
    other_maps = {k: v for k, v in available.items() if k != 'DMO'}
    
    n_others = len(other_maps)
    ncols = min(3, n_others)
    nrows = (n_others + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 4*nrows))
    if n_others == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i, (name, m) in enumerate(other_maps.items()):
        ax = axes[i]
        
        # Compute ratio (avoid division by zero)
        ratio = np.ones_like(m)
        mask = dmo_map > 0
        ratio[mask] = m[mask] / dmo_map[mask]
        
        # Show center region
        center = m.shape[0] // 2
        half = 500
        region = ratio[center-half:center+half, center-half:center+half]
        
        im = ax.imshow(region, norm=LogNorm(vmin=0.5, vmax=2.0),
                       cmap='RdBu_r', origin='lower')
        ax.set_title(f'{name} / DMO')
        plt.colorbar(im, ax=ax, label='Ratio')
    
    # Turn off unused axes
    for i in range(n_others, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(f'Ratio to DMO - Snap {SNAP}, {MASS_CUT}', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("Need DMO map to compute ratios")

In [None]:
# 2D Power spectrum comparison
def compute_2d_power_spectrum(field):
    """Compute azimuthally averaged 2D power spectrum."""
    # FFT
    fft = np.fft.fft2(field)
    fft_shift = np.fft.fftshift(fft)
    power = np.abs(fft_shift)**2
    
    # Radial average
    ny, nx = field.shape
    y, x = np.ogrid[-ny//2:ny//2, -nx//2:nx//2]
    r = np.sqrt(x**2 + y**2).astype(int)
    
    # Bin by radius
    r_max = min(nx, ny) // 2
    tbin = np.bincount(r.ravel(), power.ravel())
    nr = np.bincount(r.ravel())
    radial_power = tbin / np.maximum(nr, 1)
    
    k = np.arange(len(radial_power))
    return k[:r_max], radial_power[:r_max]

if len(available) > 0:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(available)))
    
    for (name, m), color in zip(available.items(), colors):
        k, pk = compute_2d_power_spectrum(m)
        ax.loglog(k[1:], pk[1:], label=name, color=color, alpha=0.8)
    
    ax.set_xlabel('k [1/pixel]')
    ax.set_ylabel('P(k)')
    ax.set_title(f'2D Power Spectrum - Snap {SNAP}, {MASS_CUT}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Power spectrum ratios to DMO
if 'DMO' in available and len(available) > 1:
    k_dmo, pk_dmo = compute_2d_power_spectrum(available['DMO'])
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(available) - 1))
    
    for (name, m), color in zip([(k, v) for k, v in available.items() if k != 'DMO'], colors):
        k, pk = compute_2d_power_spectrum(m)
        ratio = pk / pk_dmo
        ax.semilogx(k[1:], ratio[1:], label=name, color=color, alpha=0.8)
    
    ax.axhline(1.0, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel('k [1/pixel]')
    ax.set_ylabel('P(k) / P_DMO(k)')
    ax.set_title(f'Power Spectrum Ratio to DMO - Snap {SNAP}, {MASS_CUT}')
    ax.legend()
    ax.set_ylim(0.8, 1.2)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Pixel value distribution comparison
if len(available) > 0:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(available)))
    
    for (name, m), color in zip(available.items(), colors):
        # Histogram of log values (excluding zeros)
        vals = m[m > 0].flatten()
        ax.hist(np.log10(vals), bins=100, alpha=0.5, label=name, color=color, density=True)
    
    ax.set_xlabel('log10(pixel value)')
    ax.set_ylabel('Density')
    ax.set_title(f'Pixel Value Distribution - Snap {SNAP}, {MASS_CUT}')
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Summary statistics
print("=" * 70)
print(f"MAP STATISTICS - Snapshot {SNAP}, {MASS_CUT}")
print("=" * 70)
print(f"{'Model':<15} {'Shape':<12} {'Min':>12} {'Max':>12} {'Mean':>12} {'Std':>12}")
print("-" * 70)

for name, m in available.items():
    print(f"{name:<15} {str(m.shape):<12} {m.min():>12.2e} {m.max():>12.2e} {m.mean():>12.2e} {m.std():>12.2e}")

# Total mass check
print("\nTotal mass (sum of pixels):")
if 'DMO' in available:
    dmo_total = available['DMO'].sum()
    for name, m in available.items():
        total = m.sum()
        ratio = total / dmo_total if dmo_total > 0 else 0
        print(f"  {name}: {total:.4e} ({ratio:.4f} of DMO)")