In [None]:
import numpy as np
import plotly.graph_objects as go
from scipy.ndimage import minimum_filter
import math
import os


def load_precomputed_dissonance_map(base_freq=220, n_points=150, r_low=1.0, r_high=2.0):
    """Load pre-computed dissonance map from dataset chunks"""
    dataset_path = "dataset/EigenSpace_Data"

    print(f"Loading {n_points}³ pre-computed dataset...")

    # Find all chunk files
    chunk_files = sorted(
        [
            f
            for f in os.listdir(dataset_path)
            if f.startswith(f"harmonic-{base_freq}Hz-{n_points}nodes-chunk")
        ]
    )

    if not chunk_files:
        raise FileNotFoundError(
            f"No dataset chunks found for {base_freq}Hz, {n_points} nodes in {dataset_path}/"
        )

    # Load all chunks
    all_data = []
    for chunk_file in chunk_files:
        chunk_path = os.path.join(dataset_path, chunk_file)
        chunk_data = np.fromfile(chunk_path, dtype=np.float32)
        all_data.append(chunk_data)
        print(f"  Loaded {chunk_file}: {len(chunk_data):,} values")

    # Concatenate and reshape
    flat_data = np.concatenate(all_data)
    dissonance_3d = flat_data.reshape((n_points, n_points, n_points))

    # Recreate the coordinate ranges
    alpha_range = np.linspace(r_low, r_high, n_points)
    beta_range = np.linspace(r_low, r_high, n_points)
    gamma_range = np.linspace(r_low, r_high, n_points)

    print(f"✓ Loaded shape: {dissonance_3d.shape}")
    print(f"✓ Range: {dissonance_3d.min():.4f} to {dissonance_3d.max():.4f}")

    return alpha_range, beta_range, gamma_range, dissonance_3d


def dissmeasure(fvec, amp, model="min"):
    """Plomp-Levelt dissonance measure"""
    sort_idx = np.argsort(fvec)
    am_sorted = np.asarray(amp)[sort_idx]
    fr_sorted = np.asarray(fvec)[sort_idx]

    Dstar = 0.24
    S1 = 0.0207
    S2 = 18.96
    C1 = 5
    C2 = -5
    A1 = -3.51
    A2 = -5.75

    idx = np.transpose(np.triu_indices(len(fr_sorted), 1))
    fr_pairs = fr_sorted[idx]
    am_pairs = am_sorted[idx]

    Fmin = fr_pairs[:, 0]
    S = Dstar / (S1 * Fmin + S2)
    Fdif = fr_pairs[:, 1] - fr_pairs[:, 0]

    if model == "min":
        a = np.amin(am_pairs, axis=1)
    elif model == "product":
        a = np.prod(am_pairs, axis=1)
    else:
        raise ValueError('model should be "min" or "product"')

    SFdif = S * Fdif
    D = np.sum(a * (C1 * np.exp(A1 * SFdif) + C2 * np.exp(A2 * SFdif)))
    return D


def calculate_3d_dissonance_map(
    base_freq=500, r_low=1.0, r_high=2.0, n_points=40, num_harmonics=4, method="min"
):
    """Calculate 3D dissonance map"""
    alpha_range = np.linspace(r_low, r_high, n_points)
    beta_range = np.linspace(r_low, r_high, n_points)
    gamma_range = np.linspace(r_low, r_high, n_points)

    freq_base = base_freq * np.arange(1, num_harmonics + 1)
    amp_base = np.ones_like(freq_base)

    dissonance_3d = np.zeros((n_points, n_points, n_points))

    print("Calculating 3D dissonance map...")
    for i, alpha in enumerate(alpha_range):
        if i % 10 == 0:
            print(f"Progress: {i}/{n_points}")
        for j, beta in enumerate(beta_range):
            for k, gamma in enumerate(gamma_range):
                freq_alpha = alpha * freq_base
                freq_beta = beta * freq_base
                freq_gamma = gamma * freq_base

                f = np.concatenate((freq_base, freq_alpha, freq_beta, freq_gamma))
                a = np.concatenate((amp_base, amp_base, amp_base, amp_base))

                dissonance_3d[i, j, k] = dissmeasure(f, a, method)

    return alpha_range, beta_range, gamma_range, dissonance_3d


