Skip to content

Commit

Permalink
Add trace transform to replace uniform with stateless uniform(uniform…
Browse files Browse the repository at this point in the history
…_philox) and RNG state query/updating (#114)
  • Loading branch information
kiya00 committed Apr 18, 2024
1 parent 649c3d7 commit 89181ee
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 0 deletions.
86 changes: 86 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ class PrimIDs(Enum):
# Memory access methods
ITEM = auto()
COPY_ = auto()
SET_SEED = auto()
SET_OFFSET=auto()
GET_SEED =auto()
GET_OFFSET = auto()
SET_RNG_STATE = auto()
GET_RNG_STATE = auto()
UNPACK_RNG_STATE = auto()
PACK_RNG_STATE = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -2451,6 +2459,84 @@ def _uniform_meta(
tags=(OpTags.RANDOM_OP,),
)

def _set_seed_meta(s):
return None
def _get_seed_meta():
# return numberproxy(int, 0)
return TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64)
def _set_offset_meta(s):
return None
def _get_offset_meta():
return TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64)
# return numberproxy(int, 0)
def _set_rng_state_meta(new_state): # TODO, device
return None
def _get_rng_state_meta(): #TODO device
# static const size_t seed_size = sizeof(uint64_t);
# static const size_t offset_size = sizeof(int64_t);
# static const size_t total_size = seed_size + offset_size;

# auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
state_shape = dtypes.int64.bytes//dtypes.uint8.bytes * 2
return TensorProxy(shape=(state_shape,), dtype=dtypes.uint8, device=devices.cpu)

set_rng_state = make_prim(
PrimIDs.SET_RNG_STATE,
"set_rng_state",
meta=_set_rng_state_meta,
tags=(OpTags.RANDOM_OP, OpTags.DONT_DCE),
)
get_rng_state = make_prim(
PrimIDs.GET_RNG_STATE,
"get_rng_state",
meta=_get_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)
def _unpack_rng_state_meta(state):
return numberproxy(int, 0), numberproxy(int, 0)
# return TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64), TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64)

unpack_rng_state = make_prim(
PrimIDs.UNPACK_RNG_STATE,
"unpack_rng_state",
meta=_unpack_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)
def _pack_rng_state_meta(seed, offset):
state_shape = dtypes.int64.bytes//dtypes.uint8.bytes * 2
return TensorProxy(shape=(state_shape,), dtype=dtypes.uint8, device=devices.cpu)

pack_rng_state = make_prim(
PrimIDs.PACK_RNG_STATE,
"pack_rng_state",
meta=_pack_rng_state_meta,
tags=(OpTags.RANDOM_OP,),
)
set_seed = make_prim(
PrimIDs.SET_SEED,
"set_seed",
meta=_set_seed_meta,
tags=(OpTags.RANDOM_OP,),
)
get_seed = make_prim(
PrimIDs.GET_SEED,
"get_seed",
meta=_get_seed_meta,
tags=(OpTags.RANDOM_OP,),
)
set_offset = make_prim(
PrimIDs.SET_OFFSET,
"set_offset",
meta=_set_offset_meta,
tags=(OpTags.RANDOM_OP,OpTags.DONT_DCE),
)
get_offset = make_prim(
PrimIDs.GET_OFFSET,
"get_offset",
meta=_get_offset_meta,
tags=(OpTags.RANDOM_OP,),
)


def _uniform_philox_meta(
shape: Sequence[int],
Expand Down
57 changes: 57 additions & 0 deletions thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,63 @@
comment_symbols = {prims.PrimIDs.COMMENT, prims.PrimIDs.UNPACK_TRIVIAL}


def replace_uniform(trace: TraceCtx) -> TraceCtx:
start_time_ns = time.time_ns()
from thunder.torch import uniform_philox

swapmap: dict[Variable, Proxy] = {}

def update_swapmap(o: Any, no: Any) -> None:
if isinstance(o, Proxy):
check(
isinstance(no, Proxy),
lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}",
)

vo = variableify(o)
vno = variableify(no)
if vo == vno:
return
swapmap[vno] = o

