In [None]:
# === Imports ===
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit, OptimizeWarning
from scipy.special import i0, i1, k0, k1, gammainc
import warnings
from astropy.io import fits
import emcee  # For MCMC fitting
import corner # For posterior plots
from tqdm import tqdm # For progress bars

# === Constants ===
G = 4.3009e-6  # Gravitational constant in (kpc * (km/s)^2) / Msun

# === Component Velocity Profiles ===
def de_vaucouleurs_velocity(r, M_bulge, R_e):
    """Calculates rotational velocity for a de Vaucouleurs bulge."""
    b = 7.669
    safe_r = np.maximum(r, 1e-6)
    x = b * (safe_r / R_e)**0.25
    mass_enclosed = M_bulge * gammainc(8, x)
    return np.sqrt(np.clip(G * mass_enclosed / safe_r, 0, None))

def hernquist_velocity(r, M_bulge, a_b):
    """Calculates rotational velocity for a Hernquist bulge."""
    safe_r = np.maximum(r, 1e-6)
    return np.sqrt(np.clip(G * M_bulge * safe_r / (safe_r + a_b)**2, 0, None))

def disk_velocity(r, M_disk, R_disk):
    """Calculates rotational velocity for an exponential disk."""
    y = r / (2 * np.maximum(R_disk, 1e-9))
    term = np.zeros_like(y)
    valid_mask = (y > 1e-9) & np.isfinite(y)
    y_valid = y[valid_mask]
    term[valid_mask] = i0(y_valid) * k0(y_valid) - i1(y_valid) * k1(y_valid)
    v_sq = G * M_disk / np.maximum(R_disk, 1e-9) * y**2 * term
    return np.sqrt(np.clip(v_sq, 0, None))

def nfw_halo_velocity(r, M_halo, R_s, c_halo):
    """Calculates rotational velocity for an NFW dark matter halo."""
    safe_r = np.maximum(r, 1e-6)
    x = safe_r / R_s
    mass_norm = np.log(1 + c_halo) - c_halo / (1 + c_halo)
    log_term = np.log(1 + x) - x / (1 + x)
    v_sq = G * M_halo * log_term / (safe_r * mass_norm)
    return np.sqrt(np.clip(v_sq, 0, None))

def burkert_halo_velocity(r, rho_0, r_0):
    """Calculates rotational velocity for a Burkert dark matter halo."""
    safe_r = np.maximum(r, 1e-6)
    term1 = np.log((1 + (safe_r/r_0)**2) * (1 + (safe_r/r_0)))
    term2 = 2 * np.arctan(safe_r/r_0)
    v_sq = np.pi * G * rho_0 * r_0**3 / safe_r * (term1 - term2)
    return np.sqrt(np.clip(v_sq, 0, None))

# === Combined Model & MCMC Functions ===
def model_wrapper(r, params, bulge_choice, halo_choice, param_names):
    """A single function to route to the correct model components."""
    param_dict = dict(zip(param_names, params))
    v_b_sq, v_d_sq, v_h_sq = 0, 0, 0

    if bulge_choice == 'hernquist':
        # Ensure correct parameter names for hernquist_velocity
        if 'M_b' in param_dict and 'a_b' in param_dict:
            v_b_sq = hernquist_velocity(r, param_dict['M_b'], param_dict['a_b'])**2
    elif bulge_choice == 'devaucouleurs':
        # Ensure correct parameter names for de_vaucouleurs_velocity
        if 'M_b' in param_dict and 'R_e' in param_dict:
            v_b_sq = de_vaucouleurs_velocity(r, param_dict['M_b'], param_dict['R_e'])**2
    
    if 'M_d' in param_dict:
        v_d_sq = disk_velocity(r, param_dict['M_d'], param_dict['R_d'])**2

    if halo_choice == 'nfw':
        if 'M_h' in param_dict and 'R_s' in param_dict and 'c_h' in param_dict:
            v_h_sq = nfw_halo_velocity(r, param_dict['M_h'], param_dict['R_s'], param_dict['c_h'])**2
    elif halo_choice == 'burkert':
        if 'rho_0' in param_dict and 'r_0' in param_dict:
            v_h_sq = burkert_halo_velocity(r, param_dict['rho_0'], param_dict['r_0'])**2
            
    return np.sqrt(np.clip(v_b_sq + v_d_sq + v_h_sq, 0, None))