def visualize_3d_contour_stack(
    alpha_range,
    beta_range,
    gamma_range,
    dissonance_3d,
    cut_view=True,
    cut_axis="tetrahedron",
):
    """
    Stacked contour plots - journal quality
    Uses ALL calculated gamma layers for continuous appearance

    Parameters:
    - cut_view: If True, removes redundant data
    - cut_axis: 'tetrahedron' shows only α ≤ β ≤ γ region (no redundancy)
    """
    # Use ALL gamma indices for continuous layers (no sampling)
    n_levels = len(gamma_range)
    print(f"Creating {n_levels} continuous layers...")

    fig = go.Figure()

    # FIX 5: Improved color range to show ENTIRE map clearly
    vmin = np.percentile(dissonance_3d, 5)
    vmax = np.percentile(dissonance_3d, 80)

    # FIX 5: Enhanced colorscale for full visibility
    custom_colorscale = [
        [0.0, "rgba(0, 0, 0, 0.0)"],
        [0.1, "rgba(0, 0, 180, 1.0)"],
        [0.2, "rgba(0, 150, 255, 1.0)"],
        [0.3, "rgba(0, 200, 255, 1.0)"],
        [0.4, "rgba(100, 220, 255, 1.0)"],
        [0.5, "rgba(255, 255, 255, 1.0)"],
        [0.6, "rgba(255, 220, 100, 1.0)"],
        [0.7, "rgba(255, 200, 0, 1.0)"],
        [0.8, "rgba(255, 170, 0, 1.0)"],
        [0.9, "rgba(255, 140, 0, 1.0)"],
        [0.95, "rgba(255, 100, 0, 1.0)"],
        [1.0, "rgba(255, 0, 0, 1.0)"],
    ]

    for idx in range(n_levels):
        if idx % 10 == 0:
            print(f"  Layer {idx}/{n_levels}")

        z_offset = gamma_range[idx]

        # Apply cut to see inside the cube
        if cut_view:
            if cut_axis == "tetrahedron":
                # TETRAHEDRON: Show only where α ≤ β ≤ γ
                x_data = alpha_range
                y_data = beta_range
                surface_data = dissonance_3d[:, :, idx].copy()

                gamma_val = gamma_range[idx]

                for i in range(len(alpha_range)):
                    for j in range(len(beta_range)):
                        alpha_val = alpha_range[i]
                        beta_val = beta_range[j]

                        # Use NaN instead of mask_value for clean transparency
                        if alpha_val > beta_val or beta_val > gamma_val:
                            surface_data[j, i] = np.nan
            else:
                x_data = alpha_range
                y_data = beta_range
                surface_data = dissonance_3d[:, :, idx]
        else:
            x_data = alpha_range
            y_data = beta_range
            surface_data = dissonance_3d[:, :, idx]

        fig.add_trace(
            go.Surface(
                x=x_data,
                y=y_data,
                z=np.full_like(surface_data, z_offset).T,
                surfacecolor=surface_data,
                colorscale=custom_colorscale,
                cmin=vmin,
                cmax=vmax,
                showscale=(idx == 0),
                opacity=0.1,
                connectgaps=False,  # Changed to False to prevent connecting across NaN gaps
                reversescale=False,
                hovertemplate="α=%{x:.4f}<br>β=%{y:.4f}<br>γ=%{z:.4f}<extra></extra>",
                # FIX 1: Smaller, better-placed colorbar
                colorbar=(
                    dict(
                        title=dict(
                            text="Dissonance Level",
                            font=dict(size=12, color="black", family="Source Code Pro"),
                        ),
                        thickness=20,  # FIX 1: Reduced from ~20
                        len=0.65,  # FIX 1: Smaller length
                        x=1.0,  # FIX 1: Better position
                        tickfont=dict(size=10, color="black", family="Source Code Pro"),
                        tickformat=".1f",
                    )
                    if idx == 0
                    else None
                ),
            )
        )

    # FIX 4: Removed all chord markers - only showing dissonance map

    if cut_axis == "tetrahedron":
        cut_description = "α ≤ β ≤ γ"
    elif cut_view:
        cut_description = f"Cut view: {cut_axis}"
    else:
        cut_description = "Full cube"

    # FIX 6: Journal-quality layout
    fig.update_layout(
        scene=dict(
            # FIX 2: All axis labels white
            xaxis=dict(
                title=dict(
                    text="α (2nd frequency ratio)",
                    font=dict(size=12, color="black", family="Source Code Pro"),
                ),
                backgroundcolor="rgb(255, 255, 255)",
                gridcolor="rgb(230, 230, 230)",
                tickfont=dict(size=10, color="black", family="Source Code Pro"),
            ),
            yaxis=dict(
                title=dict(
                    text="β (3rd frequency ratio)",
                    font=dict(size=12, color="black", family="Source Code Pro"),
                ),
                backgroundcolor="rgb(255, 255, 255)",
                gridcolor="rgb(230, 230, 230)",
                tickfont=dict(size=10, color="black", family="Source Code Pro"),
            ),
            zaxis=dict(
                title=dict(
                    text="γ (4th frequency ratio)",
                    font=dict(size=12, color="black", family="Source Code Pro"),
                ),
                backgroundcolor="rgb(255, 255, 255)",
                gridcolor="rgb(230, 230, 230)",
                tickfont=dict(size=10, color="black", family="Source Code Pro"),
            ),
            camera=dict(eye=dict(x=1.0, y=1.0, z=1.0)),
            bgcolor="rgba(255, 255, 255, 1.0)",
            aspectmode="cube",
        ),
        width=1200,
        height=900,
        paper_bgcolor="rgb(255, 255, 255)",
        showlegend=False,
        margin=dict(l=0, r=0, t=0, b=0),
    )

    # Add ONLY annotation for title that floats over the plot
    fig.add_annotation(
        text=f"EigenSpace Topology ({cut_description})",
        xref="paper",
        yref="paper",
        x=0.5,
        y=0.98,
        xanchor="center",
        yanchor="top",
        showarrow=False,
        font=dict(size=16, color="black", family="Source Code Pro"),
    )

    return fig


# Main execution
if __name__ == "__main__":
    print("=== 3D HARMONIC SERIES TOPOLOGY ===\n")
    print("Focus: Tetrahedron view\n")

    # Calculate the 3D dissonance map with HIGH RESOLUTION on gamma axis
    alpha_range, beta_range, gamma_range, dissonance_3d = (
        load_precomputed_dissonance_map(
            base_freq=220,  # A3
            n_points=150,
            r_low=1.0,
            r_high=2.0,
        )
    )

    print(f"\nMap calculated: {dissonance_3d.shape}")
    print(f"Dissonance range: {dissonance_3d.min():.4f} to {dissonance_3d.max():.4f}\n")

    print("Generating TETRAHEDRON view\n")

    # FIX 4: No chord markers, just pure dissonance map
    fig4 = visualize_3d_contour_stack(
        alpha_range,
        beta_range,
        gamma_range,
        dissonance_3d,
        cut_view=True,
        cut_axis="tetrahedron",
    )
    fig4.show()

    print("\n=== Visualization complete ===")