In [None]:
%matplotlib inline


# Multi-instrument joint analysis with Gamera model


## Introduction


We are goind to reproduce the analysis of the public Crab datasets from [Nigro et al. 2019](https://www.aanda.org/articles/aa/full_html/2019/05/aa34938-18/aa34938-18.html). \
using the gamera Crab model instead of the log-parabola we used previously.
In practice, we have to:

- Read a DL4 datasets file
- Read the `~gammapy.modeling.models.Models` to apply to the datasets.
- Create a `~gammapy.modeling.Fit` object and run it to fit the model parameters
- Plot the spectrum obtained from the joint fit together with the ones obtained for each instrument fit in their respective validity range.



In [None]:
from IPython.display import display

from pathlib import Path

import numpy as np
import astropy.units as u
from astropy.coordinates import SkyCoord

import matplotlib.pyplot as plt

from gammapy.datasets import Datasets

from gammapy.modeling import Fit
from gammapy.modeling.models import Models



## Read the models

Read the models we defined previously

In [None]:
path = Path("models")
path.mkdir(exist_ok=True)

filename = path / "models_crab_gamera.yaml"
models = Models.read(filename)

models[0].parameters['effic'].frozen=False
models[0].parameters['index'].frozen=False

models[0].parameters['lon_0'].frozen=True
models[0].parameters['lat_0'].frozen=True

## Read the datasets

We read all the public datasets from the Crab 

In [None]:
filename = "./datasets/joint_crab/datasets_joint_crab_1d.yaml"

datasets_joint = Datasets.read(filename)

datasets_joint.models = models

# Joint fit

Define the fit instance

In [None]:
optimize_opts = dict(strategy=1, tolerance=0.1)
fit = Fit(optimize_opts=optimize_opts)

Let's start to fit the data from each instrument indepently

In [None]:
instruments = ["fermi", "magic", "veritas", "fact", "hess", "hawc"]
results = []
for instrument in instruments:
    datasets_instrument = Datasets([d for d in datasets_joint if instrument in d.name])
    datasets_instrument.models = models


    result_instrument = fit.run(datasets=datasets_instrument)
    
    #here we try to define the validity range of the fit 
    #we defined it as where the safe mask is true AND where there is detected counts.
    Emin = np.inf*u.TeV
    Emax = -np.inf*u.TeV
    for d in datasets_instrument:
        energy_edges = d.counts.geom.axes["energy"].edges
        ind = np.where(d.mask_safe&d.counts>0)[0]
        Emin = np.minimum(Emin, energy_edges[ind[0]])
        Emax = np.maximum(Emax, energy_edges[ind[-1]+1])

    #save the datasets, fit results and valid energy range
    result = dict()
    result["instrument"]=instrument
    result["Emin"] = Emin
    result["Emax"] = Emax
    result["models_best_fit"] = datasets_instrument.models.copy()
    result["result_minuit"] = result_instrument
    result["datasets"] = datasets_instrument
    results.append(result)


Now we do the joint fit

In [None]:
result_joint = fit.run(datasets=datasets_joint)


In [None]:
print(result_joint)

In [None]:
print(result_joint.parameters.to_table())

In [None]:
Emin = u.Quantity([res["Emin"] for res in results]).min()
Emax = u.Quantity([res["Emax"] for res in results]).max()

result = dict()
result["instrument"]="joint"
result["Emin"] = Emin
result["Emax"] = Emax
result["models_best_fit"] = datasets_joint.models.copy()
result["result_minuit"] = result_joint
result["datasets"] = datasets_joint
results.append(result)


and compare the results

In [None]:
fig, ax = plt.subplots(figsize=(10,6))

plot_kwargs = {
    "sed_type": "e2dnde",
    "ax": ax,
}
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

for result, color in zip(results, colors):
    model = result["models_best_fit"]["crab"]
    energy_bounds=u.Quantity([result["Emin"], result["Emax"]]).to("TeV")
    model.spectral_model.plot(energy_bounds,
        **plot_kwargs, label=result["instrument"], ls="-", color=color
    )
    model.spectral_model.plot_error(energy_bounds, facecolor=color, alpha=0.3, **plot_kwargs)
ax.set_ylim(5e-13, 2e-10)
ax.legend(loc=3)
plt.show()
