In [None]:
import mfem.ser as mfem
import ipywidgets as widgets
from glvis import glvis
from math import *

In [None]:
element_types = {
    "2d": ["TRIANGLE", "QUADRILATERAL"],
    "3d": ["TETRAHEDRON", "HEXAHEDRON"]
}

basis_map = {
    "H1": mfem.H1_FECollection,
    "L2": mfem.L2_FECollection
}

class callback_coeff(mfem.PyCoefficient):
    def EvalValue(self, x):
        return self.callback(x)

    
class callback_vcoeff(mfem.VectorPyCoefficient):
    def EvalValue(self, x):
        return self.callback(x)
    
# these have to be globals, if they are gc'd the runtime crashes (segfault)
coeff = fec = fespace = None
def generate(shape: tuple,
             element_type: str,
             order: int,
             basis: mfem.FiniteElementCollection, 
             func: callable = None,
             transform: callable = None):
    global fec, fespace, coeff
    if element_type in element_types["2d"]: shape = shape[:2]
    mesh = mfem.mesh.Mesh(*shape, element_type)
    if transform is not None:
        mesh_coeff = callback_vcoeff(mesh.Dimension())
        mesh_coeff.callback = transform
        mesh.Transform(mesh_coeff)
    fec = basis(order, mesh.Dimension())
    fespace = mfem.FiniteElementSpace(mesh, fec)
    if func is None: return mesh
    x = mfem.GridFunction(fespace)
    coeff = callback_coeff()
    coeff.callback = func
    x.ProjectCoefficient(coeff)
    return (mesh, x)

In [None]:
# widget parts
nx = widgets.IntSlider(description="nx:", min=5, max=50, step=5, continuous_update=False)
ny = widgets.IntSlider(description="ny:", min=5, max=50, step=5, continuous_update=False)
nz = widgets.IntSlider(description="nz:", min=5, max=50, step=5, continuous_update=False)
order = widgets.BoundedIntText(description='Order:', value=2, min=1, step=1, continuous_update=False)
element_type = widgets.Dropdown(
    value="TRIANGLE",
    options=[e for l in element_types.values() for e in l],
    description='Element Type:',
    style={'description_width': 'initial'}
)
basis = widgets.Dropdown(
    #value=basis_map.values()[0],
    options=basis_map.keys(),
    description="Basis:",
    style={'description_width': 'initial'}
)
func = widgets.Textarea(
    disabled=False,
    layout={"height": "50px"}
)
transform = widgets.Textarea(
    disabled=False,
    layout={"height": "50px"}
)
g = glvis(mfem.mesh.Mesh(nx.value, ny.value, element_type.value))

def toggle_dim(event=None):
    if element_type.value in element_types["2d"]:
        nz.disabled = True
        func.placeholder = "Gridfunction Coeff: scalar expr with x and y"
        transform.placeholder = "Mesh Transform: (<x expr>, <y expr>)"
    else:
        nz.disabled = False
        func.placeholder = "Gridfunction Coeff: scalar expr with x, y, and z"
        transform.placeholder = "Mesh Transform: (<x expr>, <y expr>, <z expr>)"
        
# setup handlers
def show(event=None):
    # gridfunction callback
    if func.value != "":
        def gfcb(x):
            local = {
                "x": x[0],
                "y": x[1],
                "z": x[2] if len(x) > 2 else 0
            }
            return eval(func.value, None, local)
    else:
        gfcb = None
      
    # mesh callback
    if transform.value != "":
        def mcb(x):
            local = {
                "x": x[0],
                "y": x[1],
                "z": x[2] if len(x) > 2 else 0
            }
            return eval(transform.value, None, local)
    else:
        mcb = None
    
    g.display(generate(
        shape=(nx.value, ny.value, nz.value),
        element_type=element_type.value,
        order=order.value,
        basis=basis_map[basis.value],
        func=gfcb,
        transform=mcb
    ))

button = widgets.Button(description="Update")
button.on_click(show)

# figure updates on any change to nx or ny
nx.observe(show, names="value")
ny.observe(show, names="value")
nz.observe(show, names="value")
order.observe(show, names="value")
element_type.observe(show, names="value")
basis.observe(show, names="value")
element_type.observe(toggle_dim, names="value")

centered = widgets.Layout(display='inline-flex',
                          align_items='center',
                          justify_content="center")

# initial state
toggle_dim()

# build widget
widgets.VBox([
    widgets.HBox([g], layout=centered),
    widgets.HBox([nx, ny, nz], layout=centered),
    widgets.HBox([order, element_type], layout=centered), 
    widgets.HBox([basis], layout=centered), 
    widgets.HBox([func], layout=centered),
    widgets.HBox([transform], layout=centered),
    widgets.HBox([button], layout=centered)
])