diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 9bd86619df..cef4334a72 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -31,7 +31,9 @@ def calculate_kl_penalty_joschu2020( - logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor + logprobs_policy: torch.Tensor, + logprobs_reference: torch.Tensor, + clamp_value: Optional[float] = 20.0, ) -> torch.Tensor: """Calculates a per-token estimate of the KL Divergence between two log_probs. @@ -41,6 +43,8 @@ def calculate_kl_penalty_joschu2020( logprobs_reference: torch.Tensor (b, s) """ r = logprobs_reference - logprobs_policy + if clamp_value is not None: + r = r.clamp(min=-clamp_value, max=clamp_value) return torch.exp(r) - r - 1