diff --git a/emulmps/emulmps_emul/emulmps_w0wa.py b/emulmps/emulmps_emul/emulmps_w0wa.py index 836e0a8..420c8d6 100644 --- a/emulmps/emulmps_emul/emulmps_w0wa.py +++ b/emulmps/emulmps_emul/emulmps_w0wa.py @@ -457,7 +457,7 @@ def _compute_boost_approximation(self, Pk_lin, params: np.ndarray) -> np.ndarray return_boost=True, Plin_in=Pk_lin * h**3, ) - return boost / h**3 + return boost def _predict_fracs_all_z(self, params_norm: np.ndarray) -> np.ndarray: """Linear model: normalised params -> log-frac for all z.""" @@ -592,6 +592,15 @@ def get_boost( boost = (np.exp(self._predict_nl_fracs_all_z(x_norm)) * syren_boost).astype(np.float32) + k_t = 0.005 # [1/Mpc] + n = 2.0 # 1 = pure exponential, larger = sharper transition + self._lin_to_nl_weight = ( + 1.0 - np.exp(-(self.K_MODES / k_t)**n) + ).astype(np.float32) + + self._k_lin_mask = self.K_MODES < k_t + boost = 1.0 + (boost - 1.0) * self._lin_to_nl_weight + return self.K_MODES, self.Z_MODES, boost def has_nl_model(self) -> bool: