In [1]:
import matplotlib.pyplot as plt
import numpy as np
from functools import partial

import jax
from jax import numpy as jnp
from jax import tree_util as jtu
from jax_models import PowerLaw, PointSource, FluxModel, NormModel, FLUX_FACTOR
from jax_loss import CashFitStatistic

from gammapy.datasets import MapDataset
from gammapy.modeling.models import PowerLawSpectralModel, PointSpatialModel, FoVBackgroundModel
from gammapy.modeling.models import SkyModel as GPSkyModel
from gammapy.modeling import Fit
from gammapy.maps import Map

from iminuit import Minuit

jax.config.update("jax_enable_x64", True)

In [2]:
dataset = MapDataset.read("../data/test-dataset-0.fits")

In [3]:
point = PointSpatialModel(frame="galactic")
spectral = PowerLawSpectralModel(amplitude="1e-10 cm-2 s-1 TeV-1")
dataset.models = [GPSkyModel(spatial_model=point, spectral_model=spectral, name="gc"), FoVBackgroundModel(dataset_name=dataset.name)]
dataset.fake()

In [4]:
stat_sum_gp = dataset.stat_sum()

In [5]:
point.lon_0.value = 0.1

def gp_stat():
    # Trigger the recomputation of the model
    dataset.models.parameters["lon_0"].value *= -1. 
    return dataset.stat_sum()

In [6]:
%%timeit
gp_stat()

530 ms ± 35.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
point.lon_0.value = 0

In [8]:
point = PointSource()
point.x_0.value = jnp.array(499.5)
point.y_0.value = jnp.array(499.5)

source_jax = FluxModel(spectral=PowerLaw(), spatial=point)
source_jax.amplitude.value = jnp.array(1e-6) / FLUX_FACTOR
bkg_jax = NormModel()

In [9]:
stat_jax = CashFitStatistic.from_gp_dataset(models=[source_jax, bkg_jax], dataset=dataset)
stat_jax_jit = jax.jit(stat_jax.__call__)

In [10]:
stat_sum_jax = stat_jax_jit()

In [11]:
%%timeit
stat_jax_jit()

148 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
assert jnp.allclose(stat_sum_jax, stat_sum_gp)

In [19]:
npred_gp = dataset.npred()
npred_jax = stat_jax.npred_models.npred()

In [21]:
npred_jax_map = Map.from_geom(dataset.counts.geom, data=npred_jax)

In [22]:
assert jnp.allclose(npred_jax, npred_gp.data, atol=1e-2)

In [23]:
def prepare_parameters_iminuit(tree):
    """Prepare parameters for iminuit"""
    values, treedef = jtu.tree_flatten_with_path(tree)

    names, x0 = [], []

    for (path, value) in values:
        names.append(jtu.keystr(path[:-1]))
        x0.append(value)

    return x0, names, treedef


x0, names, treedef = prepare_parameters_iminuit(stat_jax)

def loss_minuit(x):
    stat_jax = treedef.unflatten(jnp.array(x))
    return stat_jax()


In [24]:
minuit = Minuit(jax.jit(loss_minuit), np.array(x0), name=names, grad=jax.jit(jax.grad(loss_minuit)))

# Use same defaults as Gammapy
minuit.tol = 0.1
minuit.strategy = 1

minuit.fixed[".npred_models.models['model_0'].model.spectral.reference"] = True
minuit.fixed[".npred_models.models['model_1'].model.spectral.reference"] = True
minuit.fixed[".npred_models.models['model_1'].model.spectral.index"] = True

In [25]:
minuit.migrad()

