diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py index 9459fe7a3..956754693 100644 --- a/tests/rl/test_producer.py +++ b/tests/rl/test_producer.py @@ -49,6 +49,7 @@ async def test_sync_produce_strategy(self): mock_agent_loop = MagicMock() mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) + mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) async def mock_gen(rs): await asyncio.sleep(0.01 * rs[0].id) diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py index 04cdd4442..dcd87eb71 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop.py +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -13,7 +13,7 @@ class AgentLoopConfig(ABC, BaseModel): - model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) # TODO: extra="forbid" + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) hf_checkpoint: str sample_params: SampleParams diff --git a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py index b18ddc775..39a37cb82 100644 --- a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py +++ b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py @@ -3,7 +3,7 @@ import re from typing import cast -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from xtuner.v1.data_proto import RolloutState, SampleParams from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig @@ -28,6 +28,8 @@ def build(self, rollout_controller, judger=None, logger=None) -> "GSM8KToolAgent class FunctionCall(BaseModel): + model_config = ConfigDict(extra="forbid") + name: str arguments: dict diff --git a/xtuner/v1/rl/replay_buffer.py b/xtuner/v1/rl/replay_buffer.py index 58f19d90e..6dabce99d 100644 --- a/xtuner/v1/rl/replay_buffer.py +++ b/xtuner/v1/rl/replay_buffer.py @@ -7,7 +7,7 @@ import pandas as pd import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status from xtuner.v1.rl.utils import ( @@ -410,11 +410,15 @@ async def resume(self, path: str | Path) -> None: class SyncReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + def build(self): return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage()) class AsyncReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + def build(self): policy = StalenessReplayPolicy() return ReplayBuffer(policy=policy, storage_backend=NaiveStorage()) diff --git a/xtuner/v1/train/rl_colocate_trainer.py b/xtuner/v1/train/rl_colocate_trainer.py index 7f5429fe6..114983190 100644 --- a/xtuner/v1/train/rl_colocate_trainer.py +++ b/xtuner/v1/train/rl_colocate_trainer.py @@ -8,6 +8,7 @@ import ray import torch from mmengine.dist import get_rank +from mmengine.runner import set_random_seed from pydantic import BaseModel, ConfigDict from typing_extensions import Literal, TypedDict @@ -26,7 +27,7 @@ from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta -from xtuner.v1.utils import get_logger, is_hf_model_path, timer +from xtuner.v1.utils import get_logger, is_hf_model_path, set_deterministic, timer from xtuner.v1.utils.device import get_device, get_torch_device_module @@ -283,6 +284,9 @@ def __init__( # self._total_epochs = total_epochs # TODO self._cur_step = 0 self._global_train_step = 0 + self._seed = seed + set_deterministic() + set_random_seed(seed) self.global_batch_size = global_batch_size # main components diff --git a/xtuner/v1/utils/__init__.py b/xtuner/v1/utils/__init__.py index 2c8d866b4..1704607a2 100644 --- a/xtuner/v1/utils/__init__.py +++ b/xtuner/v1/utils/__init__.py @@ -17,6 +17,7 @@ get_padding_length, is_hf_model_path, record_git_info, + set_deterministic, ) from .pad import pad_to_max_length, pad_to_multiple_of from .profile import profile_time, profile_time_and_memory, timer, timer_logger @@ -62,4 +63,5 @@ "clean_param_name", "CacheDict", "CacheObj", + "set_deterministic", ] diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index e9aaf82bb..44f7339a2 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -9,6 +9,7 @@ from types import FunctionType from typing import Annotated +import torch from huggingface_hub import constants from mmengine import is_installed @@ -24,6 +25,13 @@ logger = get_logger() XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true" + +def set_deterministic(): + if XTUNER_DETERMINISTIC: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True, warn_only=True) + + # https://github.com/python/cpython/issues/82300#issuecomment-2169035092 if sys.version_info >= (3, 13): SharedMemory = _mpshm.SharedMemory