From 340a4e81797f6aab92e3d3f0b1db054e3b06ccea Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 8 Apr 2026 11:14:10 +0000 Subject: [PATCH] fix refer model with RL --- xtuner/v1/rl/base/worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 6a361bd3d..5bd6cbcc9 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -37,7 +37,6 @@ from xtuner.v1.ray.config import RolloutConfig from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.rl.base.loss import BaseRLLossContext -from xtuner.v1.rl.utils import gather_logprobs from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( XTUNER_DETERMINISTIC, @@ -391,9 +390,10 @@ def compute_ref_logprobs( ref_logprobs_list: list[torch.Tensor] = [] for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list): with torch.no_grad(): - ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx=None) - ref_logprobs = gather_logprobs(ref_output["logits"], shifted_labels) - ref_logprobs_list.append(ref_logprobs) + loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels}) + assert loss_ctx is not None + ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) + ref_logprobs_list.append(ref_output["loss"]) self._ref_model.to_device("cpu") return ref_logprobs_list