Skip to content

Commit

Permalink
add/update type hints for JAX wrapper of adjoint solver (#2190)
Browse files Browse the repository at this point in the history
* add/update type hints for JAX wrapper of adjoint solver

* fix return type of install_design_region_monitors
  • Loading branch information
oskooi committed Aug 11, 2022
1 parent 2aa9164 commit ea43f83
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 59 deletions.
66 changes: 21 additions & 45 deletions python/adjoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,33 @@


class DesignRegion:
def __init__(self, design_parameters, volume=None, size=None, center=mp.Vector3()):
def __init__(
self,
design_parameters: Iterable[onp.ndarray],
volume: mp.Volume = None,
size: mp.Vector3 = None,
center: mp.Vector3 = mp.Vector3(),
):
self.volume = volume or mp.Volume(center=center, size=size)
self.size = self.volume.size
self.center = self.volume.center
self.design_parameters = design_parameters
self.num_design_params = design_parameters.num_params

def update_design_parameters(self, design_parameters):
def update_design_parameters(self, design_parameters) -> None:
self.design_parameters.update_weights(design_parameters)

def update_beta(self, beta):
def update_beta(self, beta: float) -> None:
self.design_parameters.beta = beta

def get_gradient(
self, sim, fields_a, fields_f, frequencies, finite_difference_step
):
self,
sim: mp.Simulation,
fields_a: List[mp.DftFields],
fields_f: List[mp.DftFields],
frequencies: List[float],
finite_difference_step: float,
) -> onp.ndarray:
num_freqs = onp.array(frequencies).size
"""We have the option to linearly scale the gradients up front
using the scalegrad parameter (leftover from MPB API). Not
Expand Down Expand Up @@ -67,11 +78,11 @@ def get_gradient(
return onp.squeeze(grad).T


def _check_if_cylindrical(sim):
def _check_if_cylindrical(sim: mp.Simulation) -> bool:
return sim.is_cylindrical or (sim.dimensions == mp.CYLINDRICAL)


def _compute_components(sim):
def _compute_components(sim: mp.Simulation) -> List[int]:
return (
_ADJOINT_FIELD_COMPONENTS_CYL
if _check_if_cylindrical(sim)
Expand All @@ -88,8 +99,8 @@ def calculate_vjps(
simulation: mp.Simulation,
design_regions: List[DesignRegion],
frequencies: List[float],
fwd_fields: List[List[onp.ndarray]],
adj_fields: List[List[onp.ndarray]],
fwd_fields: List[List[mp.DftFields]],
adj_fields: List[List[mp.DftFields]],
design_variable_shapes: List[Tuple[int, ...]],
sum_freq_partials: bool = True,
finite_difference_step: float = FD_DEFAULT,
Expand Down Expand Up @@ -132,7 +143,7 @@ def install_design_region_monitors(
design_regions: List[DesignRegion],
frequencies: List[float],
decimation_factor: int = 0,
) -> List[mp.DftFields]:
) -> List[List[mp.DftFields]]:
"""Installs DFT field monitors at the design regions of the simulation."""
return [
[
Expand Down Expand Up @@ -168,41 +179,6 @@ def gather_monitor_values(monitors: List[ObjectiveQuantity]) -> onp.ndarray:
return monitor_values


def gather_design_region_fields(
simulation: mp.Simulation,
design_region_monitors: List[mp.DftFields],
frequencies: List[float],
) -> List[List[onp.ndarray]]:
"""Collects the design region DFT fields from the simulation.
Args:
simulation: the simulation object.
design_region_monitors: the installed design region monitors.
frequencies: the frequencies to monitor.
Returns:
A list of lists. Each entry (list) in the overall list corresponds one-to-
one with a declared design region. For each such contained list, the
entries correspond to the field components that are monitored. The entries
are ndarrays of rank 4 with dimensions (freq, x, y, (z-or-pad)).
The design region fields are sampled on the *Yee grid*. This makes them
fairly awkward to inspect directly. Their primary use case is supporting
gradient calculations.
"""
design_region_fields = []
for monitor in design_region_monitors:
fields_by_component = []
for component in _compute_components(simulation):
fields_by_freq = []
for freq_idx, _ in enumerate(frequencies):
fields = simulation.get_dft_array(monitor, component, freq_idx)
fields_by_freq.append(_make_at_least_nd(fields))
fields_by_component.append(onp.stack(fields_by_freq))
design_region_fields.append(fields_by_component)
return design_region_fields


def validate_and_update_design(
design_regions: List[DesignRegion], design_variables: Iterable[onp.ndarray]
) -> None:
Expand Down
20 changes: 12 additions & 8 deletions python/adjoint/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def loss(x):
value, grad = jax.value_and_grad(loss)(x)
```
"""
from typing import Callable, List, Tuple
from typing import Callable, Iterable, List, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -137,7 +137,9 @@ def __call__(self, designs: List[jnp.ndarray]) -> jnp.ndarray:
"""
return self._simulate_fn(designs)

def _run_fwd_simulation(self, design_variables):
def _run_fwd_simulation(
self, design_variables: Iterable[onp.ndarray]
) -> (jnp.ndarray, List[List[mp.DftFields]]):
"""Runs forward simulation, returning monitor values and design region fields."""
utils.validate_and_update_design(self.design_regions, design_variables)
self.simulation.reset_meep()
Expand All @@ -161,7 +163,9 @@ def _run_fwd_simulation(self, design_variables):
monitor_values = utils.gather_monitor_values(self.monitors)
return (jnp.asarray(monitor_values), fwd_design_region_monitors)

def _run_adjoint_simulation(self, monitor_values_grad):
def _run_adjoint_simulation(
self, monitor_values_grad: onp.ndarray
) -> List[List[mp.DftFields]]:
"""Runs adjoint simulation, returning design region fields."""
if not self.design_regions:
raise RuntimeError(
Expand Down Expand Up @@ -195,11 +199,11 @@ def _run_adjoint_simulation(self, monitor_values_grad):

def _calculate_vjps(
self,
fwd_fields,
adj_fields,
design_variable_shapes,
sum_freq_partials=True,
):
fwd_fields: List[List[mp.DftFields]],
adj_fields: List[List[mp.DftFields]],
design_variable_shapes: List[Tuple[int, ...]],
sum_freq_partials: bool = True,
) -> List[onp.ndarray]:
"""Calculates the VJP for a given set of forward and adjoint fields."""
return utils.calculate_vjps(
self.simulation,
Expand Down
16 changes: 10 additions & 6 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@

import meep as mp

# The calculation of finite difference gradients requires that JAX be operated with double precision
# The calculation of finite-difference gradients
# requires that JAX be operated with double precision
jax.config.update("jax_enable_x64", True)

# The step size for the finite difference gradient calculation
# The step size for the finite-difference
# gradient calculation
_FD_STEP = 1e-4

# The tolerance for the adjoint and finite difference gradient comparison
# The tolerance for the adjoint and finite-difference
# gradient comparison
_TOL = 0.1 if mp.is_single_precision() else 0.025

# We expect 3 design region monitor pointers (one for each field component)
# We expect 3 design region monitor pointers
# (one for each field component)
_NUM_DES_REG_MON = 3

mp.verbosity(0)
Expand Down Expand Up @@ -257,8 +261,8 @@ def loss_fn(x, excite_port_idx=0):
frequencies,
)
monitor_values = wrapped_meep([x])
s1p, s1m, s2m, s2p = monitor_values
t = s2m / s1p if excite_port_idx == 0 else s1m / s2p
s1p, s1m, s2p, s2m = monitor_values
t = s2p / s1p if excite_port_idx == 0 else s1m / s2m
return jnp.mean(jnp.square(jnp.abs(t)))

value, adjoint_grad = jax.value_and_grad(loss_fn)(
Expand Down

0 comments on commit ea43f83

Please sign in to comment.