In [1]:
import os, json
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm, colors
import ipywidgets as widgets
from IPython.display import display, clear_output


# =============================
# Utilities
# =============================
def auto_ticks(vmin, vmax, n=5):
    if vmax == vmin:
        return [vmin]
    return list(np.linspace(vmin, vmax, n))

def ensure_dir(path):
    os.makedirs(path, exist_ok=True)


# =============================
# Core physics
# =============================
def r1(x, y, z, q):
    return np.sqrt((x + q)**2 + y**2 + z**2)

def r2(x, y, z, q):
    return np.sqrt((x - q)**2 + y**2 + z**2)

def delta_r(x, y, z, q):
    return r2(x, y, z, q) - r1(x, y, z, q)

# s = -1 outgoing, +1 incoming (teaching sign convention)
def delta_phi_exact(x, y, z, q, k0, dphi0, s1, s2):
    R1 = r1(x, y, z, q)
    R2 = r2(x, y, z, q)
    return (s2*k0*R2) - (s1*k0*R1) + dphi0

def delta_phi_fraunhofer(x, y, z, q, k0, dphi0, normal_axis="y"):
    # Δr ≈ 2q * x/L   (sources separated along x by 2q)
    if normal_axis == "y":
        L = np.maximum(np.abs(y), 1e-6)
    elif normal_axis == "z":
        L = np.maximum(np.abs(z), 1e-6)
    else:
        L = np.maximum(np.sqrt(y*y + z*z), 1e-6)
    return k0 * (2.0*q*(x/L)) + dphi0

def intensity_general(x, y, z, q, k0, dphi0, s1, s2, A2_over_A1, gamma_eff,
                      use_1_over_r, use_fraunhofer=False, fraunhofer_normal="y"):
    """
    I = I1 + I2 + 2*gamma_eff*sqrt(I1 I2)*cos(Δφ)
    """
    R1 = r1(x, y, z, q)
    R2 = r2(x, y, z, q)

    A1 = 1.0
    A2 = A2_over_A1

    if use_1_over_r:
        a1 = A1 / (R1 + 1e-6)
        a2 = A2 / (R2 + 1e-6)
    else:
        a1 = A1
        a2 = A2

    I1 = a1*a1
    I2 = a2*a2

    if use_fraunhofer:
        dph = delta_phi_fraunhofer(x, y, z, q, k0, dphi0, normal_axis=fraunhofer_normal)
    else:
        dph = (s2*k0*R2) - (s1*k0*R1) + dphi0

    return I1 + I2 + 2.0*gamma_eff*np.sqrt(I1*I2)*np.cos(dph)

def field_snapshot_real(x, y, z, q, k0, phase_t, dphi0, s1, s2, A2_over_A1, use_1_over_r):
    """
    Real snapshot:
      E = a1 cos(phase - s1 k r1) + a2 cos(phase - s2 k r2 + dphi0)
    """
    R1 = r1(x, y, z, q)
    R2 = r2(x, y, z, q)

    A1 = 1.0
    A2 = A2_over_A1

    if use_1_over_r:
        a1 = A1 / (R1 + 1e-6)
        a2 = A2 / (R2 + 1e-6)
    else:
        a1 = A1
        a2 = A2

    E1 = a1 * np.cos(phase_t - s1*k0*R1)
    E2 = a2 * np.cos(phase_t - s2*k0*R2 + dphi0)
    return E1 + E2

def spectral_average_intensity(X, Y, Z, lam0, dlam, nlam, q, dphi0, s1, s2,
                               A2_over_A1, gamma_user, pol_factor,
                               use_1_over_r, fraunhofer_on, fraunhofer_normal):
    """
    Average intensity over wavelength distribution (uniform sampling in λ).
    """
    lam0 = float(lam0)
    dlam = float(dlam)
    nlam = int(max(1, nlam))

    if dlam <= 0 or nlam == 1:
        k0 = 2*np.pi/lam0
        gamma_eff = float(gamma_user) * float(pol_factor)
        return intensity_general(X, Y, Z, q, k0, dphi0, s1, s2, A2_over_A1,
                                 gamma_eff, use_1_over_r,
                                 use_fraunhofer=fraunhofer_on,
                                 fraunhofer_normal=fraunhofer_normal)

    lams = np.linspace(lam0 - dlam/2, lam0 + dlam/2, nlam)

    I_acc = 0.0
    for lam in lams:
        k0 = 2*np.pi/lam
        gamma_eff = float(gamma_user) * float(pol_factor)
        I_acc += intensity_general(X, Y, Z, q, k0, dphi0, s1, s2, A2_over_A1,
                                   gamma_eff, use_1_over_r,
                                   use_fraunhofer=fraunhofer_on,
                                   fraunhofer_normal=fraunhofer_normal)
    return I_acc / nlam

def estimate_fringe_spacing_fft(x, y):
    x = np.asarray(x)
    y = np.asarray(y)
    if len(x) < 32:
        return None
    y0 = y - np.mean(y)
    w = np.hanning(len(y0))
    yw = y0 * w
    dx = float(np.mean(np.diff(x)))
    if dx <= 0:
        return None
    Y = np.fft.rfft(yw)
    f = np.fft.rfftfreq(len(yw), d=dx)
    mag = np.abs(Y)
    mag[0] = 0.0
    idx = int(np.argmax(mag))
    if idx <= 0 or f[idx] <= 0:
        return None
    return float(1.0 / f[idx])


# =============================
# Visualization helpers
# =============================
def draw_edges(ax, xlim, ylim, zlim, lw=1.0):
    xs_ = [xlim[0], xlim[1]]
    ys_ = [ylim[0], ylim[1]]
    zs_ = [zlim[0], zlim[1]]
    for x in xs_:
        for y in ys_:
            ax.plot([x, x], [y, y], zs_, color="k", lw=lw)
    for x in xs_:
        for z in zs_:
            ax.plot([x, x], ys_, [z, z], color="k", lw=lw)
    for y in ys_:
        for z in zs_:
            ax.plot(xs_, [y, y], [z, z], color="k", lw=lw)

def unit_radial(x, y, z, sx, sy, sz):
    dx = x - sx
    dy = y - sy
    dz = z - sz
    R = np.sqrt(dx*dx + dy*dy + dz*dz) + 1e-9
    return dx/R, dy/R, dz/R


