In [1]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import jax.numpy as jnp
import finite_diff
from scipy.stats import qmc



In [3]:
nx = 101
omega = 0.2
coords, conn, coeffs, diag_coeff, epsilon = finite_diff.build_mesh(
    nx, omega=omega, dielectric_loc=jnp.array([400, 400])
)

source_domain = jnp.array(
    [[30, 30], [570, 570]]
)  # reduced domain to avoid having source too close to boundary

sampler = qmc.LatinHypercube(d=2, seed=345)
test_sources = source_domain[0, :] + (
    source_domain[1, :] - source_domain[0, :]
) * sampler.random(n=5)

A = finite_diff.build_operator(nx, conn, coeffs, diag_coeff)

for i in range(test_sources.shape[0]):
    b = finite_diff.build_vector(test_sources[i, :], coords, omega=omega)
    out_field = jnp.linalg.solve(A, b)

    fig = make_subplots(rows=1, cols=2, column_titles=["Real", "Imaginary"])
    fig.add_trace(
        go.Heatmap(
            z=jnp.real(out_field).reshape(nx, nx),
            x=jnp.linspace(0, 600, nx),
            y=jnp.linspace(0, 600, nx),
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Heatmap(
            z=jnp.imag(out_field).reshape(nx, nx),
            x=jnp.linspace(0, 600, nx),
            y=jnp.linspace(0, 600, nx),
        ),
        row=1,
        col=2,
    )
    fig.update_layout(width=1200, height=600)
    fig.show()

KeyboardInterrupt: 

In [None]:
jnp.sum(b)

Array(0.+1.88495561j, dtype=complex128)