In [None]:
import numpy as np
import pyvista as pv

In [None]:
def e_surf(k, q, m1, m2):
    return (
        np.sqrt(np.sum((k - q[1:]) ** 2, axis=-1) + m1**2)
        + np.sqrt(np.sum((k + q[1:]) ** 2, axis=-1) + m2**2)
        - 2 * q[0]
    )

In [3]:
res = 100
x = y = z = np.linspace(-1.2, 1.2, res)

ks = np.stack(np.meshgrid(x, y, z), axis = -1)

q = np.array([1,0,0,0.5])
m1 = 0.5
m2 = 0.5

vals = e_surf(ks, q, m1, m2)

grid = pv.ImageData()
grid.dimensions = np.array(vals.shape)
grid.origin = (x[0], y[0], z[0])
grid.spacing = (x[1] - x[0], y[1] - y[0], z[1] - z[0])

grid.point_data["vals"] = vals.flatten(order="F")
surf = grid.contour([0])

plotter = pv.Plotter()
plotter.add_mesh(surf, color="cyan", opacity=0.5, smooth_shading=True)

arrow = pv.Arrow(start=(0,0,0), direction=q[1:]/q[0], scale=0.3)
plotter.add_mesh(arrow, color="red")

plotter.show_grid()
plotter.show(interactive=False)


Widget(value='<iframe src="http://localhost:37799/index.html?ui=P_0x7840a3f2e9f0_0&reconnect=auto" class="pyvi…

In [4]:
import numpy as np
import matplotlib.pyplot as plt


def plot_complex_plane(xs, ys, ax = None):
    """Plot a complex→complex function using HSV color encoding for phase and magnitude.
    xs is a 2D grid (from np.meshgrid) of complex-plane x-values, ys is the complex output.
    NaN or inf values in ys are handled gracefully and shown as transparent.
    """
    
    if ax is None:
        ax = plt.gca()

    # Mask invalid data
    valid_mask = np.isfinite(ys)
    if not np.any(valid_mask):
        raise ValueError("All ys values are NaN or inf — nothing to plot.")

    # Compute phase and magnitude safely
    phase = np.angle(np.where(valid_mask, ys, 0))
    mag = np.abs(np.where(valid_mask, ys, 0))
    max_mag = np.nanmax(mag)
    mag = mag / max_mag if max_mag != 0 else mag

    # HSV mapping
    hue = (phase + np.pi) / (2 * np.pi)
    value = mag

    # HSV → RGB
    rgb = plt.cm.hsv(hue)
    rgb[..., :3] *= value[..., None]

    # Add transparency for invalid values
    alpha = np.where(valid_mask, 1.0, 0.0)
    rgb[..., -1] = alpha

    # Compute plotting extents (robust to NaNs)
    x_real = np.real(xs)
    y_imag = np.imag(xs)
    x_min, x_max = np.nanmin(x_real), np.nanmax(x_real)
    y_min, y_max = np.nanmin(y_imag), np.nanmax(y_imag)

    # Plot
    plt.imshow(
        rgb,
        origin="lower",
        extent=[x_min, x_max, y_min, y_max],
        interpolation="nearest",
        aspect="equal",  # maintain correct aspect ratio
    )


def plot_complex(xs, ys):
    """
    Plot a real -> complex function
    """
    plt.plot(xs, ys.real, label="re")
    plt.plot(xs, ys.imag, label="im")





In [5]:
x = np.linspace(-2,2, 300)
y = np.linspace(-2,2, 300)
X, Y = np.meshgrid(x, y)

xs = X + 1j*Y

k_hat = np.array([0,0,1])


from matplotlib.animation import FuncAnimation


# Set up figure
fig, ax = plt.subplots()

# Animation function
def update(frame):
    m1 = frame
    m2 = frame
    y = 1 / e_surf(xs[..., None] * k_hat, q, m1, m2)
    plot_complex_plane(xs, y, ax=ax)
    ax.set_title(f"m1 = m2 = {m1:.2f}")

# Frames
m_values = np.linspace(0.5, 1, 30)

ani = FuncAnimation(fig, update, frames=m_values, interval=200)

from IPython.display import HTML

display(HTML(ani.to_jshtml()))
plt.close()
