In [None]:
import numpy as np

input_params_keys = {
    "MHL": "lightest Higgs mass",
    "MHH": "heavy Higgs mass",
    "MHA": "pseudo-scalar Higgs mass",
    "MHp": "charged Higgs mass",
    "lam5": "Split..",
    "cba": "cos(beta-alpha)",
    "tb": "tan(beta)",
    "vev": "vacuum expectation value",
}


In [None]:
from src.thdm_param_scan import calculate_random_sampling, check_vacuum_conditions


In [None]:
vev = 246.22  # Vacuum expectation value in GeV
higgs_mass = 125.35  # Higgs mass in GeV

mass_range = (0, 8 * higgs_mass)

epsilon_range = (1e-12, 1e-1)

delta_range = (-2 / 9, 0.6 * (higgs_mass / (vev)) ** 2)

n_points = int(5e5)

results_df = calculate_random_sampling(mass_range, epsilon_range, delta_range, n_samples=n_points)

In [None]:
results_with_conditions = check_vacuum_conditions(results_df, verbose=True)

In [None]:
conditions = {
    "vacuum_stable": [
        ("stability_lam1_positive", "λ₁ > 0"),
        ("stability_lam2_positive", "λ₂ > 0"),
        ("stability_lam3_condition", "λ₃ + √(λ₁λ₂) > 0"),
        ("stability_lam4_condition", "λ₃ + λ₄ - |λ₅| + √(λ₁λ₂) > 0"),
        ("global_minimum_condition", "Global Minimum"),
        ("vacuum_stable", "vacuum_stable"),
    ],
    "unitarity": [
        ("unitarity_a_plus & unitarity_a_minus", "|a±| ≤ 8π"),
        ("unitarity_b_plus & unitarity_b_minus", "|b±| ≤ 8π"),
        ("unitarity_c_plus & unitarity_c_minus", "|c±| ≤ 8π"),
        ("unitarity_e_plus & unitarity_e_minus", "|e±| ≤ 8π"),
        ("unitarity_f_plus & unitarity_f_minus", "|f±| ≤ 8π"),
        ("unitarity_g_plus & unitarity_g_minus", "|g±| ≤ 8π"),
        ("unitarity_satisfied", "Unitarity satisfied"),
        ("all_conditions_satisfied", "ALL conditions satisfied"),
    ],
}

In [None]:

import os
import matplotlib.pyplot as plt
plt.ioff()
pdfs_dir = os.path.join(os.getcwd(),"pdfs/THDM")
os.makedirs(pdfs_dir, exist_ok=True)

def plot_condition_histogram_on_ax(
    ax,
    df,
    condition_query,
    column_name,
    bins=50,
    alpha=0.5,
    all_points_color="blue",
    condition_color="red",
    xlabel=None,
    title=None,
    xlim=None,
):
    """
    Plot histogram on provided axis instead of creating new figure
    """

    x_max = df[column_name].max()
    x_min = df[column_name].min()
    # Get data that satisfies the condition
    data_satisfy = df.query(condition_query)[column_name]
    total_points = len(df)

    # Plot histogram of all points
    n1, bins1, patches1 = ax.hist(
        df[column_name],
        bins=bins,
        range=(x_min, x_max),
        color=all_points_color,
        alpha=alpha,
        label="All points",
        edgecolor="black",
        weights=np.ones(len(df)) / total_points,
    )

    # Plot histogram of points that satisfy condition
    n2, bins2, patches2 = ax.hist(
        data_satisfy,
        bins=bins,
        range=(x_min, x_max),
        color=condition_color,
        alpha=alpha,
        label=f"Condition satisfied ({len(data_satisfy):,} points)",
        edgecolor="black",
        weights=np.ones(len(data_satisfy)) / total_points,
    )

    # Set labels and title
    if xlabel is None:
        # Auto-format common parameter names with subscripts
        xlabel_map = {
            "lam5": "λ₅",
            "lam1": "λ₁",
            "lam2": "λ₂",
            "lam3": "λ₃",
            "lam4": "λ₄",
            "delta_samples": "δ (λ₅ = (MHH² - MHA²)/v² + δ)",
            "MHH": "MHH [GeV]",
            "MHA": "MHA [GeV]",
            "MHp": "MHp [GeV]",
            "epsilon": "ε",
        }
        xlabel = xlabel_map.get(column_name, column_name)

    ax.set_xlabel(xlabel)
    ax.set_ylabel("Frequency")

    if title is None:
        title = f"Histogram of {xlabel} Values"
    ax.set_title(title)

    # Set x-axis limits
    if xlim is None:
        col_min = df[column_name].min()
        col_max = df[column_name].max()
        ax.set_xlim(col_min, col_max)
    else:
        ax.set_xlim(xlim)
    
    # Calculate appropriate y limits based on actual histogram data
    max_frequency = max(np.max(n1), np.max(n2))
    y_min = max(1e-5, max_frequency * 0.001)  # Use 1e-5 as absolute minimum
    y_max = max_frequency * 1.5  # Add 50% padding above maximum
    
    ax.set_ylim(y_min, y_max)
    
    # Add legend and grid
    ax.legend()
    ax.grid(True, alpha=0.3)
    