# =============================
# UI widgets
# =============================
# Domain (crop)
x_min = widgets.FloatSlider(value=-20, min=-300, max=0, step=1, description="x min")
x_max = widgets.FloatSlider(value= 20, min=0,   max=300, step=1, description="x max")
y_min = widgets.FloatSlider(value=-20, min=-300, max=0, step=1, description="y min")
y_max = widgets.FloatSlider(value=  0, min=-50,  max=300, step=1, description="y max")
z_min = widgets.FloatSlider(value=-20, min=-300, max=0, step=1, description="z min")
z_max = widgets.FloatSlider(value= 20, min=0,   max=300, step=1, description="z max")

n_slider = widgets.IntSlider(value=220, min=80, max=520, step=10, description="resolution n")
stride = widgets.IntSlider(value=4, min=2, max=14, step=1, description="3D stride")

# Physics
lam0 = widgets.FloatSlider(value=0.6328, min=0.4, max=1.2, step=0.0005, description="λ0 (µm)")
dlam = widgets.FloatSlider(value=0.0, min=0.0, max=0.2, step=0.001, description="Δλ (µm)")
nlam = widgets.IntSlider(value=1, min=1, max=31, step=2, description="Nλ")

q_slider = widgets.FloatSlider(value=5.0, min=0.5, max=60.0, step=0.5, description="q (µm)")
dphi0 = widgets.FloatSlider(value=0.0, min=-np.pi, max=np.pi, step=0.02, description="Δφ0")

A2_over_A1 = widgets.FloatSlider(value=1.0, min=0.0, max=3.0, step=0.05, description="A2/A1")
gamma_user = widgets.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.02, description="γ (coh)")
pol_deg = widgets.FloatSlider(value=0.0, min=0.0, max=90.0, step=1.0, description="pol θ (deg)")

use_I_decay = widgets.Checkbox(value=False, description="Use 1/r in Intensity")
use_E_decay = widgets.Checkbox(value=False, description="Use 1/r in Field")

incoming1 = widgets.Checkbox(value=False, description="Wave1 incoming (converging)")
incoming2 = widgets.Checkbox(value=False, description="Wave2 incoming (converging)")

# Views / models
main_view = widgets.ToggleButtons(
    options=[("Intensity (avg)", "I"), ("Field snapshot E", "E"), ("Δr", "dr"), ("Δφ mod 2π", "dphi")],
    value="I",
    description="Main:"
)
fraunhofer_on = widgets.Checkbox(value=False, description="Fraunhofer approx (far-field)")
fixed_scale_E = widgets.Checkbox(value=True, description="Fixed scale (phase visible)")

# Phase controls (no play)
phase = widgets.FloatSlider(value=0.0, min=0.0, max=2*np.pi, step=0.12, description="phase ωt")
btn_phase_minus = widgets.Button(description="phase -")
btn_phase_plus  = widgets.Button(description="phase +")
btn_phase_zero  = widgets.Button(description="phase = 0")

# Divergence/convergence arrows
arrows_on = widgets.Checkbox(value=False, description="Show propagation arrows (on y-face)")
arrows_wave = widgets.ToggleButtons(options=[("Wave1", "w1"), ("Wave2", "w2")], value="w1", description="Arrows:")

# Probe
probe_on = widgets.Checkbox(value=True, description="Probe ON")
probe_mode = widgets.Dropdown(
    options=[("On y-face (pick x,z)", "face_y"),
             ("On z-face (pick x,y)", "face_z"),
             ("On x-face (pick y,z)", "face_x"),
             ("Free (x0,y0,z0)", "free")],
    value="face_y",
    description="Probe mode:"
)
u1 = widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5, description="x")
u2 = widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5, description="z")
x0 = widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5, description="x0")
y0s = widgets.FloatSlider(value=-10.0, min=-20, max=0, step=0.5, description="y0")
z0s = widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5, description="z0")

btn_probe_center  = widgets.Button(description="Probe: center")
btn_probe_s1      = widgets.Button(description="Probe: near source1")
btn_probe_s2      = widgets.Button(description="Probe: near source2")

# Line profile
profile_on = widgets.Checkbox(value=True, description="Line profile ON")
profile_face = widgets.Dropdown(
    options=[("y-face: I(x) at z=zL", "y"),
             ("z-face: I(x) at y=yL", "z"),
             ("x-face: I(y) at z=zL", "x")],
    value="y",
    description="Profile:"
)
zL = widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5, description="zL")
yL = widgets.FloatSlider(value=-10.0, min=-20, max=0, step=0.5, description="yL")

# Presets
btn_preset_near = widgets.Button(description="Preset: Near-field")
btn_preset_far  = widgets.Button(description="Preset: Far-field (Fraunhofer)")
btn_preset_coh  = widgets.Button(description="Preset: Coherence demo")
btn_preset_pol  = widgets.Button(description="Preset: Polarization demo")

# Intensity scaling
lock_I_scale = widgets.Checkbox(value=True, description="Lock Intensity scale (see A2/A1 effect)")
btn_set_I_scale = widgets.Button(description="Set I scale now")
contrast_I = widgets.Checkbox(value=False, description="Contrast view (min-max) [hides A2/A1 effect]")
I_scale_info = widgets.Output()

# Compare A/B
compare_qty = widgets.Dropdown(
    options=[("y-face: Field snapshot E", "E"),
             ("y-face: Intensity I", "I"),
             ("y-face: Δφ mod 2π", "dphi"),
             ("y-face: Δr", "dr")],
    value="E",
    description="Compare:"
)
btn_compare = widgets.Button(description="Compare A/B", button_style="info")

# Save / Reset
save_dir = widgets.Text(value="outputs", description="Save dir:")
base_name = widgets.Text(value="interference", description="Base name:")
btn_save = widgets.Button(description="Save (PNG+JSON)", button_style="success")
btn_reset = widgets.Button(description="Reset", button_style="warning")

# Outputs
out_main = widgets.Output()
out_multi = widgets.Output()
out_probe = widgets.Output()
out_profile = widgets.Output()
out_compare = widgets.Output()
out_msg = widgets.Output()


