Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/rl/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def test_sync_produce_strategy(self):
mock_agent_loop = MagicMock()
mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})

async def mock_gen(rs):
await asyncio.sleep(0.01 * rs[0].id)
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/rl/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class AgentLoopConfig(ABC, BaseModel):
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) # TODO: extra="forbid"
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
hf_checkpoint: str
sample_params: SampleParams

Expand Down
4 changes: 3 additions & 1 deletion xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import cast

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from xtuner.v1.data_proto import RolloutState, SampleParams
from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig
Expand All @@ -28,6 +28,8 @@ def build(self, rollout_controller, judger=None, logger=None) -> "GSM8KToolAgent


class FunctionCall(BaseModel):
model_config = ConfigDict(extra="forbid")

name: str
arguments: dict

Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/rl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd
import torch
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status
from xtuner.v1.rl.utils import (
Expand Down Expand Up @@ -410,11 +410,15 @@ async def resume(self, path: str | Path) -> None:


class SyncReplayBufferConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

def build(self):
return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage())


class AsyncReplayBufferConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

def build(self):
policy = StalenessReplayPolicy()
return ReplayBuffer(policy=policy, storage_backend=NaiveStorage())
6 changes: 5 additions & 1 deletion xtuner/v1/train/rl_colocate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray
import torch
from mmengine.dist import get_rank
from mmengine.runner import set_random_seed
from pydantic import BaseModel, ConfigDict
from typing_extensions import Literal, TypedDict

Expand All @@ -26,7 +27,7 @@
from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem
from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run
from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta
from xtuner.v1.utils import get_logger, is_hf_model_path, timer
from xtuner.v1.utils import get_logger, is_hf_model_path, set_deterministic, timer
from xtuner.v1.utils.device import get_device, get_torch_device_module


Expand Down Expand Up @@ -283,6 +284,9 @@ def __init__(
# self._total_epochs = total_epochs # TODO
self._cur_step = 0
self._global_train_step = 0
self._seed = seed
set_deterministic()
set_random_seed(seed)
self.global_batch_size = global_batch_size

# main components
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_padding_length,
is_hf_model_path,
record_git_info,
set_deterministic,
)
from .pad import pad_to_max_length, pad_to_multiple_of
from .profile import profile_time, profile_time_and_memory, timer, timer_logger
Expand Down Expand Up @@ -62,4 +63,5 @@
"clean_param_name",
"CacheDict",
"CacheObj",
"set_deterministic",
]
8 changes: 8 additions & 0 deletions xtuner/v1/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from types import FunctionType
from typing import Annotated

import torch
from huggingface_hub import constants
from mmengine import is_installed

Expand All @@ -24,6 +25,13 @@
logger = get_logger()
XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true"


def set_deterministic():
if XTUNER_DETERMINISTIC:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True, warn_only=True)


# https://github.com/python/cpython/issues/82300#issuecomment-2169035092
if sys.version_info >= (3, 13):
SharedMemory = _mpshm.SharedMemory
Expand Down
Loading