In [None]:
for variable in list(results_df.columns):
    if variable == "tb_param":
        continue
    for condition_category, condition_list in conditions.items():
        
        n_conditions = len(condition_list)
        subplots_per_row = 2
        n_rows = (n_conditions + subplots_per_row - 1) // subplots_per_row
        
        fig, axes = plt.subplots(n_rows, subplots_per_row, figsize=(12, 6*n_rows))
        
        if n_rows == 1:
            axes = axes.reshape(1, -1)
        axes = axes.flatten()
        
        for i, (condition, condition_label) in enumerate(condition_list):
            if i < len(axes):
                ax = axes[i]
                plot_condition_histogram_on_ax(
                    ax,
                    results_with_conditions,
                    condition,
                    variable,
                    title=f"{condition_label}",
                    condition_color="green" if "all_conditions" in condition else "red",
                    bins=50,
                )
                ax.set_yscale("log")
        for i in range(n_conditions, len(axes)):
            axes[i].set_visible(False)
        plt.tight_layout()
        filename = os.path.join(pdfs_dir,f"{condition_category}", f"histograms_for_{variable}.pdf")
        png_filename = os.path.join(pdfs_dir,f"{condition_category}", f"histograms_for_{variable}.png")
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # plt.savefig(filename)
        plt.savefig(png_filename, dpi=300)
        plt.close(fig)

In [None]:
def plot_correlation_scatter(
    df,
    x_var,
    y_var,
    condition_query,
    condition_label,
    figsize=(8, 6),
    alpha=0.6,
    s=1,
    failed_color="red",
    passed_color="green",
    save_path=None,
):
    """
    Create a scatter plot showing correlation between two variables,
    colored by whether a condition is satisfied.
    """
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    
    # Split data based on condition
    df_failed = df.query(f"not ({condition_query})")
    df_passed = df.query(condition_query)
    
    # Plot failed points first (so they appear behind passed points)
    if len(df_failed) > 0:
        ax.scatter(
            df_failed[x_var],
            df_failed[y_var],
            c=failed_color,
            alpha=alpha,
            s=s,
            label=f"Failed ({len(df_failed):,} points)",
            edgecolors='none'
        )
    
    # Plot passed points
    if len(df_passed) > 0:
        ax.scatter(
            df_passed[x_var],
            df_passed[y_var],
            c=passed_color,
            alpha=alpha,
            s=s,
            label=f"Passed ({len(df_passed):,} points)",
            edgecolors='none'
        )
    
    # Format labels using the same mapping as histograms
    xlabel_map = {
        "lam5": "λ₅",
        "lam1": "λ₁",
        "lam2": "λ₂",
        "lam3": "λ₃",
        "lam4": "λ₄",
        "delta_samples": "δ (λ₅ = (MHH² - MHA²)/v² + δ)",
        "MHH": "MHH [GeV]",
        "MHA": "MHA [GeV]",
        "MHp": "MHp [GeV]",
        "epsilon": "ε",
    }
    
    x_label = xlabel_map.get(x_var, x_var)
    y_label = xlabel_map.get(y_var, y_var)
    
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(f"{condition_label}\n{x_label} vs {y_label}")
    
    # Add legend and grid
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()

