# [TR-006] Interactive 3D plots

<!-- cspell:ignore cstride facecolor ianhi ipywidgets mplot rstride valinit valmax valmin valstep -->

In [None]:
%%sh
pip install matplotlib==3.4.2 numpy==1.19.5 sympy==1.8 > /dev/null

In [None]:
%config InlineBackend.figure_formats = ['svg']
%matplotlib widget
import os

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from matplotlib import cm, widgets

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

This report illustrates how to interact with [`matplotlib`](https://matplotlib.org) 3D plots through [Matplotlib sliders](https://matplotlib.org/stable/api/widgets_api.html) and [ipywidgets](https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20List.html). This might be implemented later on in {mod}`symplot` and/or [`mpl_interactions`](https://mpl-interactions.readthedocs.io) (see [ianhi/mpl-interactions#89](https://github.com/ianhi/mpl-interactions/issues/89)).

In this example, we create a {obj}`~mpl_toolkits.mplot3d.Axes3D.plot_surface` for the following function.

In [None]:
x, y, a, b = sp.symbols("x y a b")
expr = sp.sqrt(x ** a + sp.sin(y / b) ** 2)
expr

The function is formulated with {mod}`sympy`, but we use {func}`~sympy.utilities.lambdify.lambdify` to express it as a {mod}`numpy` function.

In [None]:
np_expr = sp.lambdify((x, y, a, b), expr, "numpy")

A surface plot has to be generated over a {func}`numpy.meshgrid`. This defines the $xy$-plane over which we want to plot our function.

In [None]:
x_min, x_max = 0.1, 2
y_min, y_max = -50, +50
x_values = np.linspace(x_min, x_max, num=20)
y_values = np.linspace(y_min, y_max, num=40)
X, Y = np.meshgrid(x_values, y_values)

The $z$-values for {obj}`~mpl_toolkits.mplot3d.Axes3D.plot_surface` can now be simply computed as follows:

In [None]:
a_init = -0.5
b_init = 20
Z = np_expr(X, Y, a=a_init, b=b_init)

We now want to create sliders for $a$ and $b$, so that we can live-update the surface plot through those sliders.

## Matplotlib widgets

Matplotlib provides its own way to define {mod}`matplotlib.widgets`.

In [None]:
fig, ax = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})

# Create sliders and insert them within the figure
plt.subplots_adjust(bottom=0.25)
a_slider = widgets.Slider(
    ax=plt.axes([0.2, 0.1, 0.65, 0.03]),
    label=f"${sp.latex(a)}$",
    valmin=-2,
    valmax=2,
    valinit=a_init,
)
b_slider = widgets.Slider(
    ax=plt.axes([0.2, 0.05, 0.65, 0.03]),
    label=f"${sp.latex(b)}$",
    valmin=10,
    valmax=50,
    valinit=b_init,
    valstep=1,
)


# Define what to do when a slider changes
def update_plot(val=None):
    a = a_slider.val
    b = b_slider.val
    ax.clear()
    Z = np_expr(X, Y, a, b)
    ax.plot_surface(
        X, Y, Z, rstride=3, cstride=1, cmap=cm.coolwarm, antialiased=False
    )
    ax.set_xlabel(f"${sp.latex(x)}$")
    ax.set_ylabel(f"${sp.latex(y)}$")
    ax.set_zlabel(f"${sp.latex(expr)}$")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.set_facecolor("white")
    fig.canvas.draw_idle()


a_slider.on_changed(update_plot)
b_slider.on_changed(update_plot)

# Plot the surface as initialization
update_plot()
plt.show()

{{ run_interactive }}

In [None]:
if STATIC_WEB_PAGE:
    from IPython.display import SVG

    output_file = "006-matplotlib-slider.svg"
    plt.savefig(output_file)
    display(SVG(output_file))