# TensorWaves demo PANDA Seminar December 2021

<!-- cspell:ignore cmap coolwarm cstride iplt ipympl ipyplot meshgrid rstride timeit xlim xticks ylabel ylim zlabel zticks -->
This notebook accompanies [these slides](https://docs.google.com/presentation/d/e/2PACX-1vSymz5AjdhPw4Kz1pKhdFMnFGYuQvVaC8WbV_HTg770x6RDYoP-Anv9tn88DSuzvSiiQ9F4pcDGVExv/pub). They were presented during a PANDA Seminar on 13 December 2021.

Related notebooks for this presentation:
- [QRules demo](./qrules.ipynb)
- [AmpForm demo](./ampform.ipynb)

For more extensive examples, see **[tensorwaves.rtfd.io](https://tensorwaves.readthedocs.io)**.

## Install dependencies

In [None]:
%pip install -q ipympl jax jaxlib git+https://github.com/ComPWA/tensorwaves@main

In [None]:
%load_ext autoreload
%autoreload
%config InlineBackend.figure_formats = ['svg']
%matplotlib widget
import inspect

import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from black import FileMode, format_str
from matplotlib import cm
from sympy import sin, sqrt
from tensorwaves.function.sympy import lambdify

## Core basics

In [None]:
x, y, a, b = sp.symbols("x y a b")
expression = sqrt(x ** a + sin(y / b) ** 2)
expression

In [None]:
function = lambdify(
    expression,
    symbols=(x, y, a, b),
    backend="numpy",
    use_cse=False,
)
src = inspect.getsource(function)
src = format_str(src, mode=FileMode())
print(src)

In [None]:
# Set plot domain
x_min, x_max = 0.1, 2
y_min, y_max = -50, +50
x_values = np.linspace(x_min, x_max, num=20)
y_values = np.linspace(y_min, y_max, num=40)
X, Y = np.meshgrid(x_values, y_values)

# Create interactive plots
fig = plt.figure()
ax = plt.axes(projection="3d")
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel(f"${sp.latex(x)}$")
ax.set_ylabel(f"${sp.latex(y)}$")
ax.set_zlabel(f"${sp.latex(expression)}$")
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
surface = None


@ipywidgets.interact(a=(-2.0, 2.0), b=(10, 50))
def plot(a=-0.5, b=20):
    global surface
    Z = function(X, Y, a, b)
    if surface is not None:
        surface.remove()
    surface = ax.plot_surface(
        X, Y, Z, rstride=2, cstride=1, cmap=cm.coolwarm, antialiased=False
    )

## Fast computations

In [None]:
from tensorwaves.function.sympy import create_function

data = dict(
    x=np.random.uniform(x_min, x_max, size=1_000_000),
    y=np.random.uniform(y_min, y_max, size=1_000_000),
    a=-0.5,
    b=20,
)
numpy_function = create_function(expression, backend="numpy")
numba_function = create_function(expression, backend="numba")
jax_function = create_function(expression, backend="jax")

In [None]:
%%timeit
numpy_function(data)

In [None]:
%%timeit
numba_function(data)

In [None]:
%%timeit
jax_function(data)