diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 6a361bd3d..5bd6cbcc9 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -37,7 +37,6 @@ from xtuner.v1.ray.config import RolloutConfig from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.rl.base.loss import BaseRLLossContext -from xtuner.v1.rl.utils import gather_logprobs from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( XTUNER_DETERMINISTIC, @@ -391,9 +390,10 @@ def compute_ref_logprobs( ref_logprobs_list: list[torch.Tensor] = [] for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list): with torch.no_grad(): - ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None) - ref_logprobs = gather_logprobs(ref_output["logits"], shifted_labels) - ref_logprobs_list.append(ref_logprobs) + loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels}) + assert loss_ctx is not None + ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) + ref_logprobs_list.append(ref_output["loss"]) self._ref_model.to_device("cpu") return ref_logprobs_list