In [None]:
# Get list of variables to plot (excluding tb_param)
variables = [var for var in results_df.columns if var != "tb_param"]

print(f"Variables to plot: {variables}")
print(f"Total number of PDFs per condition: {len(variables)} (one per variable vs all others)")

def plot_correlation_scatter_on_ax(
    ax,
    df,
    x_var,
    y_var,
    condition_query,
    alpha=0.3,
    s=0.8,
    failed_color="red",
    passed_color="green",
):
    """
    Create a scatter plot on a given axis showing correlation between two variables,
    colored by whether a condition is satisfied.
    """
    
    # Split data based on condition
    df_failed = df.query(f"not ({condition_query})")
    df_passed = df.query(condition_query)
    
    # Plot failed points first (so they appear behind passed points)
    if len(df_failed) > 0:
        ax.scatter(
            df_failed[x_var],
            df_failed[y_var],
            c=failed_color,
            alpha=alpha,
            s=s,
            label=f"Failed ({len(df_failed):,})",
            edgecolors='none'
        )
    
    # Plot passed points
    if len(df_passed) > 0:
        ax.scatter(
            df_passed[x_var],
            df_passed[y_var],
            c=passed_color,
            alpha=alpha,
            s=s,
            label=f"Passed ({len(df_passed):,})",
            edgecolors='none'
        )
    
    # Format labels using the same mapping as histograms
    xlabel_map = {
        "lam5": "λ₅",
        "lam1": "λ₁",
        "lam2": "λ₂",
        "lam3": "λ₃",
        "lam4": "λ₄",
        "delta_samples": "δ (λ₅ = (MHH² - MHA²)/v² + δ)",
        "MHH": "MHH [GeV]",
        "MHA": "MHA [GeV]",
        "MHp": "MHp [GeV]",
        "epsilon": "ε",
    }
    
    x_label = xlabel_map.get(x_var, x_var)
    y_label = xlabel_map.get(y_var, y_var)
    
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(f"{x_label} vs {y_label}")
    
    # Set log scale for epsilon variable
    if x_var == "epsilon":
        ax.set_xscale("log")
    if y_var == "epsilon":
        ax.set_yscale("log")
    
    # Add legend and grid
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)


In [None]:
from tqdm.auto import tqdm

# Generate correlation plots with subfigures
for condition_category, condition_list in conditions.items():
    print(f"\nGenerating correlation plots for {condition_category}...")
    
    for condition, condition_label in condition_list:
        print(f"  Processing condition: {condition_label}")
        
        # Create directory for this condition
        condition_dir = os.path.join(pdfs_dir, "correlations", condition_category, condition.replace(" & ", "_and_").replace(" ", "_"))
        os.makedirs(condition_dir, exist_ok=True)
        
        # For each variable, create a PDF showing it vs all other variables
        for x_var in variables:
            other_variables = [v for v in variables if v != x_var]
            n_plots = len(other_variables)
            
            if n_plots == 0:
                continue
                
            # Calculate subplot layout (prefer roughly square layout)
            n_cols = int(np.ceil(np.sqrt(n_plots)))
            n_rows = int(np.ceil(n_plots / n_cols))
            
            # Create figure with subplots
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
            
            # Handle case when there's only one subplot
            if n_plots == 1:
                axes = [axes]
            elif n_rows == 1:
                axes = axes.reshape(1, -1)
            axes = axes.flatten()
            
            # Create scatter plot for each variable pair
            for i, y_var in enumerate(tqdm(other_variables, desc=f"Plotting {x_var} vs others", total = len(other_variables))):
                if i < len(axes):
                    ax = axes[i]
                    plot_correlation_scatter_on_ax(
                        ax,
                        results_with_conditions,
                        x_var,
                        y_var,
                        condition,
                        alpha=0.3,
                        s=0.8
                    )
            
            # Hide unused subplots
            for i in range(n_plots, len(axes)):
                axes[i].set_visible(False)
            
            # Add overall title
            xlabel_map = {
                "lam5": "λ₅",
                "lam1": "λ₁", 
                "lam2": "λ₂",
                "lam3": "λ₃",
                "lam4": "λ₄",
                "delta_samples": "δ",
                "MHH": "MHH",
                "MHA": "MHA", 
                "MHp": "MHp",
                "epsilon": "ε",
            }
            x_label = xlabel_map.get(x_var, x_var)
            fig.suptitle(f"{condition_label}\n{x_label} vs All Other Variables", fontsize=14, y=0.98)
            
            plt.tight_layout()
            plt.subplots_adjust(top=0.92)  # Make room for suptitle
            
            # Save the figure
            save_path = os.path.join(condition_dir, f"{x_var}_vs_all_others.pdf")
            # plt.savefig(save_path, dpi=150, bbox_inches='tight')
            png_save_path = save_path.replace('.pdf', '.png')
            plt.savefig(png_save_path, dpi=300, bbox_inches='tight')
            plt.close(fig)
        
        print(f"    Saved {len(variables)} correlation PDFs (one per variable vs all others)")

