Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/getting_started/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"import yaml\n",
"\n",
"from skillmodels.config import TEST_DIR\n",
"from skillmodels.likelihood_function import get_maximization_inputs"
"from skillmodels.maximization_inputs import get_maximization_inputs"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import yaml\n",
"\n",
"from skillmodels.config import TEST_DIR\n",
"from skillmodels.likelihood_function import get_maximization_inputs\n",
"from skillmodels.maximization_inputs import get_maximization_inputs\n",
"from skillmodels.simulate_data import simulate_dataset\n",
"from skillmodels.visualize_factor_distributions import (\n",
" bivariate_density_contours,\n",
Expand Down
2 changes: 1 addition & 1 deletion src/skillmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
contextlib.suppress(Exception)

from skillmodels.filtered_states import get_filtered_states
from skillmodels.likelihood_function import get_maximization_inputs
from skillmodels.maximization_inputs import get_maximization_inputs
from skillmodels.simulate_data import simulate_dataset

__all__ = ["get_maximization_inputs", "simulate_dataset", "get_filtered_states"]
2 changes: 1 addition & 1 deletion src/skillmodels/filtered_states.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
import numpy as np

from skillmodels.likelihood_function import get_maximization_inputs
from skillmodels.maximization_inputs import get_maximization_inputs
from skillmodels.params_index import get_params_index
from skillmodels.parse_params import create_parsing_info, parse_params
from skillmodels.process_debug_data import create_state_ranges
Expand Down
247 changes: 20 additions & 227 deletions src/skillmodels/likelihood_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,191 +2,50 @@

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

import skillmodels.likelihood_function_debug as lfd
from skillmodels.clipping import soft_clipping
from skillmodels.constraints import add_bounds, get_constraints
from skillmodels.kalman_filters import (
calculate_sigma_scaling_factor_and_weights,
kalman_predict,
kalman_update,
)
from skillmodels.params_index import get_params_index
from skillmodels.parse_params import create_parsing_info, parse_params
from skillmodels.process_data import process_data
from skillmodels.process_debug_data import process_debug_data
from skillmodels.process_model import process_model
from skillmodels.parse_params import parse_params

jax.config.update("jax_enable_x64", False) # noqa: FBT003


def get_maximization_inputs(model_dict, data):
"""Create inputs for optimagic's maximize function.

Args:
model_dict (dict): The model specification. See: :ref:`model_specs`
data (DataFrame): dataset in long format.

Returns a dictionary with keys:
loglike (function): A jax jitted function that takes an optimagic-style
params dataframe as only input and returns a dict with entries:
- "value": The scalar log likelihood
- "contributions": An array with the log likelihood per observation
debug_loglike (function): Similar to loglike, with the following differences:
- It is not jitted and thus faster on the first call and debuggable
- It will add intermediate results as additional entries in the returned
dictionary. Those can be used for debugging and plotting.
gradient (function): The gradient of the scalar log likelihood
function with respect to the parameters.
loglike_and_gradient (function): Combination of loglike and
loglike_gradient that is faster than calling the two functions separately.
constraints (list): List of optimagic constraints that are implied by the
model specification.
params_template (pd.DataFrame): Parameter DataFrame with correct index and
bounds but with empty value column.

"""
model = process_model(model_dict)
p_index = get_params_index(
model["update_info"],
model["labels"],
model["dimensions"],
model["transition_info"],
)

parsing_info = create_parsing_info(
p_index,
model["update_info"],
model["labels"],
model["anchoring"],
)
measurements, controls, observed_factors = process_data(
data,
model["labels"],
model["update_info"],
model["anchoring"],
)

sigma_scaling_factor, sigma_weights = calculate_sigma_scaling_factor_and_weights(
model["dimensions"]["n_latent_factors"],
model["estimation_options"]["sigma_points_scale"],
)

partialed_get_jnp_params_vec = functools.partial(
_get_jnp_params_vec,
target_index=p_index,
)

partialed_loglikes = {}
for n, fun in {
"ll": _log_likelihood_jax,
"llo": _log_likelihood_obs_jax,
"debug_ll": lfd._log_likelihood_jax,
}.items():
partialed_loglikes[n] = _partial_some_log_likelihood_jax(
fun=fun,
parsing_info=parsing_info,
measurements=measurements,
controls=controls,
observed_factors=observed_factors,
model=model,
sigma_weights=sigma_weights,
sigma_scaling_factor=sigma_scaling_factor,
)

_jitted_loglike = jax.jit(partialed_loglikes["ll"])
_jitted_loglikeobs = jax.jit(partialed_loglikes["llo"])
_gradient = jax.jit(jax.grad(partialed_loglikes["ll"]))

def loglike(params):
params_vec = partialed_get_jnp_params_vec(params)
return float(_jitted_loglike(params_vec))

def loglikeobs(params):
params_vec = partialed_get_jnp_params_vec(params)
return _to_numpy(_jitted_loglikeobs(params_vec))

def loglike_and_gradient(params):
params_vec = partialed_get_jnp_params_vec(params)
crit = float(_jitted_loglike(params_vec))
grad = _to_numpy(_gradient(params_vec))
return crit, grad