def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE:
import thunder.torch as ltorch
if bsym.sym.id == prims.PrimIDs.UNIFORM:
rng_state = prims.get_rng_state()
print(rng_state.shape, rng_state.dtype)
# seed, offset = ltorch.chunk(rng_state, 2)
seed, offset = prims.unpack_rng_state(rng_state)
# seed = prims.get_seed()
# offset = prims.get_offset()
out = uniform_philox(*bsym.args,**bsym.kwargs, seed=seed, offset=offset)
adv_offs = 4
new_offset = prims.add(offset, adv_offs)
new_state = prims.pack_rng_state(seed, new_offset)
prims.set_rng_state(new_state)
safe_map_flat(update_swapmap, bsym.output, out)
return transforms.VISIT_TYPE.REPLACE
else:
return transforms.VISIT_TYPE.NO_OP

extrace = transforms.visitor_transform(trace, visit_)

# Restores original variables
bound_symbols: list[BoundSymbol] = []
for bsym in extrace.bound_symbols:
nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap)
bound_symbols.append(nbsym)

extrace.bound_symbols = bound_symbols

end_time_ns = time.time_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
extrace.set_provenance(
TraceProvenance(f"Transform for replace uniform (took {elapsed_time_millis} milliseconds)")
)
return extrace


# Transforms a trace by determining which execution transforms to call given the list of executors in priority order
def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx:
start_time_ns = time.time_ns()
Expand Down
7 changes: 7 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def make_trace(func):

fw_traces = [fw_trace]
bw_traces = [bw_trace]
# print(primal_trace)
# print(fw_trace, bw_trace)

from thunder.distributed import FSDPType

Expand Down Expand Up @@ -137,6 +139,11 @@ def make_trace(func):
fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace, bw_trace.names)
_fsdp_comm_bucketing.update_name_set(bw_trace)

print(fw_trace)
from thunder.executors.passes import replace_uniform
fw_trace = replace_uniform(fw_trace)
print(fw_trace)

# Now we can run the optimization passes on the forward trace
# TODO Restore request for no rematerialization
fw_extrace = transform_for_execution(
Expand Down
80 changes: 80 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,86 @@ def _tensor_from_sequence_prims_transform(
return tensor_from_sequence(seq_or_number, device=torch_device, dtype=torch_dtype)



def _get_seed_prim_impl():
cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]
return torch.tensor(cuda_generator.initial_seed())
# print("++++: ",cuda_generator.initial_seed())
# return cuda_generator.initial_seed()
get_seed_prim_impl = ex.register_operator(
"get_seed_prim_impl", meta=prims.get_seed.meta, fn=_get_seed_prim_impl
)
_register_implementation(prims.get_seed, get_seed_prim_impl, checker=_always_executable, )

def _get_offset_prim_impl():
cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]
return torch.tensor(cuda_generator.get_offset())
# return cuda_generator.get_offset()
get_offset_prim_impl = ex.register_operator(
"get_offset_prim_impl", meta=prims.get_offset.meta, fn=_get_offset_prim_impl
)
_register_implementation(prims.get_offset, get_offset_prim_impl, checker=_always_executable, )


def _set_seed_prim_impl(s):
cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]
return cuda_generator.manual_seed(s)
set_seed_prim_impl = ex.register_operator(
"set_seed_prim_impl", meta=prims.set_seed.meta, fn=_set_seed_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE)
)
_register_implementation(prims.set_seed, set_seed_prim_impl, checker=_always_executable, )


def _set_offset_prim_impl(s):
cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]
cuda_generator.set_offset(s.item())
set_offset_prim_impl = ex.register_operator(
"set_offset_prim_impl", meta=prims.set_offset.meta, fn=_set_offset_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE)
)
_register_implementation(prims.set_offset, set_offset_prim_impl, checker=_always_executable, )

def _set_rng_state_prim_impl(s):
cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]
cuda_generator.set_state(s)
set_rng_state_prim_impl = ex.register_operator(
"set_rng_state_prim_impl", meta=prims.set_rng_state.meta, fn=_set_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE)
)
_register_implementation(prims.set_rng_state, set_rng_state_prim_impl, checker=_always_executable, )

