# Creating custom plots

Sometimes, the functions exposed by Sympy's plotting module are not enough to accomplish our visualization objectives. If that's the case, we can either:
1. `lambdify` the symbolic expressions and evaluate it numerically. However, this process is manually intensive.
2. If the expressions can be plotted by the common plotting functions (`plot`, `plot3d`, `plot_parametric`, ...), then we can use the `get_plot_data` function, which automate the _lambdifying_ process. This function accepts the same arguments of the aforementioned plotting functions, therefore it is really easy to get the numerical data we are interested in.

Once we have the numerical data, we can use our preferred plotting library. If we are lucky enough, we can also:
1. use one of the plotting functions as a starting point;
2. then we extract the numerical data;
3. then we extract the plot object associated to the plotting library;
4. use the appropriate command of the plotting library to add the new data to the plot.

Let's see a few examples.

## Example #1

Let's start with the imports:

In [None]:
from sympy import *
init_printing(use_latex=True)

# NOTE: here we imported all the plotting functions as well as get_plot_data
from spb import *
## In case we are only interested in get_plot_data
# from spb import get_plot_data

# In this tutorial we are going to use Plotly: we will understand later
# the reason behind this choice
from spb.backends.plotly import PB

Suppose we would like to plot the following vector field:

$$
\vec{F}(x, y) = (-y, x)
$$

Specifically, we would like an heatmap of its magnitude, with quivers (representing the vector field) on top of it. Currently, there are no functions to quickly achieve that. Nonetheless, we are going to see that the process is relatively easy.