# =============================
# IMPORTANT SPEED FIX:
# Update ONLY on mouse release (no live dragging)
# =============================
SLIDERS = [
    x_min, x_max, y_min, y_max, z_min, z_max,
    n_slider, stride,
    lam0, dlam, nlam, q_slider, dphi0,
    A2_over_A1, gamma_user, pol_deg,
    phase, u1, u2, x0, y0s, z0s, zL, yL
]
for s in SLIDERS:
    if hasattr(s, "continuous_update"):
        s.continuous_update = False


# =============================
# State / caching
# =============================
STATE = {}
CACHE = {
    "E_norm_key": None,
    "E_norm_val": None,
    "I_locked_vmax": None,
    "I_locked_key": None
}

SUSPEND_UPDATES = False  # used for reset/presets to avoid multiple renders


def rebuild_grids():
    n = int(n_slider.value)
    xlim = (float(x_min.value), float(x_max.value))
    ylim = (float(y_min.value), float(y_max.value))
    zlim = (float(z_min.value), float(z_max.value))

    xs = np.linspace(*xlim, n)
    ys = np.linspace(*ylim, n)
    zs = np.linspace(*zlim, n)

    # faces pinned to boundaries
    y_face = ylim[1]
    z_face = zlim[0]
    x_face = xlim[1]

    X1, Z1 = np.meshgrid(xs, zs, indexing="xy"); Y1 = np.full_like(X1, y_face)  # y-face
    X2, Y2 = np.meshgrid(xs, ys, indexing="xy"); Z2 = np.full_like(X2, z_face)  # z-face
    Y3, Z3 = np.meshgrid(ys, zs, indexing="xy"); X3 = np.full_like(Y3, x_face)  # x-face

    STATE.update(dict(
        n=n, xs=xs, ys=ys, zs=zs,
        xlim=xlim, ylim=ylim, zlim=zlim,
        y_face=y_face, z_face=z_face, x_face=x_face,
        X1=X1, Y1=Y1, Z1=Z1,
        X2=X2, Y2=Y2, Z2=Z2,
        X3=X3, Y3=Y3, Z3=Z3,
    ))

def rebuild_grids_if_needed():
    key = (
        float(x_min.value), float(x_max.value),
        float(y_min.value), float(y_max.value),
        float(z_min.value), float(z_max.value),
        int(n_slider.value)
    )
    if STATE.get("grid_key") == key:
        return
    STATE["grid_key"] = key
    rebuild_grids()

def update_probe_ui_ranges():
    xlim, ylim, zlim = STATE["xlim"], STATE["ylim"], STATE["zlim"]
    x0.min, x0.max = xlim
    y0s.min, y0s.max = ylim
    z0s.min, z0s.max = zlim
    yL.min, yL.max = ylim
    zL.min, zL.max = zlim

    if probe_mode.value == "face_y":
        u1.description, u2.description = "x", "z"
        u1.min, u1.max = xlim
        u2.min, u2.max = zlim
    elif probe_mode.value == "face_z":
        u1.description, u2.description = "x", "y"
        u1.min, u1.max = xlim
        u2.min, u2.max = ylim
    elif probe_mode.value == "face_x":
        u1.description, u2.description = "y", "z"
        u1.min, u1.max = ylim
        u2.min, u2.max = zlim
    else:
        u1.description, u2.description = "u1", "u2"

def get_probe_xyz():
    yf, zf, xf = STATE["y_face"], STATE["z_face"], STATE["x_face"]
    if probe_mode.value == "face_y":
        return float(u1.value), float(yf), float(u2.value)
    if probe_mode.value == "face_z":
        return float(u1.value), float(u2.value), float(zf)
    if probe_mode.value == "face_x":
        return float(xf), float(u1.value), float(u2.value)
    return float(x0.value), float(y0s.value), float(z0s.value)

def pol_factor():
    th = np.deg2rad(float(pol_deg.value))
    return float(np.cos(th))

def compute_E_norm_reference():
    key = (STATE["xlim"], STATE["ylim"], STATE["zlim"], STATE["n"],
           float(lam0.value), float(q_slider.value), float(dphi0.value),
           int(incoming1.value), int(incoming2.value),
           float(A2_over_A1.value), int(use_E_decay.value))

    if CACHE["E_norm_key"] == key and CACHE["E_norm_val"] is not None:
        return CACHE["E_norm_val"]

    lam = float(lam0.value)
    k0 = 2*np.pi/lam
    q = float(q_slider.value)
    d0 = float(dphi0.value)
    s1 = +1 if incoming1.value else -1
    s2 = +1 if incoming2.value else -1
    A2 = float(A2_over_A1.value)

    X, Y, Z = STATE["X1"], STATE["Y1"], STATE["Z1"]
    Eref = field_snapshot_real(X, Y, Z, q, k0, phase_t=0.0, dphi0=d0, s1=s1, s2=s2,
                               A2_over_A1=A2, use_1_over_r=use_E_decay.value)
    m = float(np.max(np.abs(Eref)))
    m = max(m, 1e-6)

    CACHE["E_norm_key"] = key
    CACHE["E_norm_val"] = m
    return m

def get_intensity_vmax(I_ref):
    lock_key = (STATE["xlim"], STATE["ylim"], STATE["zlim"], STATE["n"],
                float(lam0.value), float(dlam.value), int(nlam.value),
                float(q_slider.value), float(dphi0.value),
                float(gamma_user.value), float(pol_deg.value),
                bool(use_I_decay.value), bool(fraunhofer_on.value))

    if not lock_I_scale.value:
        vmax = float(np.percentile(I_ref, 99.5))
        return max(vmax, 1e-9)

    if CACHE["I_locked_key"] != lock_key or CACHE["I_locked_vmax"] is None:
        CACHE["I_locked_key"] = lock_key
        CACHE["I_locked_vmax"] = float(np.percentile(I_ref, 99.5))
        CACHE["I_locked_vmax"] = max(CACHE["I_locked_vmax"], 1e-9)

    return float(CACHE["I_locked_vmax"])