Migrad,Migrad.1
FCN = 2.44e+06,"Nfcn = 92, Ngrad = 3"
EDM = 3.8e-08 (Goal: 0.0002),time = 26.1 sec
Valid Minimum,Below EDM threshold (goal x 10)
No parameters at limit,Below call limit
Hesse ok,Covariance accurate

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,.npred_models.models['model_0'].model.amplitude,1.005,0.009,,,,,
1.0,.npred_models.models['model_0'].model.spectral.index,2.005,0.005,,,,,
2.0,.npred_models.models['model_0'].model.spectral.reference,1.00,0.01,,,,,yes
3.0,.npred_models.models['model_0'].model.spatial.x_0,499.479,0.033,,,,,
4.0,.npred_models.models['model_0'].model.spatial.y_0,499.505,0.033,,,,,
5.0,.npred_models.models['model_1'].model.norm,0.9982,0.0017,,,,,
6.0,.npred_models.models['model_1'].model.spectral.index,0.0,0.1,,,,,yes
7.0,.npred_models.models['model_1'].model.spectral.reference,1.00,0.01,,,,,yes

0,1,2,3,4,5,6,7,8
,.npred_models.models['model_0'].model.amplitude,.npred_models.models['model_0'].model.spectral.index,.npred_models.models['model_0'].model.spectral.reference,.npred_models.models['model_0'].model.spatial.x_0,.npred_models.models['model_0'].model.spatial.y_0,.npred_models.models['model_1'].model.norm,.npred_models.models['model_1'].model.spectral.index,.npred_models.models['model_1'].model.spectral.reference
.npred_models.models['model_0'].model.amplitude,7.74e-05,-0.017e-3 (-0.392),0,-0,-0,-0.2e-6 (-0.011),0,0
.npred_models.models['model_0'].model.spectral.index,-0.017e-3 (-0.392),2.46e-05,0,-0 (-0.001),-0,-0.2e-6 (-0.023),0,0
.npred_models.models['model_0'].model.spectral.reference,0,0,0,0.0000,0.0000,0e-6,0,0
.npred_models.models['model_0'].model.spatial.x_0,-0,-0 (-0.001),0.0000,0.00112,0.0000,0,0.0000,0.0000
.npred_models.models['model_0'].model.spatial.y_0,-0,-0,0.0000,0.0000,0.00109,0,0.0000,0.0000
.npred_models.models['model_1'].model.norm,-0.2e-6 (-0.011),-0.2e-6 (-0.023),0e-6,0,0,3.03e-06,0e-6,0e-6
.npred_models.models['model_1'].model.spectral.index,0,0,0,0.0000,0.0000,0e-6,0,0
.npred_models.models['model_1'].model.spectral.reference,0,0,0,0.0000,0.0000,0e-6,0,0


In [26]:
fit = Fit()
result = fit.run(dataset)

In [27]:
result.minuit

Migrad,Migrad.1
FCN = 2.44e+06,Nfcn = 169
EDM = 0.000157 (Goal: 0.0002),time = 117.5 sec
Valid Minimum,Below EDM threshold (goal x 10)
No parameters at limit,Below call limit
Hesse ok,Covariance accurate

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,par_000_index,2.005,0.005,,,,,
1.0,par_001_amplitude,1.005,0.009,,,,,
2.0,par_002_lon_0,0.15e-3,0.33e-3,,,,,
3.0,par_003_lat_0,-0.16e-3,0.33e-3,,,-90,90,
4.0,par_004_norm,0.9982,0.0017,,,,,

0,1,2,3,4,5
,par_000_index,par_001_amplitude,par_002_lon_0,par_003_lat_0,par_004_norm
par_000_index,2.46e-05,-0.017e-3 (-0.392),0,-0 (-0.001),-0.2e-6 (-0.023)
par_001_amplitude,-0.017e-3 (-0.392),7.74e-05,0,-0,-0.2e-6 (-0.011)
par_002_lon_0,0,0,1.09e-07,-0 (-0.011),-0
par_003_lat_0,-0 (-0.001),-0,-0 (-0.011),1.11e-07,0
par_004_norm,-0.2e-6 (-0.023),-0.2e-6 (-0.011),-0,0,3.03e-06
