In [2]:
import numpy as np
from scipy.stats import qmc
import limpy.lines as ll
import limpy.powerspectra as lp
from pathlib import Path
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.gridspec as gridspec

class LIMpyGenerator:
    """
    Generate Line Intensity Maps with flexible parameter combinations.
    
    Two operational modes:
    1. VARYING COSMOLOGY: Different halo catalogs for each parameter set (cosmo + astro params)
    2. FIXED COSMOLOGY: Single halo catalog with varying astrophysical parameters only
    """
    
    def __init__(self, boxsize=80.0, ngrid=256, redshift=3.6, 
                 line_name='CII158', model_name='Alma_scaling',
                 param_names=['sigma_8', 'omega_m'], seed=42, 
                 noise_level=1e1, noise_mode='fixed',
                 noise_range=(1e-5, 1e3),
                 log_noise_range=(0, 4),
                 mode='varying_cosmology'):
        """
        Initialize generator
        
        Args:
            mode: 'varying_cosmology' or 'fixed_cosmology'
            param_names: List of parameters to vary
                - For varying_cosmology: can include cosmo + astro params
                - For fixed_cosmology: should only include astro params (cosmo will be fixed from halocat)
        """
        self.boxsize = boxsize
        self.ngrid = ngrid
        self.redshift = redshift
        self.line_name = line_name
        self.model_name = model_name
        self.mmin = 1e9
        self.seed = seed
        self.mode = mode
        
        # Noise configuration
        self.noise_mode = noise_mode
        self.noise_level = noise_level
        self.noise_range = noise_range
        self.log_noise_range = log_noise_range
        
        # Check if noise is in param_names
        self.noise_as_param = 'noise_level' in param_names
        self.log_noise_as_param = 'log_noise_level' in param_names
        
        if self.noise_as_param or self.log_noise_as_param:
            self.noise_mode = 'varying'
            param_names_filtered = [p for p in param_names if p not in ['noise_level', 'log_noise_level']]
        else:
            param_names_filtered = param_names
        
        # Parameter configuration
        self.param_names_internal = param_names_filtered
        self.param_names = param_names
        self.cosmological_params = ['sigma_8', 'omega_m', 'h']
        self.astrophysical_params = [p for p in param_names_filtered if p not in self.cosmological_params]
        
        # Validate mode and parameters
        if mode == 'fixed_cosmology':
            cosmo_in_params = [p for p in param_names_filtered if p in self.cosmological_params]
            if cosmo_in_params:
                print(f"WARNING: In fixed_cosmology mode, cosmological parameters {cosmo_in_params} will be ignored")
                print("They will be taken from the fixed halo catalog instead")
        
        print(f"Mode: {mode}")
        print(f"Parameters to vary: {self.param_names}")
        print(f"  Cosmological: {[p for p in param_names_filtered if p in self.cosmological_params]}")
        print(f"  Astrophysical: {self.astrophysical_params}")
        print(f"Noise mode: {self.noise_mode}")
        
        # Storage
        self.params_array = None
        self.param_ranges = {}
        self.latin_hypercube_samples = None
        self.noise_values = None
        self.log_noise_values = None
        self.intensity_maps = []
        self.power_spectra = []
        self.pdfs = []
        self.k_values = None
        self.pdf_bins = None
        
        # For fixed cosmology mode
        self.fixed_halocat = None
        self.fixed_cosmo_params = None
        self.fixed_halocat_file = None
    
    def setup_parameter_ranges(self, param_ranges=None):
        """Setup parameter ranges for astrophysical parameters"""
        default_ranges = {
            'a_off': (4, 10.0),
            'b_off': (0.0, 2.0)
        }
        
        if param_ranges:
            self.param_ranges.update(param_ranges)
        else:
            for param in self.astrophysical_params:
                if param in default_ranges:
                    self.param_ranges[param] = default_ranges[param]
        
        # Add noise range if treating as parameter
        if self.noise_as_param:
            self.param_ranges['noise_level'] = self.noise_range
        elif self.log_noise_as_param:
            self.param_ranges['log_noise_level'] = self.log_noise_range
    
    def generate_latin_hypercube_samples(self, n_samples):
        """Generate Latin hypercube samples for astrophysical parameters"""
        n_astro_params = len(self.astrophysical_params)
        
        if n_astro_params == 0:
            print("No astrophysical parameters to sample")
            self.latin_hypercube_samples = None
            return None
        
        sampler = qmc.LatinHypercube(d=n_astro_params, seed=self.seed)
        samples_unit = sampler.random(n=n_samples)
        
        samples_scaled = np.zeros((n_samples, n_astro_params))
        for i, param in enumerate(self.astrophysical_params):
            low, high = self.param_ranges[param]
            samples_scaled[:, i] = qmc.scale(samples_unit[:, i:i+1], low, high).flatten()
        
        self.latin_hypercube_samples = samples_scaled
        
        print(f"Generated {n_samples} Latin hypercube samples for {self.astrophysical_params}")
        for i, param in enumerate(self.astrophysical_params):
            print(f"  {param}: [{samples_scaled[:, i].min():.3f}, {samples_scaled[:, i].max():.3f}]")
        
        return samples_scaled
    
    def generate_noise_levels(self, n_samples, method='uniform'):
        """Generate noise levels for varying noise mode"""
        if self.noise_mode == 'fixed':
            self.noise_values = np.full(n_samples, self.noise_level)
            self.log_noise_values = np.log10(self.noise_values)
            return self.noise_values
        
        rng = np.random.RandomState(self.seed + 1000)
        
        if self.log_noise_as_param:
            log_min, log_max = self.log_noise_range
            
            if method == 'uniform' or method == 'log_uniform':
                log_noise_levels = rng.uniform(log_min, log_max, n_samples)
            elif method == 'latin_hypercube':
                sampler = qmc.LatinHypercube(d=1, seed=self.seed + 1000)
                samples = sampler.random(n=n_samples)
                log_noise_levels = qmc.scale(samples, log_min, log_max).flatten()
            
            self.log_noise_values = log_noise_levels
            self.noise_values = 10**log_noise_levels
        else:
            min_noise, max_noise = self.noise_range
            
            if method == 'uniform':
                noise_levels = rng.uniform(min_noise, max_noise, n_samples)
            elif method == 'log_uniform':
                log_min = np.log10(min_noise)
                log_max = np.log10(max_noise)
                noise_levels = 10**rng.uniform(log_min, log_max, n_samples)
            elif method == 'latin_hypercube':
                sampler = qmc.LatinHypercube(d=1, seed=self.seed + 1000)
                samples = sampler.random(n=n_samples)
                log_min = np.log10(min_noise)
                log_max = np.log10(max_noise)
                noise_levels = 10**qmc.scale(samples, log_min, log_max).flatten()
            
            self.noise_values = noise_levels
            self.log_noise_values = np.log10(noise_levels)
        
        return self.noise_values
    
    def load_fixed_halocat(self, halocat_file):
        """Load a single halo catalog for fixed cosmology mode"""
        self.fixed_halocat_file = halocat_file
        data = np.load(halocat_file, allow_pickle=True)
        
        # Extract cosmological parameters
        if 'params' in data:
            params = data['params']
            if hasattr(params, 'item'):
                params = params.item()
        else:
            params = {}
        
        self.fixed_cosmo_params = {
            'sigma_8': params.get('sigma_8', params.get('sigma8', 0.8)),
            'omega_m': params.get('omega_m', params.get('Omega_m', 0.3)),
            'h': params.get('h', 0.7)
        }
        
        self.fixed_halocat = data
        
        print(f"Fixed cosmology loaded from {Path(halocat_file).name}:")
        for key, val in self.fixed_cosmo_params.items():
            print(f"  {key} = {val:.3f}")
        
        return self.fixed_cosmo_params
    
    def process_batch(self, halocat_files=None, n_samples=None, store_maps=False):
        """
        Main processing method that handles both modes
        
        For varying_cosmology mode:
            - Requires halocat_files (list of paths)
            - Each halocat provides its own cosmology
            
        For fixed_cosmology mode:
            - Requires n_samples (number of realizations)
            - Must call load_fixed_halocat() first
        """
        if self.mode == 'fixed_cosmology':
            if self.fixed_halocat is None:
                raise ValueError("For fixed_cosmology mode, call load_fixed_halocat() first")
            return self._process_fixed_cosmology_batch(n_samples, store_maps)
        else:
            if halocat_files is None:
                raise ValueError("For varying_cosmology mode, provide halocat_files")
            return self._process_varying_cosmology_batch(halocat_files, store_maps)
    
    def _process_fixed_cosmology_batch(self, n_samples, store_maps):
        """Process with fixed cosmology, varying only astrophysics"""
        if n_samples is None:
            raise ValueError("Specify n_samples for fixed_cosmology mode")
        
        # Generate astrophysical parameters if needed
        if self.astrophysical_params and self.latin_hypercube_samples is None:
            self.generate_latin_hypercube_samples(n_samples)
        
        # Generate noise levels
        noise_levels = self.generate_noise_levels(n_samples, method='latin_hypercube')
        
        rng = np.random.RandomState(self.seed)
        
        all_params = []
        all_pk = []
        all_pdf = []
        
        print(f"\nProcessing {n_samples} realizations with fixed cosmology")
        print(f"Fixed cosmology: {self.fixed_cosmo_params}")
        print(f"Store maps: {store_maps}")
        
        for i in tqdm(range(n_samples), desc="Processing"):
            try:
                # Build parameters
                params_internal = self.fixed_cosmo_params.copy()
                
                # Add astrophysical parameters
                if self.latin_hypercube_samples is not None:
                    for j, param in enumerate(self.astrophysical_params):
                        params_internal[param] = self.latin_hypercube_samples[i, j]
                
                # Process with LIMpy
                intensity_2d, _, _ = self.process_single_halocat(
                    self.fixed_halocat_file, 
                    index=i, 
                    astro_params=self.latin_hypercube_samples[i] if self.latin_hypercube_samples is not None else None
                )
                
                # Add noise
                current_noise_level = noise_levels[i]
                noise = rng.normal(0, current_noise_level, intensity_2d.shape)
                intensity_2d_noisy = intensity_2d + noise
                
                # FIXED: Store maps if requested
                if store_maps:
                    self.intensity_maps.append(intensity_2d_noisy.copy())
                
                # Compute observables
                k, pk = self.compute_power_spectrum_2d(intensity_2d_noisy)
                bins, pdf = self.compute_pdf(intensity_2d_noisy, n_bins=40)
                
                # Store parameters (only the ones we're varying)
                param_values = []
                for param_name in self.param_names_internal:
                    if param_name in self.cosmological_params:
                        # For fixed cosmology mode, still include cosmo params in output if requested
                        param_values.append(params_internal[param_name])
                    else:
                        param_values.append(params_internal[param_name])
                
                # Add noise parameter if varying
                if self.log_noise_as_param:
                    param_values.append(self.log_noise_values[i])
                elif self.noise_as_param:
                    param_values.append(current_noise_level)
                
                all_params.append(param_values)
                all_pk.append(pk)
                all_pdf.append(pdf)
                
                if self.k_values is None:
                    self.k_values = k
                if self.pdf_bins is None:
                    self.pdf_bins = bins
                    
            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        self.params_array = np.array(all_params)
        self.power_spectra = all_pk
        self.pdfs = all_pdf
        
        print(f"\nStored {len(self.intensity_maps)} maps in memory")
        
        return len(all_params)
    
    def _process_varying_cosmology_batch(self, halocat_files, store_maps):
        """Process with varying cosmology using different halocats"""
        n_files = len(halocat_files)
        
        # Generate astrophysical parameters if needed
        if self.astrophysical_params and self.latin_hypercube_samples is None:
            self.generate_latin_hypercube_samples(n_files)
        
        # Generate noise levels
        noise_levels = self.generate_noise_levels(n_files, method='latin_hypercube')
        
        rng = np.random.RandomState(self.seed)
        
        all_params = []
        all_pk = []
        all_pdf = []
        
        print(f"\nProcessing {n_files} halocats with varying cosmology")
        print(f"Store maps: {store_maps}")
        
        for i, halocat_file in enumerate(tqdm(halocat_files, desc="Processing")):
            try:
                astro_params = None
                if self.latin_hypercube_samples is not None:
                    astro_params = self.latin_hypercube_samples[i]
                
                intensity_2d, _, params = self.process_single_halocat(
                    halocat_file, index=i, astro_params=astro_params
                )
                
                # Add noise
                current_noise_level = noise_levels[i]
                noise = rng.normal(0, current_noise_level, intensity_2d.shape)
                intensity_2d_noisy = intensity_2d + noise
                
                # FIXED: Store maps if requested
                if store_maps:
                    self.intensity_maps.append(intensity_2d_noisy.copy())
                
                # Compute observables
                k, pk = self.compute_power_spectrum_2d(intensity_2d_noisy)
                bins, pdf = self.compute_pdf(intensity_2d_noisy, n_bins=40)
                
                # Store parameters
                param_values = [params[p] for p in self.param_names_internal]
                if self.log_noise_as_param:
                    param_values.append(self.log_noise_values[i])
                elif self.noise_as_param:
                    param_values.append(current_noise_level)
                
                all_params.append(param_values)
                all_pk.append(pk)
                all_pdf.append(pdf)
                
                if self.k_values is None:
                    self.k_values = k
                if self.pdf_bins is None:
                    self.pdf_bins = bins
                    
            except Exception as e:
                print(f"Error processing {halocat_file.name}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        self.params_array = np.array(all_params)
        self.power_spectra = all_pk
        self.pdfs = all_pdf
        
        print(f"\nStored {len(self.intensity_maps)} maps in memory")
        
        return len(all_params)
    
    def process_single_halocat(self, halocat_file, index=0, astro_params=None):
        """Process single halocat with parameter combination"""
        data = np.load(halocat_file, allow_pickle=True)
        
        if 'params' in data:
            params = data['params']
            if hasattr(params, 'item'):
                params = params.item()
        else:
            params = {}
        
        # Build parameter dictionary
        params_internal = {}
        
        # Add cosmological parameters
        for cosmo_param in ['sigma_8', 'omega_m', 'h']:
            if cosmo_param == 'sigma_8':
                params_internal['sigma_8'] = params.get('sigma_8', params.get('sigma8', 0.8))
            elif cosmo_param == 'omega_m':
                params_internal['omega_m'] = params.get('omega_m', params.get('Omega_m', 0.3))
            elif cosmo_param == 'h':
                params_internal['h'] = params.get('h', 0.7)
        
        # Add astrophysical parameters
        if astro_params is not None:
            for i, param in enumerate(self.astrophysical_params):
                params_internal[param] = astro_params[i]
        
        # Fix boundary conditions
        halocat_file_to_use = self._fix_boundary_conditions(data, halocat_file)
        
        try:
            lim_sim = ll.lim_sims(
                halocat_file_to_use,
                self.redshift,
                model_name=self.model_name,
                line_name=self.line_name,
                halo_cutoff_mass=self.mmin,
                halocat_type="input_cat",
                parameters=params_internal,
                ngrid_x=self.ngrid,
                ngrid_y=self.ngrid,
                ngrid_z=self.ngrid,
                boxsize_x=self.boxsize,
                boxsize_y=self.boxsize,
                boxsize_z=self.boxsize,
                nu_obs=220,
                theta_fwhm=1,
                dnu_obs=2.2
            )
            
            intensity_3d = lim_sim.make_intensity_grid()
            intensity_2d = np.mean(intensity_3d, axis=2)
            
        finally:
            if halocat_file_to_use != str(halocat_file):
                import os
                try:
                    os.unlink(halocat_file_to_use)
                except:
                    pass
        
        return intensity_2d, intensity_3d, params_internal
    
    def _fix_boundary_conditions(self, data, halocat_file):
        """Fix boundary conditions in halo positions"""
        import tempfile
        
        if 'x' in data or 'X' in data or 'pos' in data:
            temp_file = tempfile.NamedTemporaryFile(suffix='.npz', delete=False)
            save_dict = {key: data[key] for key in data.keys()}
            
            if 'pos' in data:
                positions = data['pos'].copy()
                positions = positions % self.boxsize
                positions[positions >= self.boxsize] = self.boxsize - 1e-6
                save_dict['pos'] = positions
            else:
                for coord in ['x', 'X', 'y', 'Y', 'z', 'Z']:
                    if coord in data:
                        pos = data[coord].copy()
                        pos = pos % self.boxsize
                        pos[pos >= self.boxsize] = self.boxsize - 1e-6
                        save_dict[coord] = pos
            
            np.savez(temp_file.name, **save_dict)
            return temp_file.name
        return str(halocat_file)
    
    def compute_power_spectrum_2d(self, intensity_2d):
        """Compute 2D power spectrum"""
        k, pk = lp.get_pk2d(intensity_2d, self.boxsize, self.boxsize, 
                           self.ngrid, self.ngrid)
        return k, pk
    
    def compute_pdf(self, intensity_map, n_bins=40):
        """Compute PDF"""
        flat_intensity = intensity_map.flatten()
        bin_edges = np.linspace(flat_intensity.min(), flat_intensity.max(), n_bins + 1)
        pdf, _ = np.histogram(flat_intensity, bins=bin_edges, density=True)
        return bin_edges, pdf
    
    def save_results(self, output_dir, seed=None, save_intensity_maps=False):
        """
        Save results with metadata
        
        Args:
            output_dir: Directory to save results
            seed: Random seed for file naming
            save_intensity_maps: If True, save intensity maps to disk (can be large!)
        """
        if seed is None:
            seed = self.seed
        
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        # Create structured array
        dtype_list = [(name, '<f8') for name in self.param_names]
        params_structured = np.zeros(len(self.params_array), dtype=dtype_list)
        
        for i, param_vals in enumerate(self.params_array):
            params_structured[i] = tuple(param_vals)
        
        # File naming
        ps_file = output_path / f"ps_80_256.npz"
        pdf_file = output_path / f"pdf_80_256.npz"
        
        all_k = np.array([self.k_values] * len(self.power_spectra))
        
        # Metadata
        metadata = {
            'seed': seed,
            'mode': self.mode,
            'noise_mode': self.noise_mode,
        }
        
        if self.noise_mode == 'fixed':
            metadata['noise_level'] = self.noise_level
        else:
            if self.log_noise_as_param:
                metadata['log_noise_range'] = self.log_noise_range
                metadata['log_noise_values'] = self.log_noise_values
            else:
                metadata['noise_range'] = self.noise_range
            metadata['noise_values'] = self.noise_values
        
        if self.mode == 'fixed_cosmology' and self.fixed_cosmo_params:
            metadata['fixed_cosmology'] = self.fixed_cosmo_params
        
        np.savez_compressed(
            ps_file,
            parameters=params_structured,
            k_2d=all_k,
            pk_2d=np.array(self.power_spectra),
            **metadata
        )
        
        np.savez_compressed(
            pdf_file,
            parameters=params_structured,
            lim_bin_edges_linear_2d=self.pdf_bins,
            lim_pdf_linear_2d=np.array(self.pdfs),
            **metadata
        )
        
        # Save intensity maps if requested and available
        if save_intensity_maps and len(self.intensity_maps) > 0:
            maps_file = output_path / f"intensity_maps_80_256.npz"
            print(f"\nSaving {len(self.intensity_maps)} intensity maps to disk...")
            np.savez_compressed(
                maps_file,
                parameters=params_structured,
                intensity_maps=np.array(self.intensity_maps),
                **metadata
            )
            print(f"  Intensity maps: {maps_file}")
            print(f"  File size: ~{maps_file.stat().st_size / 1e9:.2f} GB")
        
        # Save parameter configuration
        config = {
            'mode': self.mode,
            'param_names': self.param_names,
            'param_ranges': self.param_ranges,
            'seed': seed,
            'noise_mode': self.noise_mode,
            'noise_level': self.noise_level if self.noise_mode == 'fixed' else None,
            'noise_range': self.noise_range if self.noise_as_param else None,
            'log_noise_range': self.log_noise_range if self.log_noise_as_param else None,
            'boxsize': self.boxsize,
            'ngrid': self.ngrid,
            'redshift': self.redshift,
            'line_name': self.line_name,
            'model_name': self.model_name,
            'n_maps_stored': len(self.intensity_maps)
        }
        
        if self.mode == 'fixed_cosmology' and self.fixed_cosmo_params:
            config['fixed_cosmology'] = self.fixed_cosmo_params
        
        config_file = output_path / 'parameter_config.json'
        with open(config_file, 'w') as f:
            json.dump(config, f, indent=2)
        
        print(f"\nResults saved:")
        print(f"  Power spectra: {ps_file}")
        print(f"  PDFs: {pdf_file}")
        print(f"  Config: {config_file}")
        print(f"\nMode: {self.mode}")
        print(f"Seed: {seed}")
        print(f"Maps stored in memory: {len(self.intensity_maps)}")
        print(f"\nParameter summary ({len(self.params_array)} samples):")
        for i, name in enumerate(self.param_names):
            values = self.params_array[:, i]
            print(f"  {name}: [{values.min():.3f}, {values.max():.3f}]")


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

def example_varying_cosmology(n_maps=100, input_dir=None, output_dir=None, 
                              store_maps=False, seed=42, noise_level=1e-5):
    """
    EXAMPLE 1: Varying Cosmology Mode
    Uses different halo catalogs, each with its own cosmology
    """
    print("="*80)
    print("EXAMPLE 1: VARYING COSMOLOGY MODE")
    print("Uses different halo catalogs for each parameter combination")
    print("="*80)
    
    if input_dir is None:
        input_dir = Path("/Users/anirbanroy/Desktop/halocats/")
    else:
        input_dir = Path(input_dir)
    
    if output_dir is None:
        output_dir = f"/Users/anirbanroy/Desktop/processed_varying_cosmo_{n_maps}maps/"
    
    generator = LIMpyGenerator(
        param_names=['sigma_8', 'omega_m', 'a_off', 'b_off'],
        mode='varying_cosmology',
        seed=seed,
        noise_level=noise_level,
        noise_mode='fixed'
    )
    
    generator.setup_parameter_ranges({
        'a_off': (4, 10),
        'b_off': (0, 2)
    })
    
    all_halocat_files = sorted(input_dir.glob("halocat_*.npz"))
    
    if len(all_halocat_files) < n_maps:
        print(f"WARNING: Only {len(all_halocat_files)} halo catalogs available, but {n_maps} requested")
        n_maps = len(all_halocat_files)
    
    halocat_files = all_halocat_files[:n_maps]
    
    print(f"\nProcessing {len(halocat_files)} halo catalogs")
    print("Each catalog provides its own cosmology (sigma_8, omega_m)")
    print("Astrophysical parameters (a_off, b_off) are sampled via Latin hypercube")
    print(f"Store maps in memory: {store_maps}")
    print(f"Output directory: {output_dir}")
    
    n_processed = generator.process_batch(
        halocat_files=halocat_files,
        store_maps=store_maps
    )
    
    generator.save_results(output_dir, seed=seed, save_intensity_maps=store_maps)
    
    print(f"\nProcessed {n_processed} parameter combinations")
    
    return generator


def example_fixed_cosmology(n_maps=100, input_dir=None, output_dir=None, 
                              store_maps=False, seed=42, noise_level=1e-5):
    """
    EXAMPLE 2: Fixed Cosmology Mode
    Uses ONE halo catalog repeatedly with different astrophysical parameters
    """
    print("="*80)
    print("EXAMPLE 2: FIXED COSMOLOGY MODE")
    print("Uses single halo catalog with varying astrophysical parameters only")
    print("="*80)
    
    generator = LIMpyGenerator(
        param_names=['a_off', 'b_off'],
        mode='fixed_cosmology',
        seed=seed,
        noise_level=noise_level,
        noise_mode='fixed'
    )
    
    generator.setup_parameter_ranges({
        'a_off': (4, 10),
        'b_off': (0, 2)
    })
    
    single_halocat = "/Users/anirbanroy/Desktop/halocats/halocat_1396.npz"
    fixed_cosmo = generator.load_fixed_halocat(single_halocat)
    
    print(f"\nGenerating {n_maps} realizations with fixed cosmology")
    print("Only astrophysical parameters (a_off, b_off) will vary")
    
    n_processed = generator.process_batch(
        n_samples=n_maps,
        store_maps=store_maps
    )
    
    if output_dir is None:
        output_dir = "/Users/anirbanroy/Desktop/processed_fixed_cosmo/"
    generator.save_results(output_dir, seed=seed, save_intensity_maps=store_maps)
    
    print(f"\nProcessed {n_processed} parameter combinations")
    
    return generator


def example_fixed_cosmology_with_noise(n_maps=100, seed=42):
    """
    EXAMPLE 3: Fixed Cosmology with Noise as Parameter
    Uses ONE halo catalog with varying astrophysical parameters AND noise
    """
    print("="*80)
    print("EXAMPLE 3: FIXED COSMOLOGY WITH NOISE AS PARAMETER")
    print("Uses single halo catalog with varying astro params + noise level")
    print("="*80)
    
    generator = LIMpyGenerator(
        param_names=['sigma_8', 'omega_m', 'a_off', 'b_off', 'log_noise_level'],
        mode='fixed_cosmology',
        seed=seed,
        log_noise_range=(0, 4),
    )
    
    generator.setup_parameter_ranges({
        'a_off': (4, 10),
        'b_off': (0, 2)
    })
    
    single_halocat = "/Users/anirbanroy/Desktop/halocats/halocat_1396.npz"
    generator.load_fixed_halocat(single_halocat)
    
    print(f"\nGenerating {n_maps} realizations with:")
    print("  - Fixed cosmology from halocat")
    print("  - Varying a_off and b_off")
    print("  - Varying log_noise_level from 0 to 4")
    
    n_processed = generator.process_batch(
        n_samples=n_maps,
        store_maps=True
    )
    
    output_dir = "/Users/anirbanroy/Desktop/processed_fixed_cosmo_with_noise/"
    generator.save_results(output_dir, seed=seed, save_intensity_maps=store_maps)
    
    print(f"\nProcessed {n_processed} parameter combinations")
    
    return generator


def example_hybrid_mode(n_maps=50, seed=42):
    """
    EXAMPLE 4: Hybrid - Include cosmo params in output even with fixed cosmology
    This is useful when you want the output to have cosmo params for compatibility
    """
    print("="*80)
    print("EXAMPLE 4: FIXED COSMOLOGY BUT INCLUDE COSMO PARAMS IN OUTPUT")
    print("Useful for maintaining consistent output format")
    print("="*80)
    
    generator = LIMpyGenerator(
        param_names=['sigma_8', 'omega_m', 'a_off', 'b_off'],
        mode='fixed_cosmology',
        seed=seed,
        noise_level=1e-5,
        noise_mode='fixed'
    )
    
    generator.setup_parameter_ranges({
        'a_off': (4, 10),
        'b_off': (0, 2)
    })
    
    single_halocat = "/Users/anirbanroy/Desktop/halocats/halocat_1396.npz"
    fixed_cosmo = generator.load_fixed_halocat(single_halocat)
    
    print("\nNote: sigma_8 and omega_m will be fixed from the halocat")
    print("They will appear in the output but with constant values")
    
    n_processed = generator.process_batch(
        n_samples=n_maps,
        store_maps=False
    )
    
    output_dir = "/Users/anirbanroy/Desktop/processed_hybrid/"
    generator.save_results(output_dir, seed=seed, save_intensity_maps=store_maps)
    
    print(f"\nProcessed {n_processed} parameter combinations")
    print("Output will have 4 parameters, but sigma_8 and omega_m are constant")
    
    return generator


def example_varying_all_parameters(n_maps=3000, input_dir=None, output_dir=None, 
                                  store_maps=False, seed=42):
    """
    EXAMPLE 5: Varying Cosmology + Astrophysics + Noise
    Uses different halo catalogs (varying cosmology) with varying astro params and noise
    """
    print("="*80)
    print("EXAMPLE 5: VARYING ALL PARAMETERS (COSMO + ASTRO + NOISE)")
    print("Uses different halo catalogs with varying everything")
    print("="*80)
    
    if input_dir is None:
        input_dir = Path("/Users/anirbanroy/Desktop/halocats/")
    else:
        input_dir = Path(input_dir)
    
    if output_dir is None:
        output_dir = f"/Users/anirbanroy/Desktop/processed_all_params_noise_{n_maps}maps/"
    
    generator = LIMpyGenerator(
        param_names=['sigma_8', 'omega_m', 'a_off', 'b_off', 'log_noise_level'],
        mode='varying_cosmology',
        seed=seed,
        log_noise_range=(-5, 3),
    )
    
    generator.setup_parameter_ranges({
        'a_off': (4, 10),
        'b_off': (0, 2)
    })
    
    all_halocat_files = sorted(input_dir.glob("halocat_*.npz"))
    
    if len(all_halocat_files) < n_maps:
        print(f"WARNING: Only {len(all_halocat_files)} halo catalogs available, but {n_maps} requested")
        n_maps = len(all_halocat_files)
    
    halocat_files = all_halocat_files[:n_maps]
    
    print(f"\nProcessing {len(halocat_files)} halo catalogs")
    print("Parameters varying:")
    print("  - sigma_8, omega_m: from halo catalogs")
    print("  - a_off: [4, 10]")
    print("  - b_off: [0, 2]")
    print(f"  - log_noise_level: [{generator.log_noise_range[0]}, {generator.log_noise_range[1]}]")
    print(f"    (noise level: [10^{generator.log_noise_range[0]:.1f}, 10^{generator.log_noise_range[1]:.1f}])")
    print(f"Store maps in memory: {store_maps}")
    print(f"Output directory: {output_dir}")
    
    n_processed = generator.process_batch(
        halocat_files=halocat_files,
        store_maps=store_maps
    )
    
    generator.save_results(output_dir, seed=seed, save_intensity_maps=store_maps)
    
    print(f"\nProcessed {n_processed} parameter combinations")
    print("\nParameter summary:")
    for i, name in enumerate(generator.param_names):
        values = generator.params_array[:, i]
        print(f"  {name}: [{values.min():.3f}, {values.max():.3f}] (mean={values.mean():.3f}, std={values.std():.3f})")
    
    if generator.log_noise_values is not None:
        print(f"\nNoise level distribution:")
        print(f"  Min noise: {10**generator.log_noise_values.min():.2e}")
        print(f"  Max noise: {10**generator.log_noise_values.max():.2e}")
        print(f"  Median noise: {10**np.median(generator.log_noise_values):.2e}")
    
    return generator


if __name__ == "__main__":
    # ========================================================================
    # QUICK EXAMPLES - Uncomment the one you want to run
    # ========================================================================
    
    # Example 1: Varying cosmology with 10 maps
    # generator = example_varying_cosmology(n_maps=10, store_maps=True)
    
    generator = example_varying_cosmology(
        n_maps=3000,
        input_dir="/Users/anirbanroy/Desktop/halocats/",
        output_dir="/Users/anirbanroy/Desktop/test_maps_all_params/",
        store_maps=True,
        seed=42
    )
    
    
    # Example 2: Varying cosmology with 3000 maps (full dataset)
    # generator = example_varying_cosmology(n_maps=3000, store_maps=False)
    
    # Example 3: Varying all parameters including noise
    #generator = example_varying_all_parameters(
    #    n_maps=3000,
    #    input_dir="/Users/anirbanroy/Desktop/halocats/",
    #    output_dir="/Users/anirbanroy/Desktop/test_maps_all_params_varying_noise/",
    #    store_maps=True,
    #    seed=42
    #)
    
    # Example 4: Fixed cosmology mode
    # generator = example_fixed_cosmology(n_maps=100, store_maps=True)
    
    # Example 5: Fixed cosmology with noise as parameter
    # generator = example_fixed_cosmology_with_noise(n_maps=100)
    
    # Example 6: Hybrid mode for compatibility
    # generator = example_hybrid_mode(n_maps=50)
    
    # Optional: Check if maps were stored
    if generator and len(generator.intensity_maps) > 0:
        print(f"\n✓ Successfully stored {len(generator.intensity_maps)} maps in memory")
        print(f"  First map shape: {generator.intensity_maps[0].shape}")
        print(f"  Memory usage: ~{len(generator.intensity_maps) * generator.intensity_maps[0].nbytes / 1e9:.2f} GB")

EXAMPLE 1: VARYING COSMOLOGY MODE
Uses different halo catalogs for each parameter combination
Mode: varying_cosmology
Parameters to vary: ['sigma_8', 'omega_m', 'a_off', 'b_off']
  Cosmological: ['sigma_8', 'omega_m']
  Astrophysical: ['a_off', 'b_off']
Noise mode: fixed

Processing 3000 halo catalogs
Each catalog provides its own cosmology (sigma_8, omega_m)
Astrophysical parameters (a_off, b_off) are sampled via Latin hypercube
Store maps in memory: True
Output directory: /Users/anirbanroy/Desktop/test_maps_all_params/
Generated 3000 Latin hypercube samples for ['a_off', 'b_off']
  a_off: [4.001, 9.999]
  b_off: [0.000, 2.000]

Processing 3000 halocats with varying cosmology
Store maps: True


Processing: 100%|███████████████████████████| 3000/3000 [05:11<00:00,  9.64it/s]



Stored 3000 maps in memory

Saving 3000 intensity maps to disk...
  Intensity maps: /Users/anirbanroy/Desktop/test_maps_all_params/intensity_maps_80_256.npz
  File size: ~1.50 GB

Results saved:
  Power spectra: /Users/anirbanroy/Desktop/test_maps_all_params/ps_80_256.npz
  PDFs: /Users/anirbanroy/Desktop/test_maps_all_params/pdf_80_256.npz
  Config: /Users/anirbanroy/Desktop/test_maps_all_params/parameter_config.json

Mode: varying_cosmology
Seed: 42
Maps stored in memory: 3000

Parameter summary (3000 samples):
  sigma_8: [0.400, 1.200]
  omega_m: [0.100, 0.600]
  a_off: [4.001, 9.999]
  b_off: [0.000, 2.000]

Processed 3000 parameter combinations

✓ Successfully stored 3000 maps in memory
  First map shape: (256, 256)
  Memory usage: ~1.57 GB