def set_I_scale_now(_=None):
    rebuild_grids_if_needed()
    lam = float(lam0.value)
    k0 = 2*np.pi/lam
    q = float(q_slider.value)
    d0 = float(dphi0.value)
    s1 = +1 if incoming1.value else -1
    s2 = +1 if incoming2.value else -1
    polf = pol_factor()

    Xy, Yy, Zy = STATE["X1"], STATE["Y1"], STATE["Z1"]
    Iy = spectral_average_intensity(
        Xy, Yy, Zy,
        lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
        q=q, dphi0=d0, s1=s1, s2=s2,
        A2_over_A1=A2_over_A1.value,
        gamma_user=gamma_user.value, pol_factor=polf,
        use_1_over_r=use_I_decay.value,
        fraunhofer_on=fraunhofer_on.value, fraunhofer_normal="y"
    )
    CACHE["I_locked_vmax"] = float(np.percentile(Iy, 99.5))
    CACHE["I_locked_vmax"] = max(CACHE["I_locked_vmax"], 1e-9)

    with I_scale_info:
        clear_output(wait=True)
        print(f"Intensity scale locked: vmax ≈ {CACHE['I_locked_vmax']:.4g}")

    update_all()

btn_set_I_scale.on_click(set_I_scale_now)


def collect_metadata():
    return {
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "domain": {"xlim": list(STATE["xlim"]), "ylim": list(STATE["ylim"]), "zlim": list(STATE["zlim"]),
                   "n": int(STATE["n"]), "stride": int(stride.value)},
        "params": {"lambda0_um": float(lam0.value), "dlam_um": float(dlam.value), "nlam": int(nlam.value),
                   "q_um": float(q_slider.value), "dphi0_rad": float(dphi0.value),
                   "A2_over_A1": float(A2_over_A1.value),
                   "gamma_user": float(gamma_user.value), "pol_theta_deg": float(pol_deg.value),
                   "incoming1": bool(incoming1.value), "incoming2": bool(incoming2.value),
                   "use_I_decay": bool(use_I_decay.value), "use_E_decay": bool(use_E_decay.value),
                   "fraunhofer_on": bool(fraunhofer_on.value),
                   "phase": float(phase.value), "fixed_scale_E": bool(fixed_scale_E.value),
                   "lock_I_scale": bool(lock_I_scale.value), "contrast_I": bool(contrast_I.value)},
        "probe": {"on": bool(probe_on.value), "mode": probe_mode.value, "u1": float(u1.value), "u2": float(u2.value),
                  "free_xyz": [float(x0.value), float(y0s.value), float(z0s.value)]},
        "profile": {"on": bool(profile_on.value), "face": profile_face.value, "zL": float(zL.value), "yL": float(yL.value)},
        "display": {"main_view": main_view.value, "arrows_on": bool(arrows_on.value), "arrows_wave": arrows_wave.value}
    }


