# Step 3: Perform fit

As explained in the {doc}`previous step <2_generate_data>`, a {class}`.Model` object and its lambdified {class}`.Function` form behave just like a mathematical function that takes a set of data points as an argument and returns a list of intensities (real numbers). At this stage, we want to optimize the parameters of the intensity model, so that it matches the distribution of our data sample. This is what we call 'performing a fit'.

First, load the relevant data from the previous steps.

In [None]:
import pickle

from tensorwaves.model import SympyModel

with open("helicity_model.pickle", "rb") as stream:
    model = pickle.load(stream)
with open("data_set.pickle", "rb") as stream:
    data_set = pickle.load(stream)
with open("phsp_set.pickle", "rb") as stream:
    phsp_set = pickle.load(stream)
sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameters,
)

## 3.1 Define estimator

To perform a fit, you need to define an {class}`.Estimator`. An estimator is a measure for the discrepancy between the intensity model and the data distribution to which you fit. In PWA, we usually use an *unbinned negative log likelihood estimator*.

Generally, the intensity model is not normalized, but a log likelihood estimator requires a normalized function. This is where the {ref}`phase space dataset <usage/2_generate_data:2.1 Generate phase space sample>` comes into play again: the intensity is evaluated separately with the phase space dataset so that the output of the {class}`.Function` can be normalized with it. The phase space sample is therefore a required argument!

```{margin}
If you want to correct for the efficiency of the detector, you should insert a *detector-reconstructed* phase space sample here.
```

In [None]:
from tensorwaves.estimator import SympyUnbinnedNLL

estimator = SympyUnbinnedNLL(
    sympy_model,
    data_set,
    phsp_set,
    backend="jax",
)

Note that the {class}`.SympyUnbinnedNLL` can be expressed with different backends (it creates a {class}`.LambdifiedFunction` internally from the template {class}`.Model`). Here, we use {func}`jax <jax.jit>`, which is turns out to be the fastest backend for this model.

## 3.2 Optimize the intensity model

Now it's time to perform the fit!

