In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Corner and angular plots for dynamite M32 models.

Author: Andrés Beamuz Mingote
Date: 2025-09-18
"""

# =============================================================================
# Imports
# =============================================================================
import os
import glob
import sys
from datetime import datetime
import logging

import numpy as np
from astropy.table import Table, vstack, Column
from scipy.spatial import ConvexHull

import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import colormaps
from matplotlib.legend_handler import HandlerTuple
from matplotlib.patches import PathPatch, Polygon
from matplotlib.path import Path
from matplotlib.ticker import AutoMinorLocator
import matplotlib.gridspec as gridspec

# Dynamite path
sys.path.append('/home/andres-beamuz/Documents/TFM/dynamite-master/')
import dynamite as dyn

# =============================================================================
# Configuration
# =============================================================================
DO_ANGLES = True   # <-- Set to False if you don’t want to generate the angle plot

BASE_DIR = "/home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3"

# Collect all all_models.ecsv inside subfolders starting with dynamite_results_m3
DATA_FILES = sorted(
    glob.glob(os.path.join(BASE_DIR, "dynamite_results_m3*/all_models.ecsv"))
)

print(f"Found {len(DATA_FILES)} files:")
for f in DATA_FILES:
    print(" -", f)

OUTPUT_DIR = "Plots/Param_search"
os.makedirs(OUTPUT_DIR, exist_ok=True)
CORNER_PLOT_PATH = os.path.join(OUTPUT_DIR, "Param_corner_plot.pdf")
ANGLES_PLOT_PATH = os.path.join(OUTPUT_DIR, "Angles_plot.pdf")

# Path to your DYNAMITE config (used to compute angles)
CONFIG_PATH = "/home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/M32_UCM_PC_config.yaml"

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    "font.size": 9,
    "axes.titlesize": 9,
    "axes.labelsize": 9,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.fontsize": 9,
    "figure.titlesize": 9,
})

time_start = datetime.now()
print("Started:", time_start)

# =============================================================================
# Helper functions
# =============================================================================
def confidence_region(x, y, chi2_red, delta_chi2, shade="in", ax=None):
    """Return convex hull polygon of points within chi2_red <= delta_chi2."""
    mask = chi2_red <= delta_chi2
    pts = np.column_stack([x[mask], y[mask]])
    if len(pts) < 3:
        raise ValueError("Not enough points below threshold to define a region")

    hull = ConvexHull(pts)
    poly = pts[hull.vertices]

    if shade == "in":
        return poly
    elif shade == "out":
        if ax is None:
            raise ValueError("ax must be provided when shade='out'")
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        rect = np.array([
            [xlim[0], ylim[0]],
            [xlim[1], ylim[0]],
            [xlim[1], ylim[1]],
            [xlim[0], ylim[1]],
        ])
        verts = np.vstack([rect, [rect[0]], poly[::-1], [poly[0]]])
        codes_rect = [Path.MOVETO] + [Path.LINETO] * 3 + [Path.CLOSEPOLY]
        codes_poly = [Path.MOVETO] + [Path.LINETO] * (len(poly) - 1) + [Path.CLOSEPOLY]
        codes = codes_rect + codes_poly
        return verts, codes
    else:
        raise ValueError("shade must be 'in' or 'out'")


class HandlerCircleCross(HandlerTuple):
    """Custom legend handler to show a circle + cross symbol."""
    def create_artists(self, legend, orig_handle, xdescent, ydescent,
                       width, height, fontsize, trans):
        circle, cross = orig_handle
        c = mpatches.Circle(
            (xdescent + width / 2, ydescent + height / 2),
            radius=min(width, height) / 1.5,
            edgecolor=circle.get_edgecolor(),
            facecolor="none",
            lw=circle.get_linewidths()[0],
            transform=trans,
        )
        cross_size = min(width, height)
        hline = mlines.Line2D(
            [xdescent + width/2 - cross_size/2, xdescent + width/2 + cross_size/2],
            [ydescent + height/2, ydescent + height/2],
            color=cross.get_facecolor()[0],
            lw=cross.get_linewidths()[0],
            transform=trans,
        )
        vline = mlines.Line2D(
            [xdescent + width/2, xdescent + width/2],
            [ydescent + height/2 - cross_size/2, ydescent + height/2 + cross_size/2],
            color=cross.get_facecolor()[0],
            lw=cross.get_linewidths()[0],
            transform=trans,
        )
        return [c, hline, vline]

# =============================================================================
# Plotting functions
# =============================================================================
def plot_corner(table, savepath, chi_lim_val = 3, region=True, axis_bool=True):
    """Generate corner plot of parameter constraints.
            Parameters
        ----------
        table : astropy.table.Table
            The results table containing p-stars, q-stars, u-stars, chi2_red, etc.
        savepath : str
            Path where the plot will be saved.
        chi_lim_val : float, optional
            Maximum chi2_red value for color scaling and masking. Default = 3.
        region : bool, optional
            Wether to draw the confidence uncertainties regions or not
        axis_bool : bool, optional
            Use the predefined ticks for the axis
    """
    param_names = ['m-bh', 'u-stars', 'a-bh', 'f-dh', 'c-dh',
                   'p-stars', 'q-stars', 'ml']

    ax_ticks = [[6.4, 6.5, 6.6, 6.7],
                [0.985, 0.990, 0.995, 1.0],
                [-3.3, -3.1, -2.9, -2.7],
                [0, 2, 4, 6],
                [0, 8, 16, 24, 32],
                [0.82, 0.88, 0.94, 1],
                [0.3, 0.45, 0.6, 0.75],
                [0.4, 0.8, 1.2, 1.6]]

    n_params = len(param_names)
    fig_size = n_params / 8 * 12
    fig, axes = plt.subplots(n_params - 1, n_params - 1, figsize=(fig_size, fig_size))
    plt.subplots_adjust(wspace=0, hspace=0)

    pad_frac_x, pad_frac_y = 0.17, 0.18
    cmap = 'inferno_r'
    min_color, min_line_width = 'cyan', 0.8

    for i in range(n_params - 1):
        for j in range(n_params - 1):
            ax = axes[i, j]
            if (i + j) >= n_params - 1:
                ax.set_visible(False)
                continue

            param_y = param_names[i]
            param_x = param_names[-j-1]
            x, y = table[param_x], table[param_y]

            # Limits with padding
            x_range, y_range = np.ptp(x), np.ptp(y)
            ax.set_xlim(np.nanmin(x) - pad_frac_x * x_range,
                        np.nanmax(x) + pad_frac_x * x_range)
            ax.set_ylim(np.nanmin(y) - pad_frac_y * y_range,
                        np.nanmax(y) + pad_frac_y * y_range)

            # Good/bad points
            mask_good = table['chi2_red'] <= chi_lim_val
            mask_bad = ~mask_good
            ax.scatter(x[mask_bad], y[mask_bad],
                       facecolors='none', edgecolors='black', marker='D',
                       s=3, linewidths=0.2)
            sc = ax.scatter(x[mask_good], y[mask_good],
                            c=table['chi2_red'][mask_good], cmap=cmap,
                            s=15, edgecolors='none')

            # Best-fit marker
            c_cross = ax.scatter(x[-1], y[-1], marker='+', s=55,
                                 color=min_color, linewidth=min_line_width)
            c_circle = ax.scatter(x[-1], y[-1], marker='o', s=45,
                                  facecolors='none', edgecolors=min_color,
                                  linewidth=min_line_width)

            # Ticks
            if axis_bool:
                ax.set_yticks(ax_ticks[i])
                ax.set_xticks(ax_ticks[-j-1])
            if i < n_params - 2 - (i + j):
                ax.set_xticklabels([])
            else:
                ax.set_xlabel(param_x)
            if j > 0:
                ax.set_yticklabels([])
            else:
                ax.set_ylabel(param_y)

            # Tick formatting
            ax.tick_params(axis='both', direction='inout', right=True, top=True, length=4)
            ax.tick_params(which='minor', direction='inout', length=2.25, width=0.8,
                           top=True, right=True)
            ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(2))
            ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))

            # Confidence regions
            if region:
                verts, codes = confidence_region(x, y, table['chi2_red'],
                                                 delta_chi2=5, shade="out", ax=ax)
                ax.add_patch(PathPatch(Path(verts, codes),
                                       facecolor='gray', alpha=0.0, zorder=0))

                for delta, face, edge, lw in [
                    (11.8, (0.7, 0.7, 0.7, 0.25), (.4, .4, .4, .4), 1.2),
                    (2.3, (0.2, 0.2, 0.2, 0.25), (.1, .1, .1, 1), 0.75),
                ]:
                    poly = confidence_region(x, y, table['chi2_red'], delta_chi2=delta, shade="in")
                    ax.add_patch(Polygon(poly, closed=True, facecolor=face,
                                         edgecolor=edge, linewidth=lw, zorder=0))
    
                    
    # Legend
    ax = axes[0, 0]
    ax.legend(handles=[(c_circle, c_cross)], labels=[r"Minimum $\chi^2$"],
              handler_map={tuple: HandlerCircleCross()},
              loc=(3.9 * fig_size / 12, -5.4 * fig_size / 12))

    # Colorbar
    norm = colors.Normalize(vmin=0, vmax=chi_lim_val)
    cmap = colormaps[cmap]
    sm = cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])

    tick_vals = np.arange(0, chi_lim_val + 1, 1)
    cbar_ax = fig.add_axes([0.5, 0.11, 0.22, 0.014])
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal',
                        extend='both', ticks=tick_vals)
    cbar.ax.set_title(r'$\Delta\chi^2_\mathrm{red}<$' + f'{chi_lim_val}',
                      fontsize=11, pad=8, loc='center')
    cbar.ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(5))
    cbar.ax.tick_params(which='minor', length=1.5, width=0.8,
                        direction='out', bottom=True)

    # Save
    plt.savefig(savepath, bbox_inches='tight')
    plt.close(fig)
    print(f"Corner plot saved: {savepath}")


# --- Angle helpers and plotting (your exact routine integrated) ---
def compute_angles(table, cfg):
    """
    Compute theta, psi, phi with DYNAMITE triax_pqu2tpp and add to table.
    """
    triax_cmp = cfg.system.get_component_from_class(
        dyn.physical_system.TriaxialVisibleComponent)

    angles = [
        triax_cmp.triax_pqu2tpp(p, q, u)
        for p, q, u in zip(table['p-stars'], table['q-stars'], table['u-stars'])
    ]
    theta, psi, phi = zip(*angles)

    table['theta'] = Column(theta)
    table['psi'] = Column(psi)
    table['phi'] = Column(phi)

    return table


def plot_angles(table, savepath, chi_lim_val=3):
    """
    Generate angular plots (polar + scatter) for model parameters.

    Computes theta, phi, psi from p, q, u using the TriaxialVisibleComponent.
    Adds confidence regions and colorbar.

    Parameters
    ----------
    table : astropy.table.Table
        The results table containing p-stars, q-stars, u-stars, chi2_red, etc.
    savepath : str
        Path where the plot will be saved.
    chi_lim_val : float, optional
        Maximum chi2_red value for color scaling and masking. Default = 3.
    """

    # --- Suppress DYNAMITE info logging ---
    logging.getLogger("dynamite").setLevel(logging.WARNING)

    # --- Load configuration for Triaxial component ---
    cfg = dyn.config_reader.Configuration(
        "/home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/M32_UCM_PC_config.yaml",
        reset_logging=True,
    )
    triax_cmp = cfg.system.get_component_from_class(
        dyn.physical_system.TriaxialVisibleComponent
    )

    # --- Compute angles ---
    angles = [
        triax_cmp.triax_pqu2tpp(p, q, u)
        for p, q, u in zip(table["p-stars"], table["q-stars"], table["u-stars"])
    ]
    theta, psi, phi = zip(*angles)
    table["theta"] = Column(theta)
    table["psi"] = Column(psi)
    table["phi"] = Column(phi)

    # --- Derived arrays ---
    theta = np.deg2rad(table["theta"])
    phi = np.deg2rad(table["phi"])
    R = 2 * np.sin(theta / 2)
    chi2 = table["chi2_red"]

    pad_frac = 0.1
    min_color, min_line_width = "cyan", 0.8
    mask_good = chi2 <= chi_lim_val
    mask_bad = ~mask_good

    # --- Figure + GridSpec ---
    fig = plt.figure(figsize=(7, 7))
    gs = gridspec.GridSpec(2, 2, figure=fig, wspace=0, hspace=0.1)

    # --- Polar plot ---
    ax0 = fig.add_subplot(gs[0, 0], projection="polar")
    ax0.set_theta_zero_location("N")
    ax0.set_theta_offset(np.pi / 2)
    ax0.set_theta_direction(1)

    # Data points
    ax0.scatter(
        phi[mask_bad],
        R[mask_bad],
        facecolors="none",
        edgecolors="black",
        marker="D",
        s=5,
        linewidths=0.2,
        zorder=10,
    )
    sc = ax0.scatter(
        phi[mask_good],
        R[mask_good],
        c=chi2[mask_good],
        cmap="inferno_r",
        s=12,
        edgecolors="none",
        zorder=10,
    )
    c_cross = ax0.scatter(phi[-1], R[-1], marker="+", 
                          color=min_color, s=55, lw=min_line_width, zorder=10)
    c_circle = ax0.scatter(phi[-1],R[-1], marker="o", 
                           facecolors="none", edgecolors=min_color, s=45, lw=min_line_width, zorder=10)

    # Confidence regions (polar coordinates)
    poly_cart = confidence_region(phi, R, chi2, delta_chi2=11.8, shade="in")
    x_poly, y_poly = poly_cart[:, 0], poly_cart[:, 1]
    R_poly = np.sin((y_poly / 2)) * 2
    ax0.fill(
        x_poly,
        R_poly,
        facecolor=(0.7, 0.7, 0.7, 0.25),
        edgecolor=(0.4, 0.4, 0.4, 0.4),
        linewidth=1.2,
        alpha=0.3,
        zorder=0,
    )

    poly_cart = confidence_region(phi, R, chi2, delta_chi2=2.3, shade="in")
    x_poly, y_poly = poly_cart[:, 0], poly_cart[:, 1]
    R_poly = np.sin((y_poly / 2)) * 2
    ax0.fill(
        x_poly,
        R_poly,
        facecolor=(0.2, 0.2, 0.2, 0.3),
        edgecolor=(0.1, 0.1, 0.1, 1),
        linewidth=0.75,
        zorder=0,
    )

    # Axis formatting
    rticks = np.sin(np.deg2rad([0, 30, 60, 90]) / 2) * 2
    ax0.set_rticks(rticks)
    ax0.set_yticklabels(["0", "30", "60", "90"])
    ax0.set_xticks(np.deg2rad([0, 30, 60, 90]))
    ax0.set_xticklabels(["0", "30", "60", "90"])
    ax0.set_xlabel(r"$\theta$ [degrees]", labelpad=15)
    ax0.xaxis.set_label_coords(1.175, 0.68)
    ax0.set_ylabel(r"$\phi$ [degrees]", labelpad=15)
    ax0.yaxis.set_label_coords(0.3, 1.0)
    ax0.yaxis.label.set_rotation(0)
    ax0.xaxis.label.set_rotation(90)
    ax0.set_thetamin(0)
    ax0.set_thetamax(90)
    ax0.minorticks_on()

    # --- Empty placeholder top-right ---
    ax1 = fig.add_subplot(gs[0, 1])
    ax1.axis("off")

    # --- Bottom scatter plots ---
    pair_info = [("theta", "psi"), ("phi", "psi")]
    axes = [fig.add_subplot(gs[1, i]) for i in range(2)]

    for ax, (xcol, ycol), invert_x in zip(axes, pair_info, [True, False]):
        x = table[xcol]
        y = table[ycol]

        ax.scatter(
            x[mask_bad], y[mask_bad], facecolors="none", edgecolors="black", marker="D", s=5, lw=0.2
        )
        sc_tmp = ax.scatter(
            x[mask_good], y[mask_good], c=chi2[mask_good], cmap="inferno_r", s=12
        )

        ax.scatter(x[-1], y[-1], marker="+", color="cyan", s=55, lw=min_line_width)
        ax.scatter(
            x[-1],
            y[-1],
            marker="o",
            facecolors="none",
            edgecolors="cyan",
            s=45,
            lw=min_line_width,
        )

        poly = confidence_region(x, y, chi2, delta_chi2=11.8, shade="in")
        patch = Polygon(
            poly,
            closed=True,
            facecolor=(0.8, 0.8, 0.8, 0.15),
            edgecolor=(0.5, 0.5, 0.5, 0.5),
            linewidth=1,
            alpha=0.2,
            zorder=0,
        )
        ax.add_patch(patch)

        poly = confidence_region(x, y, chi2, delta_chi2=2.3, shade="in")
        patch = Polygon(
            poly,
            closed=True,
            facecolor=(0.2, 0.2, 0.2, 0.3),
            edgecolor=(0.1, 0.1, 0.1, 1),
            linewidth=0.5,
            zorder=0,
        )
        ax.add_patch(patch)

        if invert_x:
            ax.set_xlim([90, 0])
            ax.set_ylabel(rf"$\{ycol}$ [degrees]")
            ax.set_yticks([90, 95, 100, 105], labels=["90", "95", "100", "105"])
            ax.set_xticks([20, 40, 60, 80], labels=["20", "40", "60", "80"])
            ax.tick_params(axis="both", direction="inout", right=True, top=True, length=5)
        else:
            ax.invert_xaxis()
            ax.set_yticks([90, 95, 100, 105], labels=[])
            ax.tick_params(axis="both", direction="inout", right=True, top=True, length=6)
            ax.set_xticks([30, 45, 60, 75, 90], labels=["30", "45", "60", "75", "90"])
            x_min, x_max = np.nanmin(x), np.nanmax(x)
            x_range = x_max - x_min
            x_pad = pad_frac * x_range
            ax.set_xlim(x_min - x_pad, x_max + x_pad)

        ax.set_xlabel(rf"$\{xcol}$ [degrees]")

        y_min, y_max = np.nanmin(y), np.nanmax(y)
        y_range = y_max - y_min
        y_pad = pad_frac * y_range
        ax.set_ylim(y_min - y_pad, y_max + y_pad)
        ax.minorticks_on()
        ax.tick_params(which="minor", direction="inout", length=2.5, width=0.8, top=True, right=True)
        ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(2))
        ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))

    # --- Legend ---
    ax = ax1
    # ax.legend(
    #     handles=[(sc,)],
    #     labels=[r"Minimum $\chi^2$"],
    #     loc=(0.25, 0.05),
    # )
    ax.legend(handles=[(c_circle, c_cross)], labels=[r"Minimum $\chi^2$"],
          handler_map={tuple: HandlerCircleCross()},
          loc=(0.25, 0.05))

    # --- Colorbar ---
    norm = colors.Normalize(vmin=0, vmax=chi_lim_val)
    sm = cm.ScalarMappable(norm=norm, cmap="inferno_r")
    sm.set_array([])
    tick_vals = np.arange(0, chi_lim_val + 1, 1)
    tick_labels = [str(int(t)) for t in tick_vals]
    cax = fig.add_axes([0.84, 0.53, 0.02, 0.35])
    cbar = fig.colorbar(sm, cax=cax, orientation="vertical", extend="both", ticks=tick_vals)
    cbar.ax.set_title(r"$\Delta\chi^2_\mathrm{red}<$" + f"{chi_lim_val}", fontsize=11, rotation=90, x=-1.5, y=0.37)
    cbar.ax.minorticks_on()
    cbar.ax.tick_params(which="minor", length=1.5, width=0.8, right=True)

    # --- Save ---
    plt.savefig(savepath, bbox_inches="tight", dpi=600)
    plt.close(fig)
    print(f"Angles plot saved: {savepath}")

# =============================================================================
# Main
# =============================================================================
if __name__ == "__main__":
    # Load + clean tables
    tables = [Table.read(f, format="ascii.basic", delimiter=" ", comment="#") for f in DATA_FILES]
    table = vstack(tables)
    table = table[~np.isnan(table["chi2"])]

    dof = 4 * 2326 - 8
    table["chi2_red"] = (table["chi2"] - np.nanmin(table["chi2"])) / (2 * np.sqrt(dof))
    for param in ["m-bh", "a-bh", "f-dh"]:
        table[param] = np.log10(table[param])

    table = table[np.argsort(table['chi2_red'])[::-1]]

    # Run plots
    plot_corner(table, CORNER_PLOT_PATH)
    if DO_ANGLES:
        plot_angles(table, ANGLES_PLOT_PATH)

    print("Finished. Time of run:", datetime.now()-time_start)



Found 34 files:
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v1/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v10/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v12/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v13/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v14/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v15/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v16/all_models.ecsv
 - /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/Results/M3/dynamite_results_m3_full_v17/all_models.ecsv
 - /home/andres-beamuz/Do

[INFO] 11:02:02 - dynamite.config_reader.Configuration - Config file /home/andres-beamuz/Documents/TFM/M32/Data_and_results_25_06/M32_UCM_PC_config.yaml read.
[INFO] 11:02:02 - dynamite.config_reader.Configuration - io_settings...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - Output directory tree: dynamite_m3_full_v1/.
[INFO] 11:02:02 - dynamite.config_reader.Configuration - system_attributes...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - model_components...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - system_parameters...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - orblib_settings...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - weight_solver_settings...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - Will attempt to recover partially run models.
[INFO] 11:02:02 - dynamite.config_reader.Configuration - parameter_space_settings...
[INFO] 11:02:02 - dynamite.config_reader.Configuration - multiprocessing_settings...
[I

Corner plot saved: Plots/Param_search/Param_corner_plot.pdf
Angles plot saved: Plots/Param_search/Angles_plot.pdf
Finished. Time of run: 0:00:04.430329
