In [1]:
import numpy as np

import jax
from jax import numpy as jnp
from jax import tree_util as jtu
from jax_models import PowerLaw, PointSource, FluxModel, NormModel, FLUX_FACTOR
from jax_loss import CashFitStatistic, FitStatistics

from gammapy.datasets import MapDataset, Datasets
from gammapy.modeling.models import PowerLawSpectralModel, PointSpatialModel, FoVBackgroundModel
from gammapy.modeling.models import SkyModel as GPSkyModel
from gammapy.modeling import Fit
from gammapy.maps import Map

from iminuit import Minuit

jax.config.update("jax_enable_x64", True)


In [2]:
dataset = MapDataset.read("../data/test-dataset-0.fits")

In [3]:
point = PointSpatialModel(frame="galactic")
spectral = PowerLawSpectralModel(amplitude="1e-10 cm-2 s-1 TeV-1")
dataset.models = [GPSkyModel(spatial_model=point, spectral_model=spectral, name="gc"), FoVBackgroundModel(dataset_name=dataset.name)]

In [4]:
datasets = Datasets()

for idx in range(5):
    dataset.fake()
    datasets.append(dataset.copy(name=f"dataset-{idx}"))

datasets.models = dataset.models

In [5]:
stat_sum_gp = datasets.stat_sum()

In [6]:
point.lon_0.value = 0.1

def gp_stat():
    # Trigger the recomputation of the model
    point.lon_0.value *= -1. 
    return datasets.stat_sum()

In [7]:
%%timeit
gp_stat()

3.69 s ± 375 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
point.lon_0.value = 0

In [34]:
point = PointSource()
point.x_0.value = jnp.array(499.5)
point.y_0.value = jnp.array(499.5)

source_jax = FluxModel(spectral=PowerLaw(), spatial=point)
source_jax.amplitude.value = jnp.array(1e-6) / FLUX_FACTOR
bkg_jax = NormModel()
bkg_jax.spectral.index.value = jnp.array(0.)

In [35]:
stats_jax = {}

for dataset in datasets:
    stats_jax[dataset.name] = CashFitStatistic.from_gp_dataset(
        models=[source_jax, bkg_jax], dataset=dataset)
    

stat_all = FitStatistics(stats_jax)

In [11]:
stat_all_sum_jax = jax.jit(stat_all.__call__)

In [12]:
%%timeit
stat_all_sum_jax()

557 ms ± 24.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
stat_sum_jax = stat_all_sum_jax()

In [14]:
stat_all.fit_statistics["dataset-0"].npred_models.models["model_0"].npred().sum()

Array(16837.67517882, dtype=float64)

In [15]:
datasets[0].npred_signal().data.sum()

16843.219967856094

In [18]:
assert jnp.allclose(stat_sum_jax, stat_sum_gp)

In [19]:
npred_gp = dataset.npred()
npred_jax = stat_jax.npred_models.npred()

NameError: name 'stat_jax' is not defined

In [15]:
npred_jax_map = Map.from_geom(dataset.counts.geom, data=npred_jax)

NameError: name 'npred_jax' is not defined

In [16]:
assert jnp.allclose(npred_jax, npred_gp.data, atol=1e-2)

NameError: name 'npred_jax' is not defined

In [20]:
stat_all.fit_statistics["dataset-0"].npred_models.models["model_1"].model.spectral

PowerLaw(index=Parameter(value=Array(0., dtype=float32, weak_type=True), unit='', frozen=False), reference=Parameter(value=Array(1., dtype=float32, weak_type=True), unit='TeV', frozen=True))

In [17]:
p1 = stat_all.fit_statistics["dataset-0"].npred_models.models["model_0"].model.amplitude.value

In [18]:
p2 = stat_all.fit_statistics["dataset-4"].npred_models.models["model_0"].model.amplitude.value


In [19]:
p1 is p2

True