# =============================
# Main update
# =============================
def update_all(*args):
    global SUSPEND_UPDATES
    if SUSPEND_UPDATES:
        return

    rebuild_grids_if_needed()
    update_probe_ui_ranges()

    lam = float(lam0.value)
    k0 = 2*np.pi/lam
    q = float(q_slider.value)
    d0 = float(dphi0.value)

    s1 = +1 if incoming1.value else -1
    s2 = +1 if incoming2.value else -1
    polf = pol_factor()

    # --- y-face maps
    Xy, Yy, Zy = STATE["X1"], STATE["Y1"], STATE["Z1"]
    Iy = spectral_average_intensity(
        Xy, Yy, Zy,
        lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
        q=q, dphi0=d0, s1=s1, s2=s2,
        A2_over_A1=A2_over_A1.value,
        gamma_user=gamma_user.value, pol_factor=polf,
        use_1_over_r=use_I_decay.value,
        fraunhofer_on=fraunhofer_on.value, fraunhofer_normal="y"
    )
    dry = delta_r(Xy, Yy, Zy, q)
    dphiy = np.mod(delta_phi_exact(Xy, Yy, Zy, q, k0, d0, s1, s2), 2*np.pi)

    I_vmax = get_intensity_vmax(Iy)

    X1,Y1,Z1 = STATE["X1"], STATE["Y1"], STATE["Z1"]
    X2,Y2,Z2 = STATE["X2"], STATE["Y2"], STATE["Z2"]
    X3,Y3,Z3 = STATE["X3"], STATE["Y3"], STATE["Z3"]

    def compute_face_main(X, Y, Z, fraun_normal):
        if main_view.value == "I":
            return spectral_average_intensity(
                X, Y, Z,
                lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
                q=q, dphi0=d0, s1=s1, s2=s2,
                A2_over_A1=A2_over_A1.value,
                gamma_user=gamma_user.value, pol_factor=polf,
                use_1_over_r=use_I_decay.value,
                fraunhofer_on=fraunhofer_on.value, fraunhofer_normal=fraun_normal
            )
        elif main_view.value == "E":
            return field_snapshot_real(X, Y, Z, q, k0, phase.value, d0, s1, s2, A2_over_A1.value, use_E_decay.value)
        elif main_view.value == "dr":
            return delta_r(X, Y, Z, q)
        else:
            return np.mod(delta_phi_exact(X, Y, Z, q, k0, d0, s1, s2), 2*np.pi)

    D1 = compute_face_main(X1, Y1, Z1, "y")
    D2 = compute_face_main(X2, Y2, Z2, "z")
    D3 = compute_face_main(X3, Y3, Z3, "y")

    # --- Normalize for facecolors
    if main_view.value == "I":
        norm = colors.Normalize(vmin=0.0, vmax=I_vmax)
        cmap = cm.gray
    elif main_view.value == "E":
        if fixed_scale_E.value:
            m = compute_E_norm_reference()
        else:
            m = float(max(np.max(np.abs(D1)), np.max(np.abs(D2)), np.max(np.abs(D3)), 1e-9))
        norm = colors.Normalize(vmin=-m, vmax=+m)
        cmap = cm.gray
    elif main_view.value == "dr":
        vmin = float(min(np.min(D1), np.min(D2), np.min(D3)))
        vmax = float(max(np.max(D1), np.max(D2), np.max(D3)))
        norm = colors.Normalize(vmin=vmin, vmax=vmax)
        cmap = cm.gray
    else:
        norm = colors.Normalize(vmin=0.0, vmax=2*np.pi)
        cmap = cm.hsv

    FC1 = cmap(norm(D1))
    FC2 = cmap(norm(D2))
    FC3 = cmap(norm(D3))

    # --- Probe
    px, py, pz = get_probe_xyz()
    R1p = float(r1(px, py, pz, q))
    R2p = float(r2(px, py, pz, q))
    drp = float(R2p - R1p)
    dphp = float(delta_phi_exact(px, py, pz, q, k0, d0, s1, s2))
    dphp_mod = float(np.mod(dphp, 2*np.pi))

    A1 = 1.0
    A2 = float(A2_over_A1.value)
    if use_I_decay.value:
        a1p = A1 / (R1p + 1e-6)
        a2p = A2 / (R2p + 1e-6)
    else:
        a1p = A1
        a2p = A2
    I1p = a1p*a1p
    I2p = a2p*a2p
    gamma_eff = float(gamma_user.value) * float(polf)

    V_local = 0.0
    if (I1p + I2p) > 1e-12:
        V_local = float((2*gamma_eff*np.sqrt(I1p*I2p)) / (I1p + I2p))

    Ip = spectral_average_intensity(
        np.array([px]), np.array([py]), np.array([pz]),
        lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
        q=q, dphi0=d0, s1=s1, s2=s2,
        A2_over_A1=A2_over_A1.value,
        gamma_user=gamma_user.value, pol_factor=polf,
        use_1_over_r=use_I_decay.value,
        fraunhofer_on=False, fraunhofer_normal="y"
    )[0]
    Ep = float(field_snapshot_real(px, py, pz, q, k0, phase.value, d0, s1, s2, A2_over_A1.value, use_E_decay.value))

    # --- Main 3D
    with out_main:
        clear_output(wait=True)
        fig = plt.figure(figsize=(9.2, 6.6))
        ax = fig.add_subplot(111, projection="3d")
        st = int(stride.value)

        ax.plot_surface(X1, Y1, Z1, facecolors=FC1, rstride=st, cstride=st, linewidth=0, antialiased=False, shade=False)
        ax.plot_surface(X2, Y2, Z2, facecolors=FC2, rstride=st, cstride=st, linewidth=0, antialiased=False, shade=False)
        ax.plot_surface(X3, Y3, Z3, facecolors=FC3, rstride=st, cstride=st, linewidth=0, antialiased=False, shade=False)

        draw_edges(ax, STATE["xlim"], STATE["ylim"], STATE["zlim"], lw=1.0)

        ax.scatter([-q, +q], [0, 0], [0, 0], c="red", s=55)

        if probe_on.value:
            ax.scatter([px], [py], [pz], c="lime", s=80)

        if arrows_on.value:
            src = (-q, 0, 0) if arrows_wave.value == "w1" else (q, 0, 0)
            is_in = (incoming1.value if arrows_wave.value == "w1" else incoming2.value)
            sign = -1 if is_in else +1

            step = max(1, STATE["n"]//18)
            Xq = STATE["X1"][::step, ::step]
            Yq = STATE["Y1"][::step, ::step]
            Zq = STATE["Z1"][::step, ::step]
            ux, uy, uz = unit_radial(Xq, Yq, Zq, *src)
            ax.quiver(Xq, Yq, Zq, sign*ux, sign*uy, sign*uz, length=3.0, normalize=True, linewidth=0.6)

        ax.set_xlim(STATE["xlim"]); ax.set_ylim(STATE["ylim"]); ax.set_zlim(STATE["zlim"])
        ax.set_xlabel("x (µm)"); ax.set_ylabel("y (µm)"); ax.set_zlabel("z (µm)")
        ax.set_xticks(auto_ticks(*STATE["xlim"]))
        ax.set_yticks(auto_ticks(*STATE["ylim"]))
        ax.set_zticks(auto_ticks(*STATE["zlim"]))

        ax.set_box_aspect((STATE["xlim"][1]-STATE["xlim"][0],
                           STATE["ylim"][1]-STATE["ylim"][0],
                           STATE["zlim"][1]-STATE["zlim"][0]))
        ax.view_init(elev=25, azim=-130)
        ax.grid(False)

        title = {"I":"Intensity (spectral-avg)",
                 "E":"Field snapshot E (phase-visible)",
                 "dr":"Δr map",
                 "dphi":"Δφ mod 2π map"}[main_view.value]
        title += f" | w1={'IN' if incoming1.value else 'OUT'} , w2={'IN' if incoming2.value else 'OUT'}"
        if main_view.value == "I":
            title += f" | I-scale vmax≈{I_vmax:.3g} ({'LOCK' if lock_I_scale.value else 'AUTO'})"
        ax.set_title(title)

        plt.tight_layout()
        plt.show()
        STATE["last_fig"] = fig
        plt.close(fig)

    # --- Multi-view 2D on y-face
    with out_multi:
        clear_output(wait=True)
        fig2 = plt.figure(figsize=(12.0, 3.8))
        axs = [fig2.add_subplot(1,3,i+1) for i in range(3)]
        extent = [STATE["xlim"][0], STATE["xlim"][1], STATE["zlim"][0], STATE["zlim"][1]]

        if contrast_I.value:
            I_disp = (Iy - np.min(Iy)) / (np.max(Iy) - np.min(Iy) + 1e-12)
            axs[0].imshow(I_disp, origin="lower", extent=extent, aspect="auto", cmap="gray")
            axs[0].set_title("y-face: Intensity I (CONTRAST view)")
        else:
            axs[0].imshow(Iy, origin="lower", extent=extent, aspect="auto", cmap="gray", vmin=0.0, vmax=I_vmax)
            axs[0].set_title(f"y-face: Intensity I (ABS) | vmax≈{I_vmax:.3g}")
        axs[0].set_xlabel("x (µm)"); axs[0].set_ylabel("z (µm)")

        axs[1].imshow(dry, origin="lower", extent=extent, aspect="auto", cmap="gray")
        axs[1].set_title("y-face: Δr (µm)")
        axs[1].set_xlabel("x (µm)"); axs[1].set_ylabel("z (µm)")

        axs[2].imshow(dphiy, origin="lower", extent=extent, aspect="auto", cmap="hsv", vmin=0, vmax=2*np.pi)
        axs[2].set_title("y-face: Δφ mod 2π (rad)")
        axs[2].set_xlabel("x (µm)"); axs[2].set_ylabel("z (µm)")

        if probe_on.value and probe_mode.value == "face_y":
            for a in axs:
                a.axvline(px, lw=1.0)
                a.axhline(pz, lw=1.0)

        fig2.tight_layout()
        plt.show()
        plt.close(fig2)

    # --- Probe diagnostics
    with out_probe:
        clear_output(wait=True)
        if probe_on.value:
            print("=== Probe diagnostics ===")
            print(f"mode={probe_mode.value} | (x,y,z)=({px:.2f}, {py:.2f}, {pz:.2f}) µm")
            print(f"r1={R1p:.4f} µm | r2={R2p:.4f} µm | Δr={drp:.4f} µm | Δr/λ={drp/lam:.4f}")
            print(f"Δφ={dphp:.4f} rad | Δφ mod 2π={dphp_mod:.4f} rad")
            print(f"I(spectral-avg)={float(Ip):.6f} | E(snapshot)={Ep:.6f}")
            print(f"Local visibility V≈{V_local:.4f}  (changes with A2/A1, γ, pol θ)")

            m_est = drp/lam
            frac = abs(m_est - round(m_est))
            frac_half = abs(m_est - (np.floor(m_est)+0.5))
            hint = "≈ constructive (bright)" if frac < 0.05 else ("≈ destructive (dark)" if frac_half < 0.05 else "intermediate")
            print("Hint:", hint)
        else:
            print("Probe OFF")

    # --- Line profile (Intensity) + spacing + theory
    with out_profile:
        clear_output(wait=True)
        if profile_on.value:
            xs = STATE["xs"]; ys = STATE["ys"]
            yf, zf, xf = STATE["y_face"], STATE["z_face"], STATE["x_face"]

            if profile_face.value == "y":
                x_line = xs
                y_line = np.full_like(xs, yf)
                z_line = np.full_like(xs, zL.value)
                horiz = x_line
                xlabel = "x (µm)"
                normal = "y"
                L = abs(yf)
            elif profile_face.value == "z":
                x_line = xs
                y_line = np.full_like(xs, yL.value)
                z_line = np.full_like(xs, zf)
                horiz = x_line
                xlabel = "x (µm)"
                normal = "z"
                L = abs(zf)
            else:
                y_line = ys
                x_line = np.full_like(ys, xf)
                z_line = np.full_like(ys, zL.value)
                horiz = y_line
                xlabel = "y (µm)"
                normal = "y"
                L = None

            Iline = spectral_average_intensity(
                x_line, y_line, z_line,
                lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
                q=q, dphi0=d0, s1=s1, s2=s2,
                A2_over_A1=A2_over_A1.value,
                gamma_user=gamma_user.value, pol_factor=polf,
                use_1_over_r=use_I_decay.value,
                fraunhofer_on=fraunhofer_on.value, fraunhofer_normal=normal
            )

            Imax = float(np.max(Iline))
            Imin = float(np.min(Iline))
            V_line = (Imax - Imin) / (Imax + Imin + 1e-12)
            dx_fft = estimate_fringe_spacing_fft(horiz, (Iline - np.mean(Iline)))

            fig3 = plt.figure(figsize=(9.2, 2.7))
            ax3 = fig3.add_subplot(111)
            ax3.plot(horiz, Iline)
            ax3.set_xlabel(xlabel)
            ax3.set_ylabel("I (absolute)")
            ax3.grid(True, alpha=0.25)

            title = f"Line profile (Intensity) | Visibility≈{V_line:.3f}"
            if dx_fft is not None:
                title += f" | fringe spacing (FFT) ≈ {dx_fft:.3f} µm"
            ax3.set_title(title)

            if (profile_face.value in ("y", "z")) and (L is not None) and (L > 1e-6) and (q > 1e-6):
                dx_th = lam * L / (2.0*q)
                ax3.text(0.02, 0.92, f"Fraunhofer Δx≈λL/(2q)≈{dx_th:.3f} µm",
                         transform=ax3.transAxes, fontsize=9)

            fig3.tight_layout()
            plt.show()
            plt.close(fig3)
        else:
            print("Line profile OFF")

    # --- message
    with out_msg:
        clear_output(wait=True)
        if main_view.value != "E" and phase.value != 0.0:
            print("Tip: phase effect is visible in 'Field snapshot E'. Intensity is time-averaged.")
        if lock_I_scale.value:
            print("Lock Intensity scale is ON (good for seeing A2/A1, γ, pol effects).")
        if contrast_I.value:
            print("Contrast view is ON (this can HIDE A2/A1 effect). Turn it OFF to see true visibility.")


# =============================
# Compare A/B
# =============================
def compare_AB(_=None):
    rebuild_grids_if_needed()

    lam = float(lam0.value)
    k0 = 2*np.pi/lam
    q = float(q_slider.value)
    d0 = float(dphi0.value)
    polf = pol_factor()

    # Scenario A: OUT/OUT
    s1A, s2A = -1, -1
    # Scenario B: IN/OUT
    s1B, s2B = +1, -1

    Xy, Yy, Zy = STATE["X1"], STATE["Y1"], STATE["Z1"]

    # label for title
    label_map = {v: k for (k, v) in compare_qty.options}
    qty_label = label_map.get(compare_qty.value, compare_qty.value)

    if compare_qty.value == "I":
        IA = spectral_average_intensity(Xy, Yy, Zy, lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
                                        q=q, dphi0=d0, s1=s1A, s2=s2A,
                                        A2_over_A1=A2_over_A1.value,
                                        gamma_user=gamma_user.value, pol_factor=polf,
                                        use_1_over_r=use_I_decay.value,
                                        fraunhofer_on=fraunhofer_on.value, fraunhofer_normal="y")
        IB = spectral_average_intensity(Xy, Yy, Zy, lam0=lam0.value, dlam=dlam.value, nlam=nlam.value,
                                        q=q, dphi0=d0, s1=s1B, s2=s2B,
                                        A2_over_A1=A2_over_A1.value,
                                        gamma_user=gamma_user.value, pol_factor=polf,
                                        use_1_over_r=use_I_decay.value,
                                        fraunhofer_on=fraunhofer_on.value, fraunhofer_normal="y")
        vmax = max(get_intensity_vmax(IA), get_intensity_vmax(IB))
        vmin = 0.0
        cmap_use = "gray"
    elif compare_qty.value == "E":
        EA = field_snapshot_real(Xy, Yy, Zy, q, k0, phase.value, d0, s1A, s2A, A2_over_A1.value, use_E_decay.value)
        EB = field_snapshot_real(Xy, Yy, Zy, q, k0, phase.value, d0, s1B, s2B, A2_over_A1.value, use_E_decay.value)
        if fixed_scale_E.value:
            m = compute_E_norm_reference()
        else:
            m = float(max(np.max(np.abs(EA)), np.max(np.abs(EB)), 1e-9))
        IA, IB = EA, EB
        vmin, vmax = -m, +m
        cmap_use = "gray"
    elif compare_qty.value == "dphi":
        IA = np.mod(delta_phi_exact(Xy, Yy, Zy, q, k0, d0, s1A, s2A), 2*np.pi)
        IB = np.mod(delta_phi_exact(Xy, Yy, Zy, q, k0, d0, s1B, s2B), 2*np.pi)
        vmin, vmax = 0.0, 2*np.pi
        cmap_use = "hsv"
    else:  # "dr"
        IA = delta_r(Xy, Yy, Zy, q)
        IB = IA.copy()
        vmin, vmax = float(np.min(IA)), float(np.max(IA))
        cmap_use = "gray"

    extent = [STATE["xlim"][0], STATE["xlim"][1], STATE["zlim"][0], STATE["zlim"][1]]

    with out_compare:
        clear_output(wait=True)
        fig = plt.figure(figsize=(12.0, 4.2))
        ax1 = fig.add_subplot(1,2,1)
        ax2 = fig.add_subplot(1,2,2)

        ax1.imshow(IA, origin="lower", extent=extent, aspect="auto", cmap=cmap_use, vmin=vmin, vmax=vmax)
        ax2.imshow(IB, origin="lower", extent=extent, aspect="auto", cmap=cmap_use, vmin=vmin, vmax=vmax)

        ax1.set_title("A: OUT/OUT (Wave1 OUT, Wave2 OUT)")
        ax2.set_title("B: IN/OUT  (Wave1 IN,  Wave2 OUT)")
        ax1.set_xlabel("x (µm)"); ax1.set_ylabel("z (µm)")
        ax2.set_xlabel("x (µm)"); ax2.set_ylabel("z (µm)")

        # arrows for Wave1 to make direction obvious
        step = max(1, STATE["n"]//18)
        Xq = Xy[::step, ::step]
        Zq = Zy[::step, ::step]
        Yq = Yy[::step, ::step]
        ux, uy, uz = unit_radial(Xq, Yq, Zq, -q, 0, 0)

        ax1.quiver(Xq, Zq, (+1)*ux, (+1)*uz, angles="xy", scale_units="xy", scale=0.6, width=0.002)
        ax2.quiver(Xq, Zq, (-1)*ux, (-1)*uz, angles="xy", scale_units="xy", scale=0.6, width=0.002)

        fig.suptitle(f"Compare A/B on y-face | quantity={qty_label} | phase={phase.value:.2f}", y=1.02)
        fig.tight_layout()
        plt.show()
        plt.close(fig)

btn_compare.on_click(compare_AB)


# =============================
# Buttons logic
# =============================
def on_phase_minus(_):
    phase.value = float(max(phase.min, phase.value - phase.step))
    update_all()

def on_phase_plus(_):
    phase.value = float(min(phase.max, phase.value + phase.step))
    update_all()

def on_phase_zero(_):
    phase.value = 0.0
    update_all()

def on_probe_center(_):
    rebuild_grids_if_needed()
    xlim, ylim, zlim = STATE["xlim"], STATE["ylim"], STATE["zlim"]
    probe_mode.value = "free"
    x0.value = 0.5*(xlim[0]+xlim[1])
    y0s.value = 0.5*(ylim[0]+ylim[1])
    z0s.value = 0.5*(zlim[0]+zlim[1])
    update_all()

def on_probe_s1(_):
    probe_mode.value = "free"
    x0.value = -q_slider.value
    y0s.value = 0.0
    z0s.value = 0.0
    update_all()

def on_probe_s2(_):
    probe_mode.value = "free"
    x0.value = +q_slider.value
    y0s.value = 0.0
    z0s.value = 0.0
    update_all()

def preset_near(_):
    global SUSPEND_UPDATES
    SUSPEND_UPDATES = True
    try:
        y_max.value = 0.0
        y_min.value = -30.0
        x_min.value, x_max.value = -40.0, 40.0
        z_min.value, z_max.value = -40.0, 40.0
        fraunhofer_on.value = False
        dlam.value = 0.0
        nlam.value = 1
        gamma_user.value = 1.0
        pol_deg.value = 0.0
        main_view.value = "I"
    finally:
        SUSPEND_UPDATES = False
    update_all()

def preset_far(_):
    global SUSPEND_UPDATES
    SUSPEND_UPDATES = True
    try:
        y_max.value = 200.0
        y_min.value = 0.0
        x_min.value, x_max.value = -80.0, 80.0
        z_min.value, z_max.value = -80.0, 80.0
        fraunhofer_on.value = True
        dlam.value = 0.0
        nlam.value = 1
        gamma_user.value = 1.0
        pol_deg.value = 0.0
        main_view.value = "I"
    finally:
        SUSPEND_UPDATES = False
    update_all()

def preset_coh(_):
    global SUSPEND_UPDATES
    SUSPEND_UPDATES = True
    try:
        y_max.value = 200.0
        y_min.value = 0.0
        x_min.value, x_max.value = -80.0, 80.0
        z_min.value, z_max.value = -80.0, 80.0
        fraunhofer_on.value = True
        dlam.value = 0.06
        nlam.value = 21
        gamma_user.value = 1.0
        pol_deg.value = 0.0
        main_view.value = "I"
    finally:
        SUSPEND_UPDATES = False
    update_all()

def preset_pol(_):
    global SUSPEND_UPDATES
    SUSPEND_UPDATES = True
    try:
        y_max.value = 200.0
        y_min.value = 0.0
        x_min.value, x_max.value = -80.0, 80.0
        z_min.value, z_max.value = -80.0, 80.0
        fraunhofer_on.value = True
        dlam.value = 0.0
        nlam.value = 1
        gamma_user.value = 1.0
        pol_deg.value = 60.0
        main_view.value = "I"
    finally:
        SUSPEND_UPDATES = False
    update_all()

def on_save(_):
    rebuild_grids_if_needed()
    meta = collect_metadata()
    folder = save_dir.value.strip() or "outputs"
    ensure_dir(folder)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    base = base_name.value.strip() or "interference"

    png_path = os.path.join(folder, f"{base}_{ts}.png")
    json_path = os.path.join(folder, f"{base}_{ts}.json")

    fig = STATE.get("last_fig", None)
    if fig is not None:
        fig.savefig(png_path, dpi=220)

    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    with out_msg:
        clear_output(wait=True)
        print("✅ Saved:")
        print(" -", png_path)
        print(" -", json_path)

def on_reset(_):
    global SUSPEND_UPDATES
    SUSPEND_UPDATES = True
    try:
        x_min.value, x_max.value = -20, 20
        y_min.value, y_max.value = -20, 0
        z_min.value, z_max.value = -20, 20
        n_slider.value = 220
        stride.value = 4

        lam0.value = 0.6328
        dlam.value = 0.0
        nlam.value = 1
        q_slider.value = 5.0
        dphi0.value = 0.0
        A2_over_A1.value = 1.0
        gamma_user.value = 1.0
        pol_deg.value = 0.0
        use_I_decay.value = False
        use_E_decay.value = False
        incoming1.value = False
        incoming2.value = False

        main_view.value = "I"
        fraunhofer_on.value = False
        fixed_scale_E.value = True
        phase.value = 0.0

        lock_I_scale.value = True
        contrast_I.value = False
        CACHE["I_locked_vmax"] = None
        CACHE["I_locked_key"] = None

        probe_on.value = True
        probe_mode.value = "face_y"
        u1.value = 0.0
        u2.value = 0.0
        x0.value = 0.0
        y0s.value = -10.0
        z0s.value = 0.0

        profile_on.value = True
        profile_face.value = "y"
        zL.value = 0.0
        yL.value = -10.0

        arrows_on.value = False
        arrows_wave.value = "w1"

        with I_scale_info:
            clear_output(wait=True)
    finally:
        SUSPEND_UPDATES = False

    update_all()

# Bind buttons
btn_phase_minus.on_click(on_phase_minus)
btn_phase_plus.on_click(on_phase_plus)
btn_phase_zero.on_click(on_phase_zero)

btn_probe_center.on_click(on_probe_center)
btn_probe_s1.on_click(on_probe_s1)
btn_probe_s2.on_click(on_probe_s2)

btn_preset_near.on_click(preset_near)
btn_preset_far.on_click(preset_far)
btn_preset_coh.on_click(preset_coh)
btn_preset_pol.on_click(preset_pol)

btn_save.on_click(on_save)
btn_reset.on_click(on_reset)


# Observe changes
watch = [
    x_min,x_max,y_min,y_max,z_min,z_max,n_slider,stride,
    lam0,dlam,nlam,q_slider,dphi0,A2_over_A1,gamma_user,pol_deg,
    use_I_decay,use_E_decay,incoming1,incoming2,
    main_view,fraunhofer_on,fixed_scale_E,
    phase,
    arrows_on,arrows_wave,
    probe_on,probe_mode,u1,u2,x0,y0s,z0s,
    profile_on,profile_face,zL,yL,
    lock_I_scale, contrast_I
]
for w in watch:
    w.observe(update_all, names="value")


# =============================
# Layout
# =============================
ui = widgets.VBox([
    widgets.HTML("<h2 style='margin:4px 0'>3D Two-source Interference (teaching edition)</h2>"),
    widgets.HTML("<b>Presets (for class)</b>"),
    widgets.HBox([btn_preset_near, btn_preset_far, btn_preset_coh, btn_preset_pol]),

    widgets.HTML("<b>Domain / crop</b>"),
    widgets.HBox([x_min, x_max]),
    widgets.HBox([y_min, y_max]),
    widgets.HBox([z_min, z_max]),
    widgets.HBox([n_slider, stride]),

    widgets.HTML("<b>Physics</b>"),
    widgets.HBox([lam0, dlam, nlam]),
    widgets.HBox([q_slider, dphi0]),
    widgets.HBox([A2_over_A1, gamma_user, pol_deg]),
    widgets.HBox([use_I_decay, use_E_decay, fraunhofer_on]),
    widgets.HBox([incoming1, incoming2]),

    widgets.HTML("<b>Intensity scaling (FIX A2/A1 visibility)</b>"),
    widgets.HBox([lock_I_scale, btn_set_I_scale, contrast_I]),
    I_scale_info,

    widgets.HTML("<b>Main display</b>"),
    widgets.HBox([main_view, fixed_scale_E]),

    widgets.HTML("<b>Phase (برای دیدن جهت موج‌ها: Main=E)</b>"),
    widgets.HBox([phase, btn_phase_minus, btn_phase_plus, btn_phase_zero]),

    widgets.HTML("<b>Convergence / Divergence</b>"),
    widgets.HBox([arrows_on, arrows_wave]),

    widgets.HTML("<b>Probe</b>"),
    widgets.HBox([probe_on, probe_mode]),
    widgets.HBox([u1, u2]),
    widgets.HBox([x0, y0s, z0s]),
    widgets.HBox([btn_probe_center, btn_probe_s1, btn_probe_s2]),

    widgets.HTML("<b>Line profile (real-time)</b>"),
    widgets.HBox([profile_on, profile_face, zL, yL]),

    widgets.HTML("<b>Compare A/B (Outgoing/Outgoing vs Incoming/Outgoing)</b>"),
    widgets.HBox([compare_qty, btn_compare]),

    widgets.HTML("<b>Save / Reset</b>"),
    widgets.HBox([save_dir, base_name, btn_save, btn_reset]),
])

display(ui, out_main, out_multi, out_probe, out_profile, out_compare, out_msg)
update_all()


VBox(children=(HTML(value="<h2 style='margin:4px 0'>3D Two-source Interference (teaching edition)</h2>"), HTML…

Output()

Output()

Output()

Output()

Output()

Output()