In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import seaborn as sns
import warnings
import traceback
from astropy.constants import G
import astropy.units as u
import symlib
import colossus.cosmology as cosmology

sim_colors = {
    'SymphonyLMC': sns.color_palette("colorblind")[4],
    'SymphonyMilkyWay': sns.color_palette("colorblind")[0],
    'SymphonygGroup': sns.color_palette("colorblind")[2],
    'SymphonyLCluster': sns.color_palette("colorblind")[1],
    'SymphonyCluster': sns.color_palette("colorblind")[3]
}

In [None]:
def plot_combined_ppsd_and_slopes(base_dir, suite_names):
    fig_ppsd, axes_ppsd = plt.subplots(2, 2, figsize=(14, 10), dpi=500)
    fig_slope, axes_slope = plt.subplots(2, 2, figsize=(14, 10), dpi=500)
    ax_ppsd = axes_ppsd.flatten()
    ax_slope = axes_slope.flatten()

    def load_ppsd_profiles(suite):
        dir_path = os.path.join(base_dir, "output", suite, "ppsd_profiles")
        files = sorted([f for f in os.listdir(dir_path) if f.endswith(".csv")])
        r_list, m_list, Qr_list, Qtot_list = [], [], [], []
        for f in files:
            df = pd.read_csv(os.path.join(dir_path, f))
            r_list.append(df["r_scaled"].values)
            m_list.append(df["m_scaled"].values)
            Qr_list.append(df["Q_r"].values)
            Qtot_list.append(df["Q_tot"].values)
        return r_list, m_list, Qr_list, Qtot_list

    def load_slope_profiles(suite):
        dir_r = os.path.join(base_dir, "output", suite, "ppsd_slope_profiles_r")
        dir_m = os.path.join(base_dir, "output", suite, "ppsd_slope_profiles_m")
        files_r = sorted([f for f in os.listdir(dir_r) if f.endswith(".csv")])
        files_m = sorted([f for f in os.listdir(dir_m) if f.endswith(".csv")])
        r_list, m_list, slope_r_Qr, slope_r_Qtot, slope_m_Qr, slope_m_Qtot = [], [], [], [], [], []
        for fr, fm in zip(files_r, files_m):
            dfr = pd.read_csv(os.path.join(dir_r, fr))
            dfm = pd.read_csv(os.path.join(dir_m, fm))
            r_list.append(dfr["r_scaled"].values)
            m_list.append(dfm["m_scaled"].values)
            slope_r_Qr.append(dfr["slope_Q_r"].values)
            slope_r_Qtot.append(dfr["slope_Q_tot"].values)
            slope_m_Qr.append(dfm["slope_Q_r"].values)
            slope_m_Qtot.append(dfm["slope_Q_tot"].values)
        return r_list, m_list, slope_r_Qr, slope_r_Qtot, slope_m_Qr, slope_m_Qtot

    def mean_std(ax, x_list, y_list, color, label):
        if len(x_list) == 0:
            return
        x_common = np.logspace(np.log10(np.nanmin([x[0] for x in x_list if len(x) > 0])),
                               np.log10(np.nanmax([x[-1] for x in x_list if len(x) > 0])), 200)
        y_interp = []
        for x, y in zip(x_list, y_list):
            if np.all(np.isfinite(x)) and np.all(np.isfinite(y)):
                y_interp.append(np.interp(x_common, x, y, left=np.nan, right=np.nan))
        y_arr = np.array(y_interp)
        y_mean = np.nanmean(y_arr, axis=0)
        y_std = np.nanstd(y_arr, axis=0)
        sigma_val = np.nanmean(y_std)
        ax.plot(x_common, y_mean, color=color, lw=1.5, label=f"{label} (σ={sigma_val:.2f})")
        ax.fill_between(x_common, y_mean - y_std, y_mean + y_std, color=color, alpha=0.3)

    for suite in suite_names:
        color = sim_colors.get(suite, "gray")
        # PPSD profiles
        r, m, Qr, Qtot = load_ppsd_profiles(suite)
        mean_std(ax_ppsd[0], r, Qr, color, suite)
        mean_std(ax_ppsd[1], r, Qtot, color, suite)
        mean_std(ax_ppsd[2], m, Qr, color, suite)
        mean_std(ax_ppsd[3], m, Qtot, color, suite)

        # PPSD slopes
        r, m, sr_Qr, sr_Qtot, sm_Qr, sm_Qtot = load_slope_profiles(suite)
        mean_std(ax_slope[0], r, sr_Qr, color, suite)
        mean_std(ax_slope[1], r, sr_Qtot, color, suite)
        mean_std(ax_slope[2], m, sm_Qr, color, suite)
        mean_std(ax_slope[3], m, sm_Qtot, color, suite)

    labels = ["$Q_r$ vs $r$", "$Q_{tot}$ vs $r$", "$Q_r$ vs $M$", "$Q_{tot}$ vs $M$"]
    for i in range(4):
        for ax in [ax_ppsd[i], ax_slope[i]]:
            ax.set_xscale("log")
            ax.set_xlabel("r / $R_{vir}$" if i < 2 else "M(<r) / $M_{vir}$")
            ax.set_ylabel("Q" if ax in ax_ppsd else "Slope")
            ax.set_title(labels[i])
            ax.legend()
            ax.grid(True, which="both", linestyle=":")

    fig_ppsd.suptitle("PPSD Profiles: Mean ± 1σ across Suites", fontsize=16)
    fig_slope.suptitle("PPSD Slope Profiles: Mean ± 1σ across Suites", fontsize=16)

    out_dir = os.path.join(base_dir, "output", "Combined")
    os.makedirs(out_dir, exist_ok=True)
    fig_ppsd.tight_layout(rect=[0, 0, 1, 0.96])
    fig_ppsd.savefig(os.path.join(out_dir, "ppsd_profiles_comparison.png"))
    fig_slope.tight_layout(rect=[0, 0, 1, 0.96])
    fig_slope.savefig(os.path.join(out_dir, "ppsd_slope_profiles_comparison.png"))
    plt.show()

