Skip to content

Commit

Permalink
fix: support data generation with jax (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed May 11, 2021
1 parent 66e24e0 commit 18acfa7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
" expression=model.expression,\n",
" parameters=model.parameter_defaults,\n",
")\n",
"intensity = LambdifiedFunction(sympy_model, backend=\"numpy\")\n",
"intensity = LambdifiedFunction(sympy_model, backend=\"jax\")\n",
"data_converter = HelicityTransformer(model.adapter)\n",
"phsp_sample = generate_phsp(100_000, model.adapter.reaction_info)\n",
"data_sample = generate_data(\n",
Expand Down Expand Up @@ -246,7 +246,7 @@
"source": [
"phsp_set = data_converter.transform(phsp_sample)\n",
"data_set = data_converter.transform(data_sample)\n",
"data_frame = pd.DataFrame(data_set.to_pandas())\n",
"data_frame = pd.DataFrame(data_set)\n",
"data_frame[\"m_12\"].hist(bins=100, alpha=0.5, density=True)\n",
"indicate_masses()\n",
"plt.legend();"
Expand Down
4 changes: 2 additions & 2 deletions docs/usage/step2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"data_frame = pd.DataFrame(data_set.to_pandas())\n",
"phsp_frame = pd.DataFrame(data_set.to_pandas())\n",
"data_frame = pd.DataFrame(data_set)\n",
"phsp_frame = pd.DataFrame(data_set)\n",
"data_frame"
]
},
Expand Down
4 changes: 3 additions & 1 deletion src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Implementations of `.DataTransformer`."""

import numpy as np
from ampform.kinematics import EventCollection, HelicityAdapter

from tensorwaves.interfaces import DataSample, DataTransformer
Expand All @@ -17,4 +18,5 @@ def __init__(self, helicity_adapter: HelicityAdapter) -> None:

def transform(self, dataset: DataSample) -> DataSample:
events = EventCollection({int(k): v for k, v in dataset.items()})
return self.__helicity_adapter.transform(events)
dataset = self.__helicity_adapter.transform(events)
return {key: np.array(values) for key, values in dataset.items()}

0 comments on commit 18acfa7

Please sign in to comment.