print("\nAll correlation plots have been generated and saved!")

In [None]:
# Re-define conditions dictionary to avoid variable overwriting issues
conditions_dict = {
    "vacuum_stable": [
        ("stability_lam1_positive", "λ₁ > 0"),
        ("stability_lam2_positive", "λ₂ > 0"),
        ("stability_lam3_condition", "λ₃ + √(λ₁λ₂) > 0"),
        ("stability_lam4_condition", "λ₃ + λ₄ - |λ₅| + √(λ₁λ₂) > 0"),
        ("global_minimum_condition", "Global Minimum"),
        ("vacuum_stable", "vacuum_stable"),
    ],
    "unitarity": [
        ("unitarity_a_plus & unitarity_a_minus", "|a±| ≤ 8π"),
        ("unitarity_b_plus & unitarity_b_minus", "|b±| ≤ 8π"),
        ("unitarity_c_plus & unitarity_c_minus", "|c±| ≤ 8π"),
        ("unitarity_e_plus & unitarity_e_minus", "|e±| ≤ 8π"),
        ("unitarity_f_plus & unitarity_f_minus", "|f±| ≤ 8π"),
        ("unitarity_g_plus & unitarity_g_minus", "|g±| ≤ 8π"),
        ("unitarity_satisfied", "Unitarity satisfied"),
        ("all_conditions_satisfied", "ALL conditions satisfied"),
    ],
}