def _get_rng_state_prim_impl():
cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]
return cuda_generator.get_state()
# return cuda_generator.graphsafe_get_state()
get_rng_state_prim_impl = ex.register_operator(
"get_rng_state_prim_impl", meta=prims.get_rng_state.meta, fn=_get_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE)
)
_register_implementation(prims.get_rng_state, get_rng_state_prim_impl, checker=_always_executable, )

def _unpack_rng_state_prim_impl(s):
seed, offset = torch.chunk(s, 2)
# return seed.view(torch.int64), offset.view(torch.int64)
return seed.view(torch.int64).item(), offset.view(torch.int64).item()
unpack_rng_state_prim_impl = ex.register_operator(
"unpack_rng_state_prim_impl", meta=prims.unpack_rng_state.meta, fn=_unpack_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE)
)
_register_implementation(prims.unpack_rng_state, unpack_rng_state_prim_impl, checker=_always_executable, )

def _pack_rng_state_prim_impl(seed, offset):
seed = torch.tensor(seed)
offset = torch.tensor(offset)
seed_portion = seed.reshape([1]).view(torch.uint8)
offset_portion = offset.reshape([1]).view(torch.uint8)
new_state = torch.cat([seed_portion, offset_portion])
return new_state
pack_rng_state_prim_impl = ex.register_operator(
"pack_rng_state_prim_impl", meta=prims.pack_rng_state.meta, fn=_pack_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE)
)
_register_implementation(prims.pack_rng_state, pack_rng_state_prim_impl, checker=_always_executable, )





_register_implementation(prims.full, checker=_always_executable, execution_transform=_full_transform)
_register_implementation(prims.iota, checker=_always_executable, execution_transform=_iota_transform)
_register_implementation(prims.uniform, checker=_always_executable, execution_transform=_uniform_transform)
Expand Down
58 changes: 58 additions & 0 deletions thunder/tests/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,64 @@ def func(shape, dtype, device, rng_seed, rng_offset):

cf = lc_compile(func, disable_preprocessing=True, executors_list=executor.executors_list())

import torch
rng_seed_tensor = torch.tensor(rng_seed)
rng_offset_tensor = torch.tensor(rng_offset)
# outputs = [cf(shape, dtype, device, rng_seed_tensor, rng_offset_tensor) for _ in range(3)]
outputs = [cf(shape, dtype, device, rng_seed, rng_offset) for _ in range(3)]
import thunder
print(thunder.last_traces(cf)[-1])
# print(thunder.last_backward_traces(cf)[-1])
for o in outputs:
assert_close(o, outputs[0])


@instantiate(
dtypes=(dtypes.float32, dtypes.float16, dtypes.float64),
devicetypes=(devices.DeviceType.CUDA,),
)
def test_rng_state_uniform_philox(executor, device: str, dtype: dtypes.dtype):
import torch
import thunder
def func(a):
b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
# b = torch.nn.functional.dropout(a, p=0.5)
c = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
# b = torch.uniform(a.shape, device=a.device, dtype=a.dtype)
return c*b

cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()]

a = torch.randn(2, 2, device="cuda", requires_grad=True)
a1 = a.detach().clone()
a1.requires_grad_()

jfunc = thunder.jit(func)
cuda_generator.manual_seed(20)
expects = []
for _ in range(4):
out = jfunc(a)
print("b: ", out)
out.sum().backward()
print(a.grad)
expects.append(out)
expects.append(a.grad)
print("------------------")

results = []
cuda_generator.manual_seed(20)

# a = torch.randn(2, 2, device="cuda", requires_grad=True)
print(a1)
for _ in range(4):
out = jfunc(a1)
print("b: ", out)
out.sum().backward()
print(a1.grad)
results.append(out)
results.append(a1.grad)

print(thunder.last_traces(jfunc)[-1])
print(thunder.last_backward_traces(jfunc)[-1])
for expected, result in zip(expects, results):
assert_close(expected, result)

0 comments on commit 89181ee

Please sign in to comment.