Skip to content

Commit

Permalink
fix init in sngp (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed May 11, 2023
1 parent eb3660a commit df5cf45
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions fortuna/model/model_manager/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
self.ridge_penalty = ridge_penalty
self.momentum = momentum
self.mean_field_factor = mean_field_factor
self._gp_output_model = None
self._gp_output_model = self._get_output_model()
self._gp_output_model_mutable_keys = [
"sngp_random_features",
"sngp_laplace_covariance",
Expand Down Expand Up @@ -263,7 +263,6 @@ def init(
f"In order to use SNGP the output shape of the provide model has to be of shape"
f"(batch_size, n_features)."
)
self._gp_output_model = self._get_output_model()
gp_params = self._gp_output_model.init(rngs, jnp.zeros(output_shape), **kwargs)
params = nested_update(model_params.unfreeze(), gp_params.unfreeze())
return dict(model=FrozenDict(params))

0 comments on commit df5cf45

Please sign in to comment.