# print a report of the scan in a file
report_path = os.path.join(pdfs_dir, "scan_report.txt")
with open(report_path, "w") as f:
    f.write("THDM Parameter Scan Report\n")
    f.write("===========================\n\n")
    
    f.write(f"Total points sampled: {n_points:,}\n")
    f.write(f"Mass range: {mass_range[0]} to {mass_range[1]} GeV\n")
    f.write(f"Epsilon range: {epsilon_range[0]} to {epsilon_range[1]}\n")
    f.write(f"Delta range: {delta_range[0]:.6f} to {delta_range[1]:.6f}\n\n")
    
    # Dataset information
    f.write("Dataset Information:\n")
    f.write(f"Generated DataFrame with shape: {results_df.shape}\n")
    f.write(f"Columns: {list(results_df.columns)}\n\n")
    
    f.write("DataFrame with vacuum conditions:\n")
    f.write(f"New shape: {results_with_conditions.shape}\n")
    f.write(f"New columns added: {[col for col in results_with_conditions.columns if col not in results_df.columns]}\n\n")
    
    # Summary statistics
    f.write("=" * 60 + "\n")
    f.write("VACUUM CONDITIONS SUMMARY\n")
    f.write("=" * 60 + "\n")
    
    total_points = len(results_with_conditions)
    f.write(f"Total points analyzed: {total_points:,}\n")
    
    # Individual condition statistics
    f.write(f"\nIndividual stability conditions:\n")
    f.write(f"  λ₁ > 0:                           {results_with_conditions['stability_lam1_positive'].sum():,} ({100*results_with_conditions['stability_lam1_positive'].mean():.2f}%)\n")
    f.write(f"  λ₂ > 0:                           {results_with_conditions['stability_lam2_positive'].sum():,} ({100*results_with_conditions['stability_lam2_positive'].mean():.2f}%)\n")
    f.write(f"  λ₃ + √(λ₁λ₂) > 0:                 {results_with_conditions['stability_lam3_condition'].sum():,} ({100*results_with_conditions['stability_lam3_condition'].mean():.2f}%)\n")
    
    # Check if the column name exists (could be either stability_lam345_condition or stability_lam4_condition)
    if 'stability_lam345_condition' in results_with_conditions.columns:
        f.write(f"  λ₃ + λ₄ - |λ₅| + √(λ₁λ₂) > 0:     {results_with_conditions['stability_lam345_condition'].sum():,} ({100*results_with_conditions['stability_lam345_condition'].mean():.2f}%)\n")
    elif 'stability_lam4_condition' in results_with_conditions.columns:
        f.write(f"  λ₃ + λ₄ - |λ₅| + √(λ₁λ₂) > 0:     {results_with_conditions['stability_lam4_condition'].sum():,} ({100*results_with_conditions['stability_lam4_condition'].mean():.2f}%)\n")
    
    f.write(f"  Global minimum satisfied:         {results_with_conditions['global_minimum_condition'].sum():,} ({100*results_with_conditions['global_minimum_condition'].mean():.2f}%)\n")
    
    f.write(f"\nUnitarity conditions:\n")
    f.write(f"  |a±| ≤ 8π:                        {(results_with_conditions['unitarity_a_plus'] & results_with_conditions['unitarity_a_minus']).sum():,} ({100*(results_with_conditions['unitarity_a_plus'] & results_with_conditions['unitarity_a_minus']).mean():.2f}%)\n")
    f.write(f"  |b±| ≤ 8π:                        {(results_with_conditions['unitarity_b_plus'] & results_with_conditions['unitarity_b_minus']).sum():,} ({100*(results_with_conditions['unitarity_b_plus'] & results_with_conditions['unitarity_b_minus']).mean():.2f}%)\n")
    f.write(f"  |c±| ≤ 8π:                        {(results_with_conditions['unitarity_c_plus'] & results_with_conditions['unitarity_c_minus']).sum():,} ({100*(results_with_conditions['unitarity_c_plus'] & results_with_conditions['unitarity_c_minus']).mean():.2f}%)\n")
    f.write(f"  |e±| ≤ 8π:                        {(results_with_conditions['unitarity_e_plus'] & results_with_conditions['unitarity_e_minus']).sum():,} ({100*(results_with_conditions['unitarity_e_plus'] & results_with_conditions['unitarity_e_minus']).mean():.2f}%)\n")
    f.write(f"  |f±| ≤ 8π:                        {(results_with_conditions['unitarity_f_plus'] & results_with_conditions['unitarity_f_minus']).sum():,} ({100*(results_with_conditions['unitarity_f_plus'] & results_with_conditions['unitarity_f_minus']).mean():.2f}%)\n")
    f.write(f"  |g±| ≤ 8π:                        {(results_with_conditions['unitarity_g_plus'] & results_with_conditions['unitarity_g_minus']).sum():,} ({100*(results_with_conditions['unitarity_g_plus'] & results_with_conditions['unitarity_g_minus']).mean():.2f}%)\n")
    
    f.write(f"\nOverall conditions:\n")
    f.write(f"  Vacuum stable:                    {results_with_conditions['vacuum_stable'].sum():,} ({100*results_with_conditions['vacuum_stable'].mean():.2f}%)\n")
    f.write(f"  Unitarity satisfied:              {results_with_conditions['unitarity_satisfied'].sum():,} ({100*results_with_conditions['unitarity_satisfied'].mean():.2f}%)\n")
    f.write(f"  ALL conditions satisfied:         {results_with_conditions['all_conditions_satisfied'].sum():,} ({100*results_with_conditions['all_conditions_satisfied'].mean():.2f}%)\n")


print(f"\nScan report saved to {report_path}")

# Also print the detailed summary to console
print("\n" + "=" * 60)
print("VACUUM CONDITIONS SUMMARY")
print("=" * 60)