def log_likelihood(theta, r, v, v_err, bulge_choice, halo_choice, param_names):
    """Log-likelihood function for MCMC."""
    model_v = model_wrapper(r, theta, bulge_choice, halo_choice, param_names)
    if np.any(np.isnan(model_v)): return -np.inf
    sigma2 = v_err**2
    return -0.5 * np.sum((v - model_v)**2 / sigma2 + np.log(2 * np.pi * sigma2))

def log_prior(theta, bounds):
    """Log-prior function for MCMC. Uniform priors based on bounds."""
    for i in range(len(theta)):
        if not bounds[i][0] <= theta[i] <= bounds[i][1]: return -np.inf
    return 0.0

def log_probability(theta, r, v, v_err, bounds, bulge_choice, halo_choice, param_names):
    """Full log-probability for MCMC."""
    lp = log_prior(theta, bounds)
    if not np.isfinite(lp): return -np.inf
    return lp + log_likelihood(theta, r, v, v_err, bulge_choice, halo_choice, param_names)

# === Data Loading & Input ===
def load_rc_from_fits(fits_file, x0, y0, incl_deg, pa_deg, pixel_scale_kpc, r_max_kpc):
    """Loads and processes velocity data from a FITS file."""
    try:
        with fits.open(fits_file) as hdul: data = hdul[0].data
        pa_rad, incl_rad = np.radians(pa_deg), np.radians(incl_deg)
        ny, nx = data.shape
        Y, X = np.ogrid[:ny, :nx]
        dx = (X - x0) * np.cos(pa_rad) + (Y - y0) * np.sin(pa_rad)
        dy = -(X - x0) * np.sin(pa_rad) + (Y - y0) * np.cos(pa_rad)
        r_kpc_map = np.sqrt(dx**2 + (dy / np.cos(incl_rad))**2) * pixel_scale_kpc
        r_bins = np.linspace(0, r_max_kpc, 40)
        r_centers = 0.5 * (r_bins[:-1] + r_bins[1:])
        v_means, v_errors = [], []

        for i in range(len(r_bins) - 1):
            mask = (r_kpc_map >= r_bins[i]) & (r_kpc_map < r_bins[i + 1])
            if np.any(mask):
                valid_vels = data[mask][~np.isnan(data[mask])]
                if len(valid_vels) > 0:
                    v_means.append(np.mean(np.abs(valid_vels)))
                    v_errors.append(np.std(valid_vels) / np.sqrt(len(valid_vels)) if len(valid_vels) > 1 else np.mean(v_means)*0.1)
                else: v_means.append(np.nan); v_errors.append(np.nan)
            else: v_means.append(np.nan); v_errors.append(np.nan)

        valid = ~np.isnan(v_means) & ~np.isnan(v_errors) & (np.array(v_errors) > 0)
        return r_centers[valid], np.array(v_means)[valid], np.array(v_errors)[valid]
    except Exception as e:
        print(f"FITS read error: {e}")
        return np.array([]), np.array([]), np.array([])

def get_param_with_error(prompt_name, default_val, default_err, is_positive=True):
    """Safely get a parameter's initial value and error, then compute bounds."""
    print(f"\n--- {prompt_name} ---")
    try: val = float(input(f"    Initial guess (e.g., {default_val}): "))
    except (ValueError, TypeError): val = default_val; print(f"    Invalid input. Using default: {val}")
    try: err = float(input(f"    Error estimate (+/-) (e.g., {default_err}): "))
    except (ValueError, TypeError): err = default_err; print(f"    Invalid input. Using default: {err}")

    n_sigma = 5.0
    lower_bound = val - n_sigma * err
    if is_positive: lower_bound = max(1e-9, lower_bound)
    upper_bound = val + n_sigma * err
    
    return val, (lower_bound, upper_bound)