In [37]:
def prepare_parameters_iminuit(tree):
    """Prepare parameters for iminuit"""
    values, treedef = jtu.tree_flatten_with_path(tree)

    ids, id_map, name_map = [], {}, {}

    for (path, value) in values:
        id_key = id(value)
        ids.append(id_key)
        id_map.setdefault(id_key, value)
        name_map.setdefault(id_key, jtu.keystr(path[:-1]))

    unique_ids = list(id_map)
    idxs = np.array([unique_ids.index(id_) for id_ in ids])
    return list(id_map.values()), list(name_map.values()), treedef, idxs


x0, names, treedef, idxs = prepare_parameters_iminuit(stat_all)

def loss_minuit(x):
    values = [jnp.array(_) for _ in x[idxs]]
    stat_jax = treedef.unflatten(values)
    return stat_jax()


In [22]:
minuit = Minuit(jax.jit(loss_minuit), np.array(x0), name=names, grad=jax.jit(jax.grad(loss_minuit)))

# Use same defaults as Gammapy
minuit.tol = 0.1
minuit.strategy = 1

minuit.fixed[".fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.reference"] = True
minuit.fixed[".fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.reference"] = True
minuit.fixed[".fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.index"] = True

In [23]:
minuit.migrad()

Migrad,Migrad.1
FCN = 1.223e+07,"Nfcn = 76, Ngrad = 3"
EDM = 4.71e-05 (Goal: 0.0002),time = 93.7 sec
Valid Minimum,Below EDM threshold (goal x 10)
No parameters at limit,Below call limit
Hesse ok,Covariance accurate

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.amplitude,1.004,0.004,,,,,
1.0,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.index,1.9994,0.0022,,,,,
2.0,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.reference,1.00,0.01,,,,,yes
3.0,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spatial.x_0,499.494,0.015,,,,,
4.0,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spatial.y_0,499.515,0.015,,,,,
5.0,.fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.index,0.0,0.1,,,,,yes
6.0,.fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.reference,1.00,0.01,,,,,yes

0,1,2,3,4,5,6,7
,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.amplitude,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.index,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.reference,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spatial.x_0,.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spatial.y_0,.fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.index,.fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.reference
.fit_statistics['dataset-0'].npred_models.models['model_0'].model.amplitude,1.54e-05,-3e-6 (-0.384),0,0,-0,0,0
.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.index,-3e-6 (-0.384),4.9e-06,0e-6,0e-6,0e-6,0e-6,0e-6
.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spectral.reference,0,0e-6,0,0,0,0,0
.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spatial.x_0,0,0e-6,0,0.000215,0,0,0
.fit_statistics['dataset-0'].npred_models.models['model_0'].model.spatial.y_0,-0,0e-6,0,0,0.000219,0,0
.fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.index,0,0e-6,0,0,0,0,0
.fit_statistics['dataset-0'].npred_models.models['model_1'].model.spectral.reference,0,0e-6,0,0,0,0,0


In [24]:
fit = Fit()
result = fit.run(datasets)

  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)
  result = super().__array_ufunc__(function, method, *arrays, **kwargs)


In [30]:
result.minuit

Migrad,Migrad.1
FCN = 1.223e+07,Nfcn = 122
EDM = 0.000308 (Goal: 0.0002),time = 568.7 sec
Valid Minimum,Below EDM threshold (goal x 10)
No parameters at limit,Below call limit
Hesse ok,Covariance accurate

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,par_000_index,1.9992,0.0022,,,,,
1.0,par_001_amplitude,1.004,0.004,,,,,
2.0,par_002_lon_0,0.16e-3,0.15e-3,,,,,
3.0,par_003_lat_0,0.02e-3,0.15e-3,,,-90,90,

0,1,2,3,4
,par_000_index,par_001_amplitude,par_002_lon_0,par_003_lat_0
par_000_index,4.89e-06,-3e-6 (-0.384),-0,0
par_001_amplitude,-3e-6 (-0.384),1.54e-05,-0,0
par_002_lon_0,-0,-0,2.17e-08,0 (0.007)
par_003_lat_0,0,0,0 (0.007),2.16e-08