def debug_loglike(params):
params_vec = partialed_get_jnp_params_vec(params)
jax_output = partialed_loglikes["debug_ll"](params_vec)
tmp = _to_numpy(jax_output)
tmp["value"] = float(tmp["value"])
return process_debug_data(debug_data=tmp, model=model)

constr = get_constraints(
dimensions=model["dimensions"],
labels=model["labels"],
anchoring_info=model["anchoring"],
update_info=model["update_info"],
normalizations=model["normalizations"],
)

params_template = pd.DataFrame(columns=["value"], index=p_index)
params_template = add_bounds(
params_template,
model["estimation_options"]["bounds_distance"],
)

out = {
"loglike": loglike,
"loglikeobs": loglikeobs,
"debug_loglike": debug_loglike,
"loglike_and_gradient": loglike_and_gradient,
"constraints": constr,
"params_template": params_template,
}

return out


def _partial_some_log_likelihood_jax(
fun,
def log_likelihood(
params,
parsing_info,
measurements,
controls,
observed_factors,
model,
sigma_weights,
transition_func,
sigma_scaling_factor,
sigma_weights,
dimensions,
labels,
estimation_options,
is_measurement_iteration,
is_predict_iteration,
iteration_to_period,
observed_factors,
):
update_info = model["update_info"]
is_measurement_iteration = (update_info["purpose"] == "measurement").to_numpy()
_periods = pd.Series(update_info.index.get_level_values("period").to_numpy())
is_predict_iteration = ((_periods - _periods.shift(-1)) == -1).to_numpy()
last_period = model["labels"]["periods"][-1]
# iteration_to_period is used as an indexer to loop over arrays of different lengths
# in a jax.lax.scan. It needs to work for arrays of length n_periods and not raise
# IndexErrors on tracer arrays of length n_periods - 1 (i.e. n_transitions).
# To achieve that, we replace the last period by -1.
iteration_to_period = _periods.replace(last_period, -1).to_numpy()

return functools.partial(
fun,
return log_likelihood_obs(
params=params,
parsing_info=parsing_info,
measurements=measurements,
controls=controls,
transition_func=model["transition_info"]["func"],
transition_func=transition_func,
sigma_scaling_factor=sigma_scaling_factor,
sigma_weights=sigma_weights,
dimensions=model["dimensions"],
labels=model["labels"],
estimation_options=model["estimation_options"],
dimensions=dimensions,
labels=labels,
estimation_options=estimation_options,
is_measurement_iteration=is_measurement_iteration,
is_predict_iteration=is_predict_iteration,
iteration_to_period=iteration_to_period,
observed_factors=observed_factors,
)
).sum()


def _log_likelihood_obs_jax(
def log_likelihood_obs(
params,
parsing_info,
measurements,
Expand Down Expand Up @@ -287,40 +146,6 @@ def _log_likelihood_obs_jax(
).sum(axis=0)


def _log_likelihood_jax(
params,
parsing_info,
measurements,
controls,
transition_func,
sigma_scaling_factor,
sigma_weights,
dimensions,
labels,
estimation_options,
is_measurement_iteration,
is_predict_iteration,
iteration_to_period,
observed_factors,
):
return _log_likelihood_obs_jax(
params=params,
parsing_info=parsing_info,
measurements=measurements,
controls=controls,
transition_func=transition_func,
sigma_scaling_factor=sigma_scaling_factor,
sigma_weights=sigma_weights,
dimensions=dimensions,
labels=labels,
estimation_options=estimation_options,
is_measurement_iteration=is_measurement_iteration,
is_predict_iteration=is_predict_iteration,
iteration_to_period=iteration_to_period,
observed_factors=observed_factors,
).sum()


def _scan_body(
carry,
loop_args,
Expand Down Expand Up @@ -427,35 +252,3 @@ def _one_arg_predict(kwargs, transition_func):
**kwargs,
)
return new_states, new_upper_chols, kwargs["states"]


def _to_numpy(obj):
if isinstance(obj, dict):
res = {}
for key, value in obj.items():
if np.isscalar(value):
res[key] = value
else:
res[key] = np.array(value)

elif np.isscalar(obj):
res = obj
else:
res = np.array(obj)

return res


def _get_jnp_params_vec(params, target_index):
if set(params.index) != set(target_index):
additional_entries = params.index.difference(target_index).tolist()
missing_entries = target_index.difference(params.index).tolist()
msg = "Invalid params DataFrame. "
if additional_entries:
msg += f"Your params have additional entries: {additional_entries}. "
if missing_entries:
msg += f"Your params have missing entries: {missing_entries}. "
raise ValueError(msg)

vec = jnp.array(params.reindex(target_index)["value"].to_numpy())
return vec
2 changes: 1 addition & 1 deletion src/skillmodels/likelihood_function_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from skillmodels.parse_params import parse_params


def _log_likelihood_jax(
def log_likelihood(
params,
parsing_info,
measurements,
Expand Down
Loading