# PyHEP 2022 Notebook Talk ― fit

In [None]:
import logging
import os

import black
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Image
from matplotlib.animation import PillowWriter
from tensorwaves.optimizer.callbacks import Callback

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)


def format_str(src: str) -> str:
    return black.format_str(src, mode=black.FileMode())


class FitAnimation(Callback):
    def __init__(
        self, data, function, binning, output_file, estimated_iterations, fps=15
    ):
        plt.ioff()
        self.__function = function
        self.__fig, (self.__ax1, self.__ax2) = plt.subplots(
            figsize=(7, 7), nrows=2, tight_layout=True
        )
        self.__ax2.set_yticks(np.arange(-30, 80, 10))
        self.__ax1.hist(data["x"], bins=binning, alpha=0.7, density=True, label="data")
        self.__line = self.__ax1.plot(
            binning,
            self.__function({"x": binning}),
            c="red",
            linewidth=2,
            label="model",
        )[0]
        self.__ax1.legend(loc="upper right")

        self.__par_lines = [
            self.__ax2.plot(0, value, label=par)[0]
            for par, value in self.__function.parameters.items()
        ]
        self.__ax2.set_xlim(0, estimated_iterations)
        self.__ax2.set_title("Parameter values")
        self.__ax2.legend(
            [
                f"${sp.latex(sp.Symbol(par_name))}$"
                for par_name in self.__function.parameters
            ],
            loc="upper right",
        )

        self.__writer = PillowWriter(fps=fps)
        self.__writer.setup(self.__fig, outfile=output_file)

    def on_optimize_start(self, logs):
        self._update_plot()

    def on_optimize_end(self, logs):
        self._update_plot()
        self.__writer.finish()
        plt.ion()

    def on_iteration_end(self, iteration, logs):
        self._update_plot()
        self.__writer.finish()

    def on_function_call_end(self, function_call, logs):
        self._update_plot()

    def _update_plot(self):
        self._update_parametrization_plot()
        self._update_traceback()
        self.__writer.grab_frame()

    def _update_parametrization_plot(self):
        title = self._render_parameters(self.__function.parameters)
        self.__ax1.set_title(title)
        self.__line.set_ydata(self.__function({"x": binning}))

    def _update_traceback(self):
        for line in self.__par_lines:
            par_name = line.get_label()
            new_value = self.__function.parameters[par_name]
            x = line.get_xdata()
            x = [*x, x[-1] + 1]
            y = [*line.get_ydata(), new_value]
            line.set_xdata(x)
            line.set_ydata(y)
        y_values = np.array([line.get_ydata() for line in self.__par_lines])
        self.__ax2.set_ylim(y_values.min() * 1.1, y_values.max() * 1.1)

    @staticmethod
    def _render_parameters(parameters):
        values = []
        for name, value in parameters.items():
            symbol = sp.Dummy(name)
            latex = sp.latex(symbol)
            values.append(f"{latex}={value:.2g}")
        return f'${",".join(values)}$'

## Performing a fit

In [None]:
import sympy as sp


def gaussian(x, mu, sigma) -> sp.Expr:
    return sp.exp(-(((x - mu) / sigma) ** 2) / 2)


def poisson(x, k) -> sp.Expr:
    return x**k * sp.exp(-x) / sp.factorial(k)

In [None]:
x, a, b, c, mu1, mu2, sigma1, sigma2 = sp.symbols("x (a:c) mu_(:2) sigma_(:2)")
expression_1d = (
    a * gaussian(x, mu1, sigma1) + b * gaussian(x, mu2, sigma2) + c * poisson(x, k=2)
)
expression_1d

In [None]:
parameter_defaults = {
    a: 0.15,
    b: 0.05,
    c: 0.3,
    mu1: 1.0,
    sigma1: 0.3,
    mu2: 2.7,
    sigma2: 0.5,
}
substituted_expr_1d = expression_1d.subs(parameter_defaults)
substituted_expr_1d

In [None]:
x_range = (x, 0, 5)
p1 = sp.plot(substituted_expr_1d, x_range, show=False, line_color="red")
p2 = sp.plot(*substituted_expr_1d.args, x_range, show=False, line_color="gray")
p2.append(p1[0])
p2.show()

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

function_1d = create_parametrized_function(
    expression=expression_1d,
    parameters=parameter_defaults,
    backend="jax",
    use_cse=False,
)

In [None]:
from tensorwaves.function import get_source_code

src = get_source_code(function_1d)
src = format_str(src)
print(src)

In [None]:
from tensorwaves.data import NumpyDomainGenerator, NumpyUniformRNG

rng = NumpyUniformRNG(seed=0)
domain_generator = NumpyDomainGenerator(boundaries={"x": (0, 5)})
domain = domain_generator.generate(1_000_000, rng)

In [None]:
from tensorwaves.data import IntensityDistributionGenerator

data_generator = IntensityDistributionGenerator(domain_generator, function_1d)
data = data_generator.generate(1_000_000, rng)

In [None]:
initial_parameters = {
    "a": 0.2,
    "b": 0.3,
    "c": 0.4,
    "mu_0": 0.3,
    "mu_1": 3.2,
    "sigma_0": 0.3,
    "sigma_1": 0.4,
}
function_1d.update_parameters(initial_parameters)

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
ax.set_xlabel("$x$")
ax.set_yticks([])
binning = np.linspace(0, 5, num=200)
bin_values, *_ = ax.hist(data["x"], bins=binning, alpha=0.7, density=True, label="data")
bin_centers = (binning[1:] + binning[:-1]) / 2
plot_domain = {"x": bin_centers}
lines = ax.plot(
    bin_centers, function_1d(plot_domain), c="red", linewidth=2, label="model"
)
ax.legend(loc="upper right")
plt.show()

In [None]:
from tensorwaves.estimator import ChiSquared
from tensorwaves.optimizer import Minuit2

estimator = ChiSquared(
    function_1d,
    domain=plot_domain,
    observed_values=bin_values,
    backend="jax",
)
optimizer = Minuit2()
fit_result = optimizer.optimize(estimator, initial_parameters)
fit_result

In [None]:
optimizer = Minuit2(
    callback=FitAnimation(
        data,
        function_1d,
        binning,
        "fit-animation.gif",
        estimated_iterations=290,
        fps=25,
    ),
)
fit_result = optimizer.optimize(estimator, initial_parameters)
Image("fit-animation.gif")