Instead of dealing with vectors from `sympy.vector` module (which is a [PITA](https://www.urbandictionary.com/define.php?term=pita)), we are going to create the vector components separately. Also, we are going to restrict our domain to $-5 \le x \le 5, \, -5 \le y \le 5$:

In [None]:
var("x, y")

# vector components
u = -y
v = x
ranges = (x, -5, 5), (y, -5, 5)

# magnitude of the vector field
magn = sqrt(u**2 + v**2)

### Use one of the plotting functions as a starting point

We can create a contour plot for the magnitude of the vector field. Since the plot is the starting point for our customization, it makes sense to visualize it:

In [None]:
title = r"$\vec{F}(%s, %s) = (%s, %s)$" % (latex(x), latex(y), latex(u), latex(v))
p = plot_contour((magn, "Magnitude"), *ranges, backend=PB, show=True,
                 aspect_ratio="equal", title=title,
                 contours_coloring="heatmap", contours_labels=True)

First, a note of caution for FireFox users: if you have problems with the visualization, try to change the latex title with a non-latex one.

Let's explain a few things:
* In this tutorial we are using `PlotlyBackend` (aka `PB`) because it exposes some keyword arguments that can be used to customize the contour plot. Run the next code cell to learn more about them.
* We set an equal aspect ratio for both axis in order to avoid distorsions.
* If we are moving fast in our development workflow, it is very well possible to set `show=False` in order to not visualize the plot.

In [None]:
help(PB)

### Extract the numerical data

Now, the fun begins. In order to plot the quivers, we can think a vector component as a function of two variables. Clearly, in this simple case:

$$
\begin{aligned}
u &= f(x, y) = -y \qquad \text{we can think } x = 0 \\
v &= f(x, y) = x \qquad \text{we can think } y = 0
\end{aligned}
$$

To visualize the process, let's plot the components:

In [None]:
asd = plot3d(u, v, *ranges, backend=PB, legend=True)

The above plot represents the data that we would like to extract from our symbolic vector components in order to create the quivers. The easiest way to do that is to use the `get_plot_data` function, which requires the same arguments we would use for any plot function, namely a tuple of the form `expr, range, label [optional]`:

In [None]:
xx, yy, uu = get_plot_data(u, *ranges, n=20)

Why did we get this error? The automatic algorithm inside `get_plot_data` is trying to understand what kind of expression we passed in. It does so by analyzing the number of free symbols, the number of sub-expressions (if we are passing in parametric expressions) and the number of ranges. Here, we provided the expression `u = -y` (one free symbol, `x`), and two ranges (`(x, -5, 5), (y, -5, 5)`). The automatic algorithm is not able to associate these information to any of the default plot functions.

In order to solve the problem, we must explicitly tell the function what we are trying to compute. Let's first read its documentation:

In [None]:
help(get_plot_data)

As we can see, an optional `pt` keyword argument can be provided, to explicitly tells the function what kind of plot we are trying to do. By providing this argument, the automatic algorithm will be bypassed.

Then, in our case:

In [None]:
xx, yy, uu = get_plot_data(u, *ranges, n=20, pt="p3d")
_, _, vv = get_plot_data(v, *ranges, n=20, pt="p3d")

Here, we explicitly asked the function to evaluate the expression in order to obtain numerical data for a function of two variables (remember, $u = f(x, y) = -y$).

In the first line of the previous cell, `xx, yy, uu` are three different two-dimensional numpy arrays:
* `xx`: represents the meshgrid along the x-axis.
* `yy`: represents the meshgrid along the y-axis.
* `uu`: represents the numerically-evaluated `u` component of the vector.

In the second line of code we are only interested in `vv`, as the first two elements of the output result are identical to `xx` and `yy` (because of the same ranges and same number of discretization points).

Note that we reduced the number of evaluation points to 20, so that the quivers will be nicely visible on the plot (too high the number, and the quivers might overlap).

### Extract the plot object associated to the plotting library

Let's now extract the plot object associate to Plotly:

In [None]:
plot = p.fig

In [None]:
type(plot)

### Use the appropriate command of the plotting library

As we can see, `plot` is a Plotly's `Figure` object, therefore we can use Plotly's command to add quivers:

In [None]:
from plotly.figure_factory import create_quiver
quiver = create_quiver(xx, yy, uu, vv, scale=0.1, line_color="aqua")
plot.add_traces([quiver.data[0]])

At this point we might as well set the _x_ range and the _y_ range:

In [None]:
plot.update_layout(xaxis_range=ranges[0][1:], yaxis_range=ranges[0][1:])

Et voilà, done it!

## Example #2

We are now going to compress all the previous commands into a single cell, with a different vector field:

In [None]:
var("x, y")
u = x**2 - y**2 - 4
v = 2 * x * y
ranges = (x, -5, 5), (y, -5, 5)
magn = sqrt(u**2 + v**2)
title = r"$\vec{F}(%s, %s) = (%s, %s)$" % (latex(x), latex(y), latex(u), latex(v))
p = plot_contour((magn, "Magnitude"), *ranges, backend=PB, show=False,
                 aspect_ratio="equal", title=title,
                 contours_coloring="heatmap", contours_labels=True)
xx, yy, uu = get_plot_data(u, *ranges, n=15)
_, _, vv = get_plot_data(v, *ranges, n=15)
plot = p.fig
quiver = create_quiver(xx, yy, uu, vv, scale=0.035, line_color="aqua")
plot.add_traces([quiver.data[0]])
plot.update_layout(xaxis_range=ranges[0][1:], yaxis_range=ranges[0][1:])

Note that we didn't need to specify `pt="p3d"` because both components have `x` and `y` terms, hence the automatic algorithm was able to infere the kind of data we were interested in.

## Example #3

Let's now create a simple plot with two horizontal axis.

Let's suppose we are building an analytical model of a nuclear reactor fuel element of anular geometry, where the coolant flows in the inner tube. The following expressions represent the temperature distributions and power density along the channel's length:

In [None]:
r, ro, ri = symbols("r, r_o, r_i")
mdot, cp, hc = symbols(r"\dot{m}, c_p, h_c")
alpha, k, L, z = symbols("alpha, k, L, z")
Tin, Pave = symbols(r"T_{in}, P_{ave}")

# Fuel temperature distribution along the channel
# here, the only variable is z 
Tf = (Tin + (Pave * L * pi * (ro**2 - ri**2) / (2 * mdot * cp)) * 
      (1 - sin(alpha * (L / 2 - z)) / sin(alpha * L / 2)) + 
      (alpha * Pave * L  / 2) * (cos(alpha * (L / 2 - z)) / sin(alpha * L / 2)) *
      ((ro**2 - ri**2) / (2 * hc * ri) - (1 / (2 * k)) * 
      ((r**2 - ri**2) / 2 + ro**2 * log(ri / r)))
     )
Tf

In [None]:
# Fuel temperature distribution at the inner and outer walls
Twi = Tf.subs(r, ri)
Two = Tf.subs(r, ro)

In [None]:
# Cooling fluid temperature
Tp = (Tin + (Pave * L / 2) * pi * (ro**2 - ri**2) / (mdot * cp) * 
     (1 - sin(alpha * (L / 2 - z)) / sin(alpha * L / 2)))
Tp

In [None]:
# Power density
P = (alpha * Pave * L) / (2 * sin(alpha * L / 2)) * cos(alpha * (L / 2 - z))
P

In [None]:
# numerical values (they might come from some sliders)
d = {
    ri: 0.2,
    ro: 0.4,
    L: 100,
    Pave: 1000,
    hc: 1,
    alpha: 0.031,
    mdot: 1,
    k: 0.2,
    cp: 15,
    Tin: 300,
}

In [None]:
_range = (z, 0, 100)

xp, tp = get_plot_data(Tp.subs(d), _range)
xwi, twi = get_plot_data(Twi.subs(d), _range)
xwo, two = get_plot_data(Two.subs(d), _range)
xpower, power = get_plot_data(P.subs(d), _range)

import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(x=tp, y=xp, name="$T_{p}$",
                        line = dict(color="cyan")))
fig.add_trace(go.Scatter(x=twi, y=xwi, name="$T_{w, i}$",
                        line = dict(color="cyan", dash="dash")))
fig.add_trace(go.Scatter(x=two, y=xwo, name="$T_{w, o}$",
                        line = dict(color="cyan", dash="dot")))
fig.add_trace(go.Scatter(x=power, y=xpower, name="$P$", xaxis="x2", line_color="crimson"))

# Create axis objects
fig.update_layout(
    template = "plotly_dark",
    xaxis=dict(
        title = "Temperature [K]",
        titlefont = dict(
            color = "cyan"
        ),
        tickfont = dict(
            color = "cyan"
        ),
        range = [0, 3000]
    ),
    xaxis2=dict(
        title = "Fuel Power Density [W/cm^3]",
        overlaying = "x",
        side = "top",
        titlefont = dict(
            color = "crimson"
        ),
        tickfont = dict(
            color = "crimson"
        ),
        range = [0, 2000],
        showgrid = False
    ),
    yaxis = dict(
        autorange = "reversed"
    )
)

fig.show()

There we have it. We easily built a custom plot starting from symbolic expressions.