Skip to content

Commit

Permalink
Add prim operators to query/update CUDA default RNG state (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Apr 24, 2024
1 parent e0ab648 commit 8aed212
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 3 deletions.
101 changes: 101 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ class PrimIDs(Enum):
# Memory access methods
ITEM = auto()
COPY_ = auto()
SET_RNG_STATE = auto()
GET_RNG_STATE = auto()
UNPACK_RNG_STATE = auto()
PACK_RNG_STATE = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -2536,6 +2540,103 @@ def _uniform_meta(
)


def _set_rng_state_meta(new_state: TensorProxy, device: devices.Device | None = None) -> TensorProxy:
utils.check_type(new_state, TensorProxy)
utils.check(
isinstance(new_state.dtype, dtypes.unsignedinteger),
lambda: f"new_state dtype={new_state.dtype} was not a uint8 dtype",
)
utils.check(new_state.device is devices.cpu, lambda: f"new_state should be CPU tensor")
if device is not None:
utils.check(
device.devicetype is devices.DeviceType.CUDA,
lambda: f"set_rng_state is supported for CUDA only",
exception_type=NotImplementedError,
)
return TensorProxy(like=new_state)


set_rng_state = make_prim(
PrimIDs.SET_RNG_STATE,
"set_rng_state",
meta=_set_rng_state_meta,
tags=(OpTags.RANDOM_OP, OpTags.DONT_DCE),
)


# NOTE: the input state is used to force the dependency when using multiple get/set_rng_state calls,
# The user should ensure that when the default cuda generator state is obtained for the device for the first time, the state must be passed as "None".
def _get_rng_state_meta(state: TensorProxy | NoneType, device: devices.Device | None = None) -> TensorProxy:
# RNG state is the cancatenate of 64-bit seed and 64-bit offset. Its type is uint8. So state_shape = dtypes.int64.bytes//dtypes.uint8.bytes * 2
state_shape = 16
utils.check(state is not None or device is not None, lambda: f"state and device cannot both be None")
if state is not None:
utils.check_type(state, TensorProxy)
utils.check(
isinstance(state.dtype, dtypes.unsignedinteger), lambda: f"state dtype={state.dtype} was not a uint8 dtype"
)
utils.check(state.device.devicetype is devices.DeviceType.CPU, lambda: f"RNG state must be CPU tensor")
utils.check(
utils.same_shape(state.shape, (state_shape,)),
lambda: f"state shape must be ({state_shape},), but got {state.shape}",
)
if device is not None:
utils.check(
device.devicetype is devices.DeviceType.CUDA,
lambda: f"get_rng_state is supported for CUDA only",
exception_type=NotImplementedError,
)
return TensorProxy(shape=(state_shape,), dtype=dtypes.uint8, device=devices.cpu)


get_rng_state = make_prim(
PrimIDs.GET_RNG_STATE,
"get_rng_state",
meta=_get_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)


def _unpack_rng_state_meta(state: TensorProxy) -> tuple[NumberProxy, NumberProxy]:
utils.check_type(state, TensorProxy)
utils.check(
isinstance(state.dtype, dtypes.unsignedinteger), lambda: f"state dtype={state.dtype} was not a uint8 dtype"
)
utils.check(state.device.devicetype is devices.DeviceType.CPU, lambda: f"RNG state must be CPU tensor")
# RNG state is the cancatenate of 64-bit seed and 64-bit offset. Its type is uint8. So state_shape = dtypes.int64.bytes//dtypes.uint8.bytes * 2
state_shape = 16
utils.check(
utils.same_shape(state.shape, (state_shape,)),
lambda: f"state shape must be ({state_shape},), but got {state.shape}",
)
return numberproxy(int, 0), numberproxy(int, 0)


unpack_rng_state = make_prim(
PrimIDs.UNPACK_RNG_STATE,
"unpack_rng_state",
meta=_unpack_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)


def _pack_rng_state_meta(seed: Number, offset: Number) -> TensorProxy:
utils.check_type(seed, Number)
utils.check_type(offset, Number)
utils.check_same_dtype(seed, offset)
# state_shape is dtypes.int64.bytes//dtypes.uint8.bytes * 2
state_shape = 16
return TensorProxy(shape=(state_shape,), dtype=dtypes.uint8, device=devices.cpu)


pack_rng_state = make_prim(
PrimIDs.PACK_RNG_STATE,
"pack_rng_state",
meta=_pack_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)


def _uniform_philox_meta(
shape: Sequence[int],
minval: float,
Expand Down
66 changes: 66 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,78 @@ def _tensor_from_sequence_prims_transform(
return tensor_from_sequence(seq_or_number, device=torch_device, dtype=torch_dtype)


def _set_rng_state_prim_impl(new_state: torch.Tensor, device: torch.device | None) -> torch.Tensor:
if device is None:
device = torch.cuda.current_device()
torch.cuda.set_rng_state(new_state, device)
return new_state


set_rng_state_prim_impl = ex.register_operator(
"set_rng_state_prim_impl",
meta=prims.set_rng_state.meta,
fn=_set_rng_state_prim_impl,
tags=(prims.OpTags.RANDOM_OP, prims.OpTags.DONT_DCE),
)


def _get_rng_state_prim_impl(state: torch.Tensor | None, device: torch.device | None) -> torch.Tensor:
if state is not None:
return state
if device is None:
device = torch.cuda.current_device()
state = torch.cuda.get_rng_state(device)
return state


get_rng_state_prim_impl = ex.register_operator(
"get_rng_state_prim_impl",
meta=prims.get_rng_state.meta,
fn=_get_rng_state_prim_impl,
tags=(prims.OpTags.RANDOM_OP,),
)


def _unpack_rng_state_prim_impl(s: torch.Tensor) -> tuple[int, int]:
seed, offset = torch.chunk(s, 2)
return seed.view(torch.int64).item(), offset.view(torch.int64).item()


unpack_rng_state_prim_impl = ex.register_operator(
"unpack_rng_state_prim_impl",
meta=prims.unpack_rng_state.meta,
fn=_unpack_rng_state_prim_impl,
tags=(prims.OpTags.RANDOM_OP,),
)


def _pack_rng_state_prim_impl(seed: int, offset: int) -> torch.Tensor:
seed = torch.tensor(seed)
offset = torch.tensor(offset)
seed_portion = seed.reshape([1]).view(torch.uint8)
offset_portion = offset.reshape([1]).view(torch.uint8)
new_state = torch.cat([seed_portion, offset_portion])
return new_state


pack_rng_state_prim_impl = ex.register_operator(
"pack_rng_state_prim_impl",
meta=prims.pack_rng_state.meta,
fn=_pack_rng_state_prim_impl,
tags=(prims.OpTags.RANDOM_OP,),
)


_register_implementation(prims.full, checker=_always_executable, execution_transform=_full_transform)
_register_implementation(prims.iota, checker=_always_executable, execution_transform=_iota_transform)
_register_implementation(prims.uniform, checker=_always_executable, execution_transform=_uniform_transform)
_register_implementation(
prims.uniform_philox, checker=_uniform_philox_prim_checker, execution_transform=_uniform_philox_prim_transform
)
_register_implementation(prims.set_rng_state, set_rng_state_prim_impl, checker=_always_executable)
_register_implementation(prims.get_rng_state, get_rng_state_prim_impl, checker=_always_executable)
_register_implementation(prims.unpack_rng_state, unpack_rng_state_prim_impl, checker=_always_executable)
_register_implementation(prims.pack_rng_state, pack_rng_state_prim_impl, checker=_always_executable)
_register_implementation(prims.randn, checker=_always_executable, execution_transform=_randn_prims_transform)
_register_implementation(
prims.tensor_from_sequence, checker=_always_executable, execution_transform=_tensor_from_sequence_prims_transform
Expand Down
61 changes: 58 additions & 3 deletions thunder/tests/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from torch.testing import assert_close

import thunder
import thunder.torch as ltorch
from thunder import compile as lc_compile
from thunder.core import devices, dtypes
from thunder.tests.framework import TorchExecutor, instantiate
from thunder.tests.framework import TorchExecutor, instantiate, NOTHING


@instantiate(
Expand All @@ -20,8 +20,63 @@ def test_uniform_philox(executor, device: str, dtype: dtypes.dtype):
def func(shape, dtype, device, rng_seed, rng_offset):
return ltorch.uniform_philox(shape, device=device, dtype=dtype, seed=rng_seed, offset=rng_offset)

cf = lc_compile(func, disable_preprocessing=True, executors_list=executor.executors_list())
cf = thunder.jit(func, executors_list=executor.executors_list())

outputs = [cf(shape, dtype, device, rng_seed, rng_offset) for _ in range(3)]
for o in outputs:
assert_close(o, outputs[0])


from contextlib import contextmanager


@contextmanager
def reset_seed(cuda_generator):
try:
yield
finally:
cuda_generator.seed()


@instantiate(
dtypes=NOTHING,
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
)
def test_rng_state_prims(executor, device: str, _):
import thunder.core.prims as prims
import torch

dev = devices.to_device(device)

def func():
b = prims.get_rng_state(None, device=dev)
c = prims.get_rng_state(b, device=dev)
seed, offset = prims.unpack_rng_state(b)
offset = offset + 40
new_state1 = prims.pack_rng_state(seed, offset)

new_state1 = prims.set_rng_state(new_state1, dev)
new_state1_1 = prims.get_rng_state(new_state1, dev)
state1_seed, state1_offset = prims.unpack_rng_state(new_state1_1)
return b, c, seed, offset, new_state1_1, state1_seed, state1_offset

cuda_generator = torch.cuda.default_generators[dev.index]
jfunc = thunder.jit(func, executors_list=executor.executors_list())
with reset_seed(cuda_generator):
cuda_generator.manual_seed(2)
ori_state, ori_state_1, ori_seed, updated_offset, state1, s1_seed, s1_offset = jfunc()

cuda_generator.manual_seed(2)
expect_ori_state = cuda_generator.get_state()
expected_ori_seed = cuda_generator.initial_seed()
expected_offset = cuda_generator.get_offset() + 40
assert_close(expect_ori_state, ori_state)
assert_close(expect_ori_state, ori_state_1)
assert_close(ori_seed, expected_ori_seed)
assert_close(expected_offset, updated_offset)

cuda_generator.set_offset(expected_offset)
assert_close(cuda_generator.get_state(), state1)
assert_close(cuda_generator.initial_seed(), s1_seed)
assert_close(cuda_generator.get_offset(), s1_offset)

0 comments on commit 8aed212

Please sign in to comment.