Skip to content

Commit

Permalink
Add plot3d function (#2305)
Browse files Browse the repository at this point in the history
* Update visualization.py

* Update simulation.py

* Add lighting

* Update docs

* Make skimage import backwards compatible

* Add camera settings

* Minor edits, and remove __main__

* Remove distance

* Update prism_epsilon.png
  • Loading branch information
SkandanC committed Nov 17, 2022
1 parent 9ed6ec1 commit a7cd7da
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 26 deletions.
8 changes: 1 addition & 7 deletions doc/docs/Python_Tutorials/Basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -1192,13 +1192,7 @@ sim = mp.Simulation(resolution=50,
cell_size=cell_size,
geometry=geometry)

sim.init_sim()

eps_data = sim.get_epsilon()

from mayavi import mlab
s = mlab.contour3d(eps_data, colormap="YlGnBu")
mlab.show()
sim.plot3D()
```

![](../images/prism_epsilon.png#center)
Expand Down
Binary file modified doc/docs/images/prism_epsilon.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 14 additions & 3 deletions python/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4811,14 +4811,25 @@ def plot_fields(self, **kwargs):

return vis.plot_fields(self, **kwargs)

def plot3D(self):
def plot3D(
self, save_to_image: bool = False, image_name: str = "sim.png", **kwargs
):
"""
Uses Mayavi to render a 3D simulation domain. The simulation object must be 3D.
Uses vispy to render a 3D scene of the simulation object. The simulation object must be 3D.
Can also be embedded in Jupyter notebooks.
Args:
save_to_image: if True, saves the image to a file
image_name: the name of the image file to save to
kwargs: Camera settings.
scale_factor: float, camera zoom factor
azimuth: float, azimuthal angle in degrees
elevation: float, elevation angle in degrees
"""
import meep.visualization as vis

return vis.plot3D(self)
return vis.plot3D(self, save_to_image, image_name, **kwargs)

def visualize_chunks(self):
"""
Expand Down
182 changes: 166 additions & 16 deletions python/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from matplotlib.figure import Figure
from typing import Callable, Union, Any, Tuple, List, Optional


# ------------------------------------------------------- #
# Visualization
# ------------------------------------------------------- #
Expand Down Expand Up @@ -398,23 +399,23 @@ def sort_points(xy):
ax.plot(
[a.y for a in intersection],
[a.z for a in intersection],
**line_args
**line_args,
)
return ax
# Plot XZ
elif sim_size.y == 0:
ax.plot(
[a.x for a in intersection],
[a.z for a in intersection],
**line_args
**line_args,
)
return ax
# Plot XY
elif sim_size.z == 0:
ax.plot(
[a.x for a in intersection],
[a.y for a in intersection],
**line_args
**line_args,
)
return ax
else:
Expand Down Expand Up @@ -988,27 +989,176 @@ def plot2D(
return ax


def plot3D(sim: Simulation):
from mayavi import mlab
def plot3D(sim, save_to_image: bool = False, image_name: str = "sim.png", **kwargs):
from vispy.scene.visuals import Box, Mesh
from vispy.scene import SceneCanvas, transforms

if sim.dimensions < 3:
raise ValueError("Simulation must have 3 dimensions to visualize 3D")
try:
from skimage.measure import marching_cubes
except:
from skimage.measure import marching_cubes_lewiner as marching_cubes
from vispy.visuals.filters import ShadingFilter

xmin, xmax, ymin, ymax, zmin, zmax = box_vertices(
sim.geometry_center, sim.cell_size
# Set canvas
canvas = SceneCanvas(keys="interactive", bgcolor="white")

view = canvas.central_widget.add_view()
view.camera = "turntable"

# Get domain measurements
sim_center, sim_size = sim.geometry_center, sim.cell_size

xmin, xmax, ymin, ymax, zmin, zmax = mp.visualization.box_vertices(
sim_center, sim_size, sim.is_cylindrical
)

Nx = int(sim.cell_size.x * sim.resolution) + 1
Ny = int(sim.cell_size.y * sim.resolution) + 1
Nz = int(sim.cell_size.z * sim.resolution) + 1
grid_resolution = sim.resolution

Nx = int((xmax - xmin) * grid_resolution + 1)
Ny = int((ymax - ymin) * grid_resolution + 1)
Nz = int((zmax - zmin) * grid_resolution + 1)

xtics = np.linspace(xmin, xmax, Nx)
ytics = np.linspace(ymin, ymax, Ny)
ztics = np.linspace(zmin, zmax, Nz)

eps_data = sim.get_epsilon_grid(xtics, ytics, ztics)
s = mlab.contour3d(eps_data, colormap="YlGnBu")
return s
# Get eps for geometry
eps_data = np.round(np.real(sim.get_epsilon_grid(xtics, ytics, ztics)), 2)

unique = np.unique(np.abs(eps_data)).tolist()

# Remove background material
unique.remove(np.round(np.abs(np.asarray(sim.default_material.epsilon_diag)), 2)[0])

mesh_midpoint = (sim_size[0] / 2, sim_size[1] / 2, sim_size[2] / 2)

light_dir = (0, 0, -1, 0)

# Build geometry
for i, eps in enumerate(unique):
eps_ = np.array(eps_data.flatten() == eps).astype(int).reshape(eps_data.shape)
marching_cube = marching_cubes(
eps_,
0.99,
spacing=(sim.cell_size.x / Nx, sim.cell_size.y / Ny, sim.cell_size.z / Nz),
)
vertices, faces = marching_cube[0], marching_cube[1]

mesh = Mesh(
vertices,
faces,
color=(
1 - ((i + 1) / len(unique)),
1 - ((i + 1) / len(unique)),
1 - ((i + 1) / len(unique)),
0.8,
),
)

mesh.transform = transforms.MatrixTransform()
mesh.transform.translate(np.asarray(sim.geometry_center))
shading_filter = ShadingFilter(shininess=100)
shading_filter.light_dir = light_dir[:3]
mesh.attach(shading_filter)
view.add(mesh)

# Build source
thickness = (
sim.boundary_layers[0].thickness if not len(sim.boundary_layers) < 1 else 0
)
for source in sim.sources:
size = tuple(source.size)
source_box = Box(
*size,
color=(1, 0, 0, 1), # red
)
center = list(source.center)
source_box.transform = transforms.MatrixTransform()
source_box.transform.translate(np.asarray(mesh_midpoint))
source_box.transform.translate(center)
source_box.transform.translate(tuple(sim.geometry_center))
view.add(source_box)

# Build monitors
for mon in sim.dft_objects:
for reg in mon.regions:
size = list(reg.size)
monitor_box = Box(
*size,
color=(0, 0, 1, 1), # blue
)
center = list(reg.center)
monitor_box.transform = transforms.MatrixTransform()
vector = [0, 0, 0]
vector[reg.direction] = 1
vector = mp.Vector3(*vector)
monitor_box.transform.translate(tuple(mesh_midpoint))
monitor_box.transform.translate(center)
monitor_box.transform.translate(tuple(sim.geometry_center))
view.add(monitor_box)

# Build boundaries
for box_center_top in [
np.add(mesh_midpoint, (0, 0, sim_size[2] / 2 - thickness / 2)),
np.subtract(mesh_midpoint, (0, 0, sim_size[2] / 2 - thickness / 2)),
]:
box = _build_3d_pml(sim_size[0], sim_size[1], thickness, box_center_top)
view.add(box)

for box_center_right in [
np.add(mesh_midpoint, (sim_size[0] / 2 - thickness / 2, 0, 0)),
np.subtract(mesh_midpoint, (sim_size[0] / 2 - thickness / 2, 0, 0)),
]:
box = _build_3d_pml(thickness, sim_size[1], sim_size[2], box_center_right)
view.add(box)

for box_center_front in [
np.add(mesh_midpoint, (0, sim_size[1] / 2 - thickness / 2, 0)),
np.subtract(mesh_midpoint, (0, sim_size[1] / 2 - thickness / 2, 0)),
]:
box = _build_3d_pml(sim_size[0], thickness, sim_size[2], box_center_front)
view.add(box)

# Camera options
view.camera.center = mesh_midpoint
view.camera.scale_factor = getattr(
kwargs, "scale_factor", 2 * np.linalg.norm(sim_size)
)
view.camera.elevation = getattr(kwargs, "elevation", 10)
view.camera.azimuth = getattr(kwargs, "azimuth", 45)
view.camera.transform.imap(light_dir)

# Plot or save
if save_to_image:
image = canvas.render()
import imageio

imageio.imwrite(image_name, image)

return

canvas.show(run=True)


def _build_3d_pml(x: float, y: float, thickness: float, translate: tuple):
from vispy.scene.visuals import Box
from vispy.scene import transforms
from vispy.visuals.filters import WireframeFilter

box = Box(
x,
y,
thickness,
color=(0, 1, 0, 0.2), # green but transparent
# color=None,
)
box.transform = transforms.MatrixTransform()
box.transform.rotate(90, (1, 0, 0))
box.transform.translate(translate)
wireframe_filter = WireframeFilter(width=2)
box.mesh.attach(wireframe_filter)

return box


def visualize_chunks(sim: Simulation):
Expand Down Expand Up @@ -1446,7 +1596,7 @@ def to_jshtml(self, fps: int) -> JS_Animation:
Nframes=Nframes,
fill_frames=fill_frames,
interval=interval,
**mode_dict
**mode_dict,
)
return JS_Animation(html_string)

Expand Down

0 comments on commit a7cd7da

Please sign in to comment.