In [None]:
def plot_ppsd_split_by_concentration_and_accretion(base_dir, suite_names):
    
    def calculate_accretion_rate(sim_dir):
        try:
            scale = symlib.scale_factors(sim_dir)
            r, _ = symlib.read_rockstar(sim_dir)
            snap = len(scale) - 1
            m_now = r[0, snap]["m"]

            sim_params = symlib.simulation_parameters(sim_dir)
            cosmo = cosmology.setCosmology("custom", {
                "flat": sim_params["flat"],
                "H0": sim_params["H0"],
                "Om0": sim_params["Om0"],
                "Ob0": sim_params["Ob0"],
                "sigma8": sim_params["sigma8"],
                "ns": sim_params["ns"]
            })

            rho_m0 = cosmo.rho_m(0)
            Delta = 99
            rho_vir = Delta * rho_m0 * 1e9 / sim_params["h100"]**2
            G_val = G.to(u.Mpc**3 / (u.Msun * u.Gyr**2)).value
            t_dyn = 1.0 / np.sqrt((4 / 3) * np.pi * G_val * rho_vir)
            times = cosmo.age(1 / scale - 1)
            t0 = times[snap]
            t_past = t0 - t_dyn
            snap_past = np.argmin(np.abs(times - t_past))
            m_past = r[0, snap_past]["m"]

            gamma = (m_now - m_past) / t_dyn
            return gamma
        except Exception as e:
            traceback.print_exc()
            warnings.warn(f"Error calculating accretion rate: {e}")
            return np.nan

    def mean_std_plot(ax, x_list, y_list, label, color):
        x_ref = x_list[0]
        y_stack = np.stack(y_list)
        y_mean = np.nanmean(y_stack, axis=0)
        y_std = np.nanstd(y_stack, axis=0)
        ax.plot(x_ref, y_mean, color=color, lw=1.5, label=f"{label} (σ={np.nanmean(y_std):.2f})")
        ax.fill_between(x_ref, y_mean - y_std, y_mean + y_std, color=color, alpha=0.3)

    all_r_list, all_m_list = [], []
    all_Qr_list, all_Qtot_list = [], []
    all_gamma_list, all_cvir_list = [], []

    for suite in suite_names:
        ppsd_dir = os.path.join(base_dir, "output", suite, "ppsd_profiles")
        files = sorted([f for f in os.listdir(ppsd_dir) if f.endswith(".csv")])

        cvir_list, gamma_list = [], []
        r_list, m_list, Qr_list, Qtot_list = [], [], [], []

        for i, f in enumerate(files):
            try:
                sim_dir = symlib.get_host_directory(base_dir, suite, i)
                df = pd.read_csv(os.path.join(ppsd_dir, f))
                r = df["r_scaled"].values
                m = df["m_scaled"].values
                Qr = df["Q_r"].values
                Qtot = df["Q_tot"].values
                r_list.append(r)
                m_list.append(m)
                Qr_list.append(Qr)
                Qtot_list.append(Qtot)

                r_data, _ = symlib.read_rockstar(sim_dir)
                cvir = r_data[0, -1]["cvir"]
                gamma = calculate_accretion_rate(sim_dir)

                cvir_list.append(cvir)
                gamma_list.append(gamma)
            except:
                continue

        cvir_list = np.array(cvir_list)
        gamma_list = np.array(gamma_list)
        median_c = np.nanmedian(cvir_list)
        median_gamma = np.nanmedian(gamma_list)

        for r, m, Qr, Qtot, c, g in zip(r_list, m_list, Qr_list, Qtot_list, cvir_list, gamma_list):
            if np.any(~np.isfinite(Qr)) or np.any(~np.isfinite(Qtot)):
                continue
            all_r_list.append(r)
            all_m_list.append(m)
            all_Qr_list.append(Qr)
            all_Qtot_list.append(Qtot)
            all_cvir_list.append("high" if c > median_c else "low")
            all_gamma_list.append("high" if g > median_gamma else "low")

    # --- Concentration Split ---
    fig_c, axes_c = plt.subplots(2, 2, figsize=(14, 10), dpi=500)
    axc = axes_c.flatten()
    r_low, r_high, Qr_low, Qr_high = [], [], [], []
    m_low, m_high, Qtot_low, Qtot_high = [], [], [], []

    for r, m, Qr, Qtot, label in zip(all_r_list, all_m_list, all_Qr_list, all_Qtot_list, all_cvir_list):
        if label == "low":
            r_low.append(r)
            m_low.append(m)
            Qr_low.append(Qr)
            Qtot_low.append(Qtot)
        else:
            r_high.append(r)
            m_high.append(m)
            Qr_high.append(Qr)
            Qtot_high.append(Qtot)

    mean_std_plot(axc[0], r_low, Qr_low, "Low c", "steelblue")
    mean_std_plot(axc[0], r_high, Qr_high, "High c", "firebrick")
    axc[0].set_title(r"$Q_r$ vs $r$")

    mean_std_plot(axc[1], r_low, Qtot_low, "Low c", "steelblue")
    mean_std_plot(axc[1], r_high, Qtot_high, "High c", "firebrick")
    axc[1].set_title(r"$Q_{\rm tot}$ vs $r$")

    mean_std_plot(axc[2], m_low, Qr_low, "Low c", "steelblue")
    mean_std_plot(axc[2], m_high, Qr_high, "High c", "firebrick")
    axc[2].set_title(r"$Q_r$ vs $M$")

    mean_std_plot(axc[3], m_low, Qtot_low, "Low c", "steelblue")
    mean_std_plot(axc[3], m_high, Qtot_high, "High c", "firebrick")
    axc[3].set_title(r"$Q_{\rm tot}$ vs $M$")

    for ax in axc:
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xlabel("r / Rvir" if "r" in ax.get_title() else "M / Mvir")
        ax.set_ylabel("PPSD")
        ax.grid(True, which="both", linestyle=":")
        ax.legend()

    fig_c.suptitle("PPSD Profiles by Concentration (Combined Suites)", fontsize=16)
    fig_c.tight_layout(rect=[0, 0, 1, 0.96])
    os.makedirs(os.path.join(base_dir, "output", "combined", "figures"), exist_ok=True)
    fig_c.savefig(os.path.join(base_dir, "output", "combined", "figures", "ppsd_profiles_by_concentration.png"))

    # --- Accretion Split ---
    fig_a, axes_a = plt.subplots(2, 2, figsize=(14, 10), dpi=500)
    axa = axes_a.flatten()
    r_low, r_high, Qr_low, Qr_high = [], [], [], []
    m_low, m_high, Qtot_low, Qtot_high = [], [], [], []

    for r, m, Qr, Qtot, label in zip(all_r_list, all_m_list, all_Qr_list, all_Qtot_list, all_gamma_list):
        if label == "low":
            r_low.append(r)
            m_low.append(m)
            Qr_low.append(Qr)
            Qtot_low.append(Qtot)
        else:
            r_high.append(r)
            m_high.append(m)
            Qr_high.append(Qr)
            Qtot_high.append(Qtot)

    mean_std_plot(axa[0], r_low, Qr_low, "Low Accretion", "seagreen")
    mean_std_plot(axa[0], r_high, Qr_high, "High Accretion", "orange")
    axa[0].set_title(r"$Q_r$ vs $r$")

    mean_std_plot(axa[1], r_low, Qtot_low, "Low Accretion", "seagreen")
    mean_std_plot(axa[1], r_high, Qtot_high, "High Accretion", "orange")
    axa[1].set_title(r"$Q_{\rm tot}$ vs $r$")

    mean_std_plot(axa[2], m_low, Qr_low, "Low Accretion", "seagreen")
    mean_std_plot(axa[2], m_high, Qr_high, "High Accretion", "orange")
    axa[2].set_title(r"$Q_r$ vs $M$")

    mean_std_plot(axa[3], m_low, Qtot_low, "Low Accretion", "seagreen")
    mean_std_plot(axa[3], m_high, Qtot_high, "High Accretion", "orange")
    axa[3].set_title(r"$Q_{\rm tot}$ vs $M$")

    for ax in axa:
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xlabel("r / Rvir" if "r" in ax.get_title() else "M / Mvir")
        ax.set_ylabel("PPSD")
        ax.grid(True, which="both", linestyle=":")
        ax.legend()

    fig_a.suptitle("PPSD Profiles by Accretion Rate (Combined Suites)", fontsize=16)
    fig_a.tight_layout(rect=[0, 0, 1, 0.96])
    fig_a.savefig(os.path.join(base_dir, "output", "Combined", "ppsd_profiles_by_accretion.png"))
    plt.show()

In [None]:
base_dir = "~/Projects/Symphony-PPSD"
suite_names = [
    "SymphonyLMC",
    "SymphonyMilkyWay",
    "SymphonyGroup",
    "SymphonyLCluster",
    "SymphonyCluster",
]

plot_combined_ppsd_and_slopes(base_dir, suite_names)
plot_ppsd_split_by_concentration_and_accretion(base_dir, suite_names)