# === Fit Quality Metrics ===
def compute_fit_metrics(observed, fitted, errors, dof):
    """Computes reduced chi-squared and R-squared."""
    residuals = observed - fitted
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((observed - np.mean(observed))**2)
    chi2 = np.sum((residuals / errors)**2)
    reduced_chi2 = chi2 / max(1, len(observed) - dof)
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
    return reduced_chi2, r2, residuals

# === Main Execution ===
if __name__ == '__main__':
    mode = input("Select data mode ('tabulated' or 'fits_raw'): ").strip().lower()
    bulge_choice = input("Select bulge profile ('hernquist', 'devaucouleurs', or 'none'): ").strip().lower()
    halo_choice = input("Select halo profile ('nfw', 'burkert', or 'none'): ").strip().lower()
    fitting_method = input("Select fitting method ('mcmc', 'bootstrap', 'curve_fit'): ").strip().lower()

    # --- Load Data ---
    if mode == 'tabulated':
        filename = input("Enter path to tabulated data file (r, v, err): ")
        try:
            data = np.loadtxt(filename)
            r_obs, v_obs = data[:, 0], data[:, 1]
            errors = data[:, 2] if data.shape[1] > 2 else np.ones_like(v_obs) * np.mean(v_obs) * 0.1
        except Exception as e: print(f"Error loading data: {e}"); exit()
    elif mode == 'fits_raw':
        fitsfile = input("FITS file: "); x0 = float(input("x center (px): ")); y0 = float(input("y center (px): "))
        incl = float(input("Inclination (deg): ")); pa = float(input("Position Angle (deg, E from N): "))
        scale = float(input("Pixel scale (kpc/pixel): ")); rmax = float(input("Max radius to extract (kpc): "))
        r_obs, v_obs, errors = load_rc_from_fits(fitsfile, x0, y0, incl, pa, scale, rmax)
        if len(r_obs) < 5: print("Not enough data points extracted."); exit()
    else: print("Invalid mode selected."); exit()

    # --- Get User Parameters with Simplified Units ---
    param_names, initial_params, param_bounds = [], [], []
    mass_params = ['M_b', 'M_d', 'M_h']
    mass_scale = 1e10

    if bulge_choice in ['hernquist', 'devaucouleurs']:
        p, b = get_param_with_error("Bulge Mass (10^10 Msun)", 1.0, 0.5)
        initial_params.append(p * mass_scale); param_bounds.append((b[0]*mass_scale, b[1]*mass_scale)); param_names.append('M_b')
        if bulge_choice == 'hernquist':
            p, b = get_param_with_error("Bulge Scale 'a_b' (kpc)", 1.0, 0.5)
            initial_params.append(p); param_bounds.append(b); param_names.append('a_b')
        else: # devaucouleurs
            p, b = get_param_with_error("Bulge Effective Radius 'R_e' (kpc)", 2.0, 1.0)
            initial_params.append(p); param_bounds.append(b); param_names.append('R_e')

    p, b = get_param_with_error("Disk Mass (10^10 Msun)", 5.0, 2.0)
    initial_params.append(p * mass_scale); param_bounds.append((b[0]*mass_scale, b[1]*mass_scale)); param_names.append('M_d')
    p, b = get_param_with_error("Disk Scale Radius 'R_d' (kpc)", 3.0, 1.5)
    initial_params.append(p); param_bounds.append(b); param_names.append('R_d')
    
    if halo_choice == 'nfw':
        p, b = get_param_with_error("NFW Halo Mass (10^10 Msun)", 100.0, 50.0) # 1e12 Msun
        initial_params.append(p * mass_scale); param_bounds.append((b[0]*mass_scale, b[1]*mass_scale)); param_names.append('M_h')
        p, b = get_param_with_error("NFW Scale Radius 'R_s' (kpc)", 20, 10)
        initial_params.append(p); param_bounds.append(b); param_names.append('R_s')
        p, b = get_param_with_error("NFW Concentration 'c_h'", 10, 5)
        initial_params.append(p); param_bounds.append(b); param_names.append('c_h')
    elif halo_choice == 'burkert':
        p, b = get_param_with_error("Burkert Central Density 'rho_0' (Msun/kpc^3)", 1e7, 5e6)
        initial_params.append(p); param_bounds.append(b); param_names.append('rho_0')
        p, b = get_param_with_error("Burkert Core Radius 'r_0' (kpc)", 10, 5)
        initial_params.append(p); param_bounds.append(b); param_names.append('r_0')
    
    n_params = len(initial_params)
    final_params, final_params_err, param_percentiles, flat_samples = None, None, None, None

    # --- Perform Fitting ---
    fit_func = lambda r, *params: model_wrapper(r, params, bulge_choice, halo_choice, param_names)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", (OptimizeWarning, RuntimeWarning))
        if fitting_method == 'curve_fit':
            print("\nFitting with curve_fit...")
            popt, pcov = curve_fit(fit_func, r_obs, v_obs, p0=initial_params, sigma=errors, bounds=np.array(param_bounds).T, maxfev=10000)
            final_params, final_params_err = popt, np.sqrt(np.diag(pcov))
        elif fitting_method == 'bootstrap':
            print("\nPerforming bootstrap analysis...")
            bootstrap_params = np.zeros((500, n_params))
            for i in tqdm(range(500)):
                indices = np.random.choice(len(r_obs), len(r_obs), replace=True)
                try:
                    popt, _ = curve_fit(fit_func, r_obs[indices], v_obs[indices], p0=initial_params, sigma=errors[indices], bounds=np.array(param_bounds).T, maxfev=5000)
                    bootstrap_params[i, :] = popt
                except RuntimeError: bootstrap_params[i, :] = np.nan
            valid_fits = bootstrap_params[~np.isnan(bootstrap_params).any(axis=1)]
            param_percentiles = np.percentile(valid_fits, [16, 50, 84], axis=0)
            final_params = param_percentiles[1]
        elif fitting_method == 'mcmc':
            print("\nRunning MCMC sampling with emcee...")
            n_walkers = max(2 * n_params, 50)
            pos = initial_params + 1e-4 * np.random.randn(n_walkers, n_params)
            pos = np.clip(pos, [b[0] for b in param_bounds], [b[1] for b in param_bounds])
            sampler = emcee.EnsembleSampler(n_walkers, n_params, log_probability, args=(r_obs, v_obs, errors, param_bounds, bulge_choice, halo_choice, param_names))
            sampler.run_mcmc(pos, 5000, progress=True)
            flat_samples = sampler.get_chain(discard=1000, thin=15, flat=True)
            param_percentiles = np.percentile(flat_samples, [16, 50, 84], axis=0)
            final_params = param_percentiles[1]

    # === Display Results & Plotting (Robust Version) ===
    print("\n--- Fit Results ---")
    n_fit_params = 0
    if final_params is not None:
        n_fit_params = len(final_params)
        
        # Check for parameter mismatch, which can happen in interactive environments
        if n_fit_params != n_params:
            print("\n--- WARNING ---")
            print(f"Mismatch detected: {n_params} parameters were defined, but the fit produced {n_fit_params} parameters.")
            print("This can happen if notebook cells are run out of order. Please restart and run all cells sequentially for correct labels.")
        
        # Display results based on which method was used
        for i in range(n_fit_params):
            name = param_names[i]
            unit = ""
            scale = 1.0
            if name in mass_params:
                unit = f" (10^{int(np.log10(mass_scale))} Msun)"
                scale = 1.0 / mass_scale

            if fitting_method in ['mcmc', 'bootstrap'] and param_percentiles is not None:
                med = param_percentiles[1, i] * scale
                upper_err = (param_percentiles[2, i] - param_percentiles[1, i]) * scale
                lower_err = (param_percentiles[1, i] - param_percentiles[0, i]) * scale
                print(f"{name}{unit}: {med:.3f} (+{upper_err:.3f}) (-{lower_err:.3f})")
            elif fitting_method == 'curve_fit' and final_params_err is not None:
                val = final_params[i] * scale
                err = final_params_err[i] * scale
                print(f"{name}{unit}: {val:.3f} +/- {err:.3f}")
    
    if final_params is not None:
        r_plot = np.linspace(0.01, r_obs.max()*1.1, 500)
        chi2_nu, r_sq, residuals = compute_fit_metrics(v_obs, model_wrapper(r_obs, final_params, bulge_choice, halo_choice, param_names), errors, dof=n_fit_params)
        print(f"\nGoodness of Fit:    χ²_ν = {chi2_nu:.3f} | R² = {r_sq:.4f}")

        plt.style.use('seaborn-v0_8-whitegrid')
        fig1, ax1 = plt.subplots(figsize=(12, 7))
        ax1.errorbar(r_obs, v_obs, yerr=errors, fmt='o', color='black', label='Observed Data', capsize=3, markersize=5, zorder=5)
        
        param_dict = dict(zip(param_names, final_params))
        # Plot model components...
        if bulge_choice != 'none':
            bulge_v = np.zeros_like(r_plot) # Initialize to zeros
            if bulge_choice == 'devaucouleurs':
                bulge_p = {}
                if 'M_b' in param_dict:
                    bulge_p['M_bulge'] = param_dict['M_b']
                if 'R_e' in param_dict:
                    bulge_p['R_e'] = param_dict['R_e']
                if 'M_bulge' in bulge_p and 'R_e' in bulge_p: # Ensure all required params are present
                    bulge_v = de_vaucouleurs_velocity(r_plot, **bulge_p)
                    ax1.plot(r_plot, bulge_v, '--', color='orange', label=f'{bulge_choice.capitalize()} Bulge')
            elif bulge_choice == 'hernquist':
                bulge_p = {}
                if 'M_b' in param_dict:
                    bulge_p['M_bulge'] = param_dict['M_b']
                if 'a_b' in param_dict:
                    bulge_p['a_b'] = param_dict['a_b']
                if 'M_bulge' in bulge_p and 'a_b' in bulge_p: # Ensure all required params are present
                    bulge_v = hernquist_velocity(r_plot, **bulge_p)
                    ax1.plot(r_plot, bulge_v, '--', color='orange', label=f'{bulge_choice.capitalize()} Bulge')
        
        if 'M_d' in param_dict and 'R_d' in param_dict:
            disk_p = {'M_disk': param_dict['M_d'], 'R_disk': param_dict['R_d']}
            ax1.plot(r_plot, disk_velocity(r_plot, **disk_p), '--', color='green', label='Disk')
        
        if halo_choice != 'none':
            halo_v = np.zeros_like(r_plot) # Initialize to zeros
            if halo_choice == 'nfw':
                halo_p = {}
                if 'M_h' in param_dict:
                    halo_p['M_halo'] = param_dict['M_h']
                if 'R_s' in param_dict:
                    halo_p['R_s'] = param_dict['R_s']
                if 'c_h' in param_dict:
                    halo_p['c_halo'] = param_dict['c_h']
                if 'M_halo' in halo_p and 'R_s' in halo_p and 'c_halo' in halo_p: # Ensure all required params are present
                    halo_v = nfw_halo_velocity(r_plot, **halo_p)
                    ax1.plot(r_plot, halo_v, '--', color='purple', label=f'{halo_choice.capitalize()} Halo')
            elif halo_choice == 'burkert':
                halo_p = {}
                if 'rho_0' in param_dict:
                    halo_p['rho_0'] = param_dict['rho_0']
                if 'r_0' in param_dict:
                    halo_p['r_0'] = param_dict['r_0']
                if 'rho_0' in halo_p and 'r_0' in halo_p: # Ensure all required params are present
                    halo_v = burkert_halo_velocity(r_plot, **halo_p)
                    ax1.plot(r_plot, halo_v, '--', color='purple', label=f'{halo_choice.capitalize()} Halo')


        ax1.plot(r_plot, model_wrapper(r_plot, final_params, bulge_choice, halo_choice, param_names), color='red', lw=2, label='Total Model', zorder=4)

        # Plot uncertainty band
        if fitting_method in ['mcmc', 'bootstrap']:
            samples = flat_samples if flat_samples is not None else (valid_fits if 'valid_fits' in locals() else None)
            if samples is not None and len(samples) > 0:
                # Limit to a reasonable number of samples for plotting efficiency
                num_plot_samples = min(len(samples), 500) 
                v_samples = np.array([model_wrapper(r_plot, s, bulge_choice, halo_choice, param_names) for s in samples[np.random.randint(len(samples), size=num_plot_samples)]])
                lower, upper = np.percentile(v_samples, [16, 84], axis=0)
                ax1.fill_between(r_plot, lower, upper, color='red', alpha=0.2, label='1σ Confidence Region', zorder=0)
            
        ax1.set_xlabel("Radius (kpc)"); ax1.set_ylabel("Velocity (km/s)"); ax1.set_title("Galaxy Rotation Curve Fit"); ax1.legend()
        ax1.set_xlim(0, r_obs.max()*1.05); ax1.set_ylim(0, max(v_obs.max()*1.2, ax1.get_ylim()[1]))
        ax1.text(0.05, 0.95, f"$\\chi^2_\\nu = {chi2_nu:.2f}$\n$R^2 = {r_sq:.3f}$", transform=ax1.transAxes, va='top', bbox=dict(boxstyle='round', fc='wheat', alpha=0.5))
        plt.tight_layout(); plt.savefig("rotation_curve_fit.png", dpi=300); plt.show()

        # Residuals Plot
        fig2, ax2 = plt.subplots(figsize=(12, 4))
        ax2.errorbar(r_obs, residuals, yerr=errors, fmt='o', color='black', capsize=3, markersize=5)
        ax2.axhline(0, color='red', linestyle='--', lw=2); ax2.set_xlabel("Radius (kpc)"); ax2.set_ylabel("Residuals (km/s)"); ax2.set_title("Fit Residuals")
        ax2.set_xlim(0, r_obs.max()*1.05); plt.tight_layout(); plt.savefig("residuals_plot.png", dpi=300); plt.show()

        # Corner Plot
        if flat_samples is not None:
            print("\nGenerating MCMC corner plot...")
            plot_samples = flat_samples.copy()
            plot_labels = param_names[:n_fit_params]
            plot_truths = final_params.copy()
            for i in range(n_fit_params):
                if param_names[i] in mass_params:
                    plot_samples[:, i] /= mass_scale
                    plot_truths[i] /= mass_scale
                    plot_labels[i] = f"{param_names[i]} ($10^{{{int(np.log10(mass_scale))}}} M_\\odot$)"

            fig3 = corner.corner(plot_samples, labels=plot_labels, truths=plot_truths, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})
            fig3.suptitle("MCMC Posterior Distributions", fontsize=16, y=1.02)
            plt.savefig("mcmc_corner_plot.png", dpi=300); plt.show()
    else:  
        print("\nFit failed or was not run. Nothing to display or plot.")