Skip to content

Commit

Permalink
Add trace transformation to replace uniform with uniform_philox (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Apr 30, 2024
1 parent 7b92e30 commit 5785601
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
49 changes: 49 additions & 0 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,52 @@ def joint_fn(args, kwargs, cotangents):
new_fw_trace = update_fusion_call_ctx(new_fw_trace)
new_bw_trace = update_fusion_call_ctx(new_bw_trace)
return new_fw_trace, new_bw_trace


def replace_uniform(trace: TraceCtx) -> TraceCtx:
"""For better rematerialization, replace the uniform operator with the stateless uniform_philox operator and manually update the RNG state."""
start_time_ns = time.time_ns()
from thunder.core.trace import VariableInterface
from thunder.core.proxies import Proxy
from thunder.core.devices import Device
from thunder.core.transforms import VISIT_TYPE, visitor_transform

swapmap: dict[VariableInterface, Proxy] = {}
prev_state: dict[Device, Proxy] = {}

def visit_(bsym: BoundSymbolInterface) -> VISIT_TYPE:
if bsym.sym.id == prims.PrimIDs.UNIFORM:
dev = bsym.kwargs["device"]
if dev not in prev_state:
rng_state = prims.get_rng_state(None, dev)
prev_state[dev] = rng_state
else:
rng_state = prims.get_rng_state(prev_state[dev], dev)
seed, offset = prims.unpack_rng_state(rng_state)
out = prims.uniform_philox(*bsym.args, **bsym.kwargs, seed=seed, offset=offset)
advance_offset = 4
new_offset = prims.add(offset, advance_offset)
new_state = prims.pack_rng_state(seed, new_offset)
new_state_1 = prims.set_rng_state(new_state, dev)
new_vo = variableify(out)
swapmap[new_vo] = bsym.output
prev_state[dev] = new_state_1
return VISIT_TYPE.REPLACE
return VISIT_TYPE.NO_OP

new_trace = visitor_transform(trace, visit_)

bound_symbols: list[BoundSymbolInterface] = []
for bsym in new_trace.bound_symbols:
nbsym: BoundSymbolInterface = bsym.from_bsym_swap_proxies(swapmap)
bound_symbols.append(nbsym)

new_trace.bound_symbols = bound_symbols

end_time_ns = time.time_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
new_trace.set_provenance(
TraceProvenance(f"Transform for replace uniform (took {elapsed_time_millis} milliseconds)")
)
return new_trace
5 changes: 5 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace, bw_trace.names)
_fsdp_comm_bucketing.update_name_set(bw_trace)

# Replace uniform with uniform_philox and rng state operators for better rematerialization
from thunder.core.rematerialization import replace_uniform

fw_trace = replace_uniform(fw_trace)

# Now we can run the optimization passes on the forward trace
# TODO Restore request for no rematerialization
fw_extrace = transform_for_execution(
Expand Down
101 changes: 101 additions & 0 deletions thunder/tests/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,104 @@ def func():
assert_close(cuda_generator.get_state(), state1)
assert_close(cuda_generator.initial_seed(), s1_seed)
assert_close(cuda_generator.get_offset() // (1 if executor == TorchExecutor else 4), s1_offset)


@instantiate(
dtypes=(dtypes.float32, dtypes.float16, dtypes.float64),
devicetypes=(devices.DeviceType.CUDA,),
)
def test_rng_state_uniform_philox_reproducibility(executor, device: str, dtype: dtypes.dtype):
import torch

def func(a):
b = ltorch.uniform_like(a, device=a.device, dtype=a.dtype)
d = torch.nn.functional.dropout(a, p=0.5)
c = ltorch.uniform_like(a, device=a.device, dtype=a.dtype)
return c * b * a * d

dev = devices.to_torch_device(device)
cuda_generator = torch.cuda.default_generators[dev.index]
a = torch.randn(2, 2, device=dev, dtype=dtypes.to_torch_dtype(dtype), requires_grad=True)
a1 = a.detach().clone()
a1.requires_grad_()

jfunc = thunder.jit(func, executors_list=executor.executors_list())

with torch.random.fork_rng(devices=(dev,)):
torch.cuda.manual_seed(20)
expects = []
for _ in range(4):
out = jfunc(a)
out.sum().backward()
expects.append(out)
expects.append(a.grad)

results = []
torch.cuda.manual_seed(20)
for _ in range(4):
out = jfunc(a1)
out.sum().backward()
results.append(out)
results.append(a1.grad)

for expected, result in zip(expects, results):
assert_close(expected, result)


@instantiate(
dtypes=(dtypes.float32, dtypes.float16, dtypes.float64),
devicetypes=(devices.DeviceType.CUDA,),
executors=(TorchExecutor,),
)
def test_uniform_philox_vs_uniform(executor, device: str, dtype: dtypes.dtype):
import torch

dev = devices.to_torch_device(device)
cuda_generator = torch.cuda.default_generators[dev.index]

def func(a):
b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
e = a * b
c = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
f = e + c
d = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
return f * d

a = torch.randn(2, 2, device=dev, dtype=dtypes.to_torch_dtype(dtype), requires_grad=True)
a1 = a.detach().clone().requires_grad_()

jfunc = thunder.jit(func, executors_list=executor.executors_list())

with torch.random.fork_rng(devices=(dev,)):
cuda_generator.manual_seed(20)
expects = []
# get the results of uniform_philox with RNG state updates
for _ in range(4):
out = jfunc(a)
expects.append(out)
assert cuda_generator.get_offset() == 12 * 4
fwd_trc = [
t for t in thunder.last_traces(jfunc) if getattr(t.get_provenance(), "pss", "") == "Augmented forward pass"
][0]
from thunder.core.prims import PrimIDs

uniform_philox_sym = [PrimIDs.UNIFORM_PHILOX, "torch.uniform_philox"]
uniform_sym = [PrimIDs.UNIFORM, "torch.uniform"]
assert all(t.sym.id not in uniform_philox_sym for t in fwd_trc.bound_symbols)
assert all(t not in uniform_sym for t in thunder.last_traces(jfunc)[-1].bound_symbols)

# get the results of uniform
results = []
cuda_generator.manual_seed(20)
from unittest.mock import patch

with patch("thunder.core.rematerialization.replace_uniform") as replace_uniform_mock:
replace_uniform_mock.return_value = fwd_trc
jfunc = thunder.jit(func, executors_list=executor.executors_list())
for _ in range(4):
out = jfunc(a1)
results.append(out)
assert cuda_generator.get_offset() == 12 * 4

for expected, result in zip(expects, results):
assert_close(expected, result)

0 comments on commit 5785601

Please sign in to comment.