Starting the fit itself is quite simple: just create an optimizer instance of your choice, here [Minuit2](https://root.cern.ch/doc/master/Minuit2Page.html), and call its {meth}`~.Minuit2.optimize` method to start the fitting process. Notice that the {meth}`~.Minuit2.optimize` method requires a second argument. This is a mapping of parameter names that you want to fit to their initial values.

Let's first select a few of the parameters that we saw in {ref}`Step 3.1 <usage/3_perform_fit:3.1 Define estimator>` and feed them to the optimizer to run the fit. Notice that we modify the parameters slightly to make the fit more interesting (we are running fitting to a data sample that was generated with this very same amplitude model after all).

In [None]:
initial_parameters = {
    "C[J/\\psi(1S) \\to f_{0}(1500)_{0} \\gamma_{+1};f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}]": 1.0
    + 0.0j,
    "Gamma_f(0)(500)": 0.3,
    "Gamma_f(0)(980)": 0.1,
    "m_f(0)(1710)": 1.75,
    "Gamma_f(0)(1710)": 0.2,
}

Recall that a {class}`.Function` object computes the intensity for a certain dataset. This can be seen now nicely when we use these intensities as weights on the phase space sample and plot it together with the original dataset. Here, we look at the invariant mass distribution projection of the final states `1` and `2`, which, {ref}`as we saw before <usage/2_generate_data:2.3 Visualize kinematic variables>`, are the final state particles $\pi^0,\pi^0$.

Don't forget to use {meth}`~.Function.update_parameters` first!

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def compare_model(
    variable_name,
    data_set,
    phsp_set,
    intensity_model,
    bins=100,
):
    data = data_set[variable_name]
    phsp = phsp_set[variable_name]
    intensities = intensity_model(phsp_set)
    plt.hist(data, bins=bins, alpha=0.5, label="data", density=True)
    plt.hist(
        phsp,
        weights=intensities,
        bins=bins,
        histtype="step",
        color="red",
        label="initial fit model",
        density=True,
    )
    plt.legend()

In [None]:
from tensorwaves.model import LambdifiedFunction

intensity = LambdifiedFunction(sympy_model, backend="numpy")
intensity.update_parameters(initial_parameters)
compare_model("m_12", data_set, phsp_set, intensity)

Finally, we create an {class}`.Optimizer` to {meth}`~.Minuit2.optimize` the model (which is embedded in the {class}`.Estimator`). Here, we choose the {class}`.Minuit2` optimizer, which is the most common optimizer in high-energy physics.

Notice that the {class}`.Minuit2` class allows one to list {mod}`~tensorwaves.optimizer.callbacks`. These are called during the {meth}`~.Optimizer.optimize` method. Here, we use {class}`.CallbackList` to 'stack' several callbacks together.

```{tip}
To define your own callback, create a class that inherits from the {class}`.Callback` class and feed it to the {class}`.Minuit2` constructor.
```

```{margin}
The computation time depends on the complexity of the model, number of data events, the size of the phase space sample, and the number of free parameters. This model is rather small and has but a few free parameters, so the optimization shouldn't take more than a few seconds.
```

In [None]:
from tensorwaves.optimizer.callbacks import (
    CallbackList,
    CSVSummary,
    TFSummary,
    YAMLSummary,
)
from tensorwaves.optimizer.minuit import Minuit2

minuit2 = Minuit2(
    callback=CallbackList(
        [
            TFSummary(),
            YAMLSummary("current_fit_result.yaml"),
            CSVSummary("fit_traceback.csv", step_size=2),
        ]
    ),
    use_analytic_gradient=False,  # this is still working reliably
)
result = minuit2.optimize(estimator, initial_parameters)
result

As can be seen, the values of the optimized parameters in the result are again comparable to the original values that are contained in the model ({attr}`.SympyModel.parameters`):

In [None]:
optimized_parameters = result["parameter_values"]
for p in optimized_parameters:
    print(p)
    print(f"  initial:   {initial_parameters[p]:.3}")
    print(f"  optimized: {optimized_parameters[p]:.3}")
    print(f"  original:  {sympy_model.parameters[p]:.3}")

## 3.3 Export and import

In {ref}`usage/3_perform_fit:3.2 Optimize the intensity model`, we initialized {obj}`.Minuit2` with some callbacks that are {class}`.Loadable`. Such callback classes offer the possibility to {meth}`.Loadable.load_latest_parameters`, so you can pick up the optimize process in case it crashes or if you pause it. Loading the latest parameters goes as follows:

In [None]:
latest_parameters = YAMLSummary.load_latest_parameters(
    "current_fit_result.yaml"
)
latest_parameters

To restart the fit with the latest parameters, simply rerun as before.

In [None]:
minuit2 = Minuit2()
minuit2.optimize(estimator, latest_parameters)

Lo and behold: the parameters were already optimized, so the fit converged faster!

## 3.4 Visualize

### Plot optimized model

Using the same method as above, we renew the parameters of the {class}`.Function` and plot it again over the data distribution.

In [None]:
intensity.update_parameters(latest_parameters)
compare_model("m_12", data_set, phsp_set, intensity)

### Analyze optimization process

Note that {ref}`in Step 3.2 <usage/3_perform_fit:3.2 Optimize the intensity model>`, we initialized {class}`.Minuit2` with a {class}`.TFSummary` callback as well. Its output files provide a nice, interactive representation of the fit process and can be viewed with [TensorBoard](https://www.tensorflow.org/tensorboard/get_started) as follows:

````{tabbed} Terminal
```bash
tensorboard --logdir logs
```
````

````{tabbed} Python
```python
import tensorboard as tb

tb.notebook.list()  # View open TensorBoard instances
tb.notebook.start(args_string="--logdir logs")
```
See more info [here](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks#tensorboard_in_notebooks)
````

````{tabbed} Jupyter notebook
```ipython
%load_ext tensorboard
%tensorboard --logdir logs
```
See more info [here](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks#tensorboard_in_notebooks)
````

An alternative would be to use the output of the {class}`.CSVSummary` callback. Here's an example:

In [None]:
import pandas as pd

fit_traceback = pd.read_csv("fit_traceback.csv")
fit_traceback.plot("function_call", "estimator_value")
fit_traceback.plot("function_call", sorted(initial_parameters));