In [None]:
import sys
import os
import numpy as np
import pandas as pd
import torch
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage
from sklearn.preprocessing import Normalizer
import joblib
sys.path.append("/disk/xray15/aem2/camels/proj2")
from setup_params_alice import *

# note from Alice:
# in Chris' code, I think he loads in photometry directly from CAMELS sims, gets galaxies then gets photometry for flux/lums
# I only have flux lums so can not use his get_x / get_theta versions that rely on the camels method.


# safe name for passing in and using for naming paths/directories without underscores.  see more below (get_colour_dir_name)
def get_safe_name(name, filter_system_only=False):
    """
    Convert string to path-safe version and/or extract filter system.
    """
    safe_name = name.replace(' ', '_')
    if filter_system_only:
        return safe_name.split('_')[0]
    return safe_name

def load_uvlf_data(base_dir, start_group=0, end_group=49, redshift_dict=None, filters=None):
    """
    Load UVLF data from txt files
    """
    data = []
    
    # Default to GALEX filters if none specified
    filters = filters or ["GALEX FUV", "GALEX NUV"]
    redshift_dict = redshift_dict or {'044': {'redshift': 2.00, 'label': 'z2.0'}}
    
    for group_num in range(start_group, end_group + 1):
        sim_name = f"SB28_{group_num}"
        group_data = []
        
        for snap, redshift_info in redshift_dict.items():
            for band in filters:
                # Construct file path using same structure as when saving
                filter_system = get_safe_name(band, filter_system_only=True)
                file_name = f"UVLF_{sim_name}_{get_safe_name(band)}_{get_safe_name(redshift_info['label'])}_attenuated.txt"
                file_path = os.path.join(base_dir, "LFs", "attenuated", filter_system, 
                                       get_safe_name(redshift_info['label']), file_name)
                
                try:
                    # Load the UVLF data
                    df = pd.read_csv(file_path, sep='\t')
                    # Append phi values (log number density)
                    group_data.extend(df['phi'].values)
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")
                    return None
                    
        data.append(group_data)
    
    return np.array(data)

def setup_sbi(input_dir, start_group=0, end_group=49, redshift_dict=None):
    """Set up and run SBI"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)
    
    # Load UVLF data
    x_all = load_uvlf_data(
        input_dir, 
        start_group=start_group,
        end_group=end_group,
        redshift_dict=redshift_dict
    )
    
    if x_all is None:
        print("Failed to load data")
        return
    
    # Initialize model parameters
    name = "SB28_GALEX_UVLF_test"
    hidden_features = 30
    num_transforms = 4
    
    # Normalize data
    norm = Normalizer()
    x_all = torch.tensor(
        norm.fit_transform(X=x_all),
        dtype=torch.float32,
        device=device
    )
    
    # Save normalizer
    joblib.dump(norm, f'/disk/xray15/aem2/models/{name}_scaler.save')
    
    # Setup test/train split
    test_mask = np.random.rand(len(x_all)) > 0.9
    np.savetxt('/disk/xray15/aem2/data/test_mask_galex.txt', test_mask, fmt='%i')
    
    # Initialize networks
    nets = [
        ili.utils.load_nde_sbi(
            engine="NPE",
            model="nsf", 
            hidden_features=hidden_features, 
            num_transforms=num_transforms
        ) for _ in range(3)
    ]
    
    # Setup training parameters
    train_args = {
        "training_batch_size": 4,
        "learning_rate": 5e-4,
        'stop_after_epochs': 20
    }
    
    # Initialize data loader
    loader = NumpyLoader(
        x=x_all[~test_mask],
        theta=torch.tensor(theta[~test_mask, :], device=device)
    )
    
    # Setup and run inference
    runner = InferenceRunner.load(
        backend="sbi",
        engine="NPE",
        prior=prior,
        nets=nets,
        device=device,
        train_args=train_args,
        proposal=None,
        out_dir="models/",
        name=name,
    )
    
    posterior_ensemble, summaries = runner(loader=loader)
    
    # Generate coverage plots
    metric = PosteriorCoverage(
        num_samples=int(4e3),
        sample_method='direct',
        labels=param_labels,
        plot_list=["coverage", "histogram", "predictions", "tarp"],
        out_dir="../plots/",
    )
    
    fig = metric(
        posterior=posterior_ensemble,
        x=x_all[test_mask].cpu(),
        theta=theta[test_mask, :].cpu(),
        signature=f"coverage_{name}_",
    )
    
    return posterior_ensemble, summaries

if __name__ == "__main__":
    # Test with just z=2.0
    redshift_test = {'044': {'redshift': 2.00, 'label': 'z2.0'}}
    
    # Run SBI
    posterior, summaries = setup_sbi(
        input_dir="/disk/xray15/aem2/data/28pams/IllustrisTNG/SB",
        start_group=0,
        end_group=49,
        redshift_dict=redshift_test
    )