total_points = len(results_with_conditions)
print(f"Total points analyzed: {total_points:,}")

# Individual condition statistics
print(f"\nIndividual stability conditions:")

print(f"  λ₁ > 0:                           {results_with_conditions['stability_lam1_positive'].sum():,} ({100*results_with_conditions['stability_lam1_positive'].mean():.2f}%)")
print(f"  λ₂ > 0:                           {results_with_conditions['stability_lam2_positive'].sum():,} ({100*results_with_conditions['stability_lam2_positive'].mean():.2f}%)")
print(f"  λ₃ + √(λ₁λ₂) > 0:                 {results_with_conditions['stability_lam3_condition'].sum():,} ({100*results_with_conditions['stability_lam3_condition'].mean():.2f}%)")

# Check if the column name exists (could be either stability_lam345_condition or stability_lam4_condition)
if 'stability_lam345_condition' in results_with_conditions.columns:
    print(f"  λ₃ + λ₄ - |λ₅| + √(λ₁λ₂) > 0:     {results_with_conditions['stability_lam345_condition'].sum():,} ({100*results_with_conditions['stability_lam345_condition'].mean():.2f}%)")
elif 'stability_lam4_condition' in results_with_conditions.columns:
    print(f"  λ₃ + λ₄ - |λ₅| + √(λ₁λ₂) > 0:     {results_with_conditions['stability_lam4_condition'].sum():,} ({100*results_with_conditions['stability_lam4_condition'].mean():.2f}%)")

print(f"  Global minimum satisfied:         {results_with_conditions['global_minimum_condition'].sum():,} ({100*results_with_conditions['global_minimum_condition'].mean():.2f}%)")

print(f"\nUnitarity conditions:")
print(f"  |a±| ≤ 8π:                        {(results_with_conditions['unitarity_a_plus'] & results_with_conditions['unitarity_a_minus']).sum():,} ({100*(results_with_conditions['unitarity_a_plus'] & results_with_conditions['unitarity_a_minus']).mean():.2f}%)")
print(f"  |b±| ≤ 8π:                        {(results_with_conditions['unitarity_b_plus'] & results_with_conditions['unitarity_b_minus']).sum():,} ({100*(results_with_conditions['unitarity_b_plus'] & results_with_conditions['unitarity_b_minus']).mean():.2f}%)")
print(f"  |c±| ≤ 8π:                        {(results_with_conditions['unitarity_c_plus'] & results_with_conditions['unitarity_c_minus']).sum():,} ({100*(results_with_conditions['unitarity_c_plus'] & results_with_conditions['unitarity_c_minus']).mean():.2f}%)")
print(f"  |e±| ≤ 8π:                        {(results_with_conditions['unitarity_e_plus'] & results_with_conditions['unitarity_e_minus']).sum():,} ({100*(results_with_conditions['unitarity_e_plus'] & results_with_conditions['unitarity_e_minus']).mean():.2f}%)")
print(f"  |f±| ≤ 8π:                        {(results_with_conditions['unitarity_f_plus'] & results_with_conditions['unitarity_f_minus']).sum():,} ({100*(results_with_conditions['unitarity_f_plus'] & results_with_conditions['unitarity_f_minus']).mean():.2f}%)")
print(f"  |g±| ≤ 8π:                        {(results_with_conditions['unitarity_g_plus'] & results_with_conditions['unitarity_g_minus']).sum():,} ({100*(results_with_conditions['unitarity_g_plus'] & results_with_conditions['unitarity_g_minus']).mean():.2f}%)")

print(f"\nOverall conditions:")
print(f"  Vacuum stable:                    {results_with_conditions['vacuum_stable'].sum():,} ({100*results_with_conditions['vacuum_stable'].mean():.2f}%)")
print(f"  Unitarity satisfied:              {results_with_conditions['unitarity_satisfied'].sum():,} ({100*results_with_conditions['unitarity_satisfied'].mean():.2f}%)")
print(f"  ALL conditions satisfied:         {results_with_conditions['all_conditions_satisfied'].sum():,} ({100*results_with_conditions['all_conditions_satisfied'].mean():.2f}%)")