Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing the weight of NumberProxy used in mincut in rematerialization #425

Merged
merged 5 commits into from
May 28, 2024

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented May 16, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes part of #114 .

Background:
When decomposing dropout to uniform_philox, the rematerialization doesn't pass the expected seed/offset to the backward

Trace before this PR(based on branch uniform_rng):

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t6 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t6: "cpu ui8[16]"
  (i7, i8) = unpack_rng_state_prim_impl(t6)
  [t1, t5] = nvFusion0(a, i7, i8)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.mul(a, t4)  # t5: "cuda:0 f32[2, 2]"
  t10 = update_rng_state_prim_impl(i7, i8)  # t10: "cpu ui8[16]"
  set_rng_state_prim_impl(t10, devices.Device("cuda:0"))
  return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((a, t1), (2.0,))
# Constructed by Update Call Context (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t6, = cotangents
  a, t1, = C0
  f7, = C1
  [t39] = nvFusion0(a, f7, t1, t6)
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t33 = prims.mul(t4, t6)  # t33: "cuda:0 f32[2, 2]"
    # t34 = prims.mul(a, t6)  # t34: "cuda:0 f32[2, 2]"
    # t35 = prims.mul(f7, t34)  # t35: "cuda:0 f32[2, 2]"
    # t36 = prims.mul(t2, t35)  # t36: "cuda:0 f32[2, 2]"
    # t39 = prims.add(t33, t36)  # t39: "cuda:0 f32[2, 2]"
  return (t39,)

The mincut is (a_in, a_out), (t1_in, t1_out) with weight=2+1=3, the corresponding weight is:

a, 2.0
i7, 0.5
i8, 0.5
t5, 4.0
t0, 4.0
t1, 1.0
t2, 4.0
t3, 4.0
t4, 4.0
t5, 4.0
t33, 4.0
t34, 4.0
t35, 4.0
t36, 4.0
t39, 4.0

we expect to have the mincut ((a_in, a_out), (i7_in, i7_out), (i8_in, i8_out)) which actually has the same weight=0.5+0.5+2=3

Here's the trace with this PR: #425 (comment)

The proposal of fixing it would be add a factor(e.g. 0.1) to reduce the weight for NumberProxy

cc: @IvanYashchuk

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change makes sense to me.

The PR description has only the "before" trace. Could you please update it adding how the trace looks like with this PR?

Not blocking the merge due to an absence of tests, but maybe you could come up with a short test for this change? The rematerialization tests currently live in tests/test_nvfuser_remat.py.

@kiya00
Copy link
Collaborator Author

kiya00 commented May 17, 2024

The PR description has only the "before" trace. Could you please update it adding how the trace looks like with this PR?

Trace after this PR(based on branch uniform_rng):

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t6 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t6: "cpu ui8[16]"
  (i7, i8) = unpack_rng_state_prim_impl(t6)
  del t6
  [t5] = nvFusion0(a, i7, i8)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.mul(a, t4)  # t5: "cuda:0 f32[2, 2]"
  t10 = update_rng_state_prim_impl(i7, i8)  # t10: "cpu ui8[16]"
  del i7, i8
  set_rng_state_prim_impl(t10, devices.Device("cuda:0"))
  del t10
  return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((a,), (2.0, 0, 0))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t6, = cotangents
  clear_collection(cotangents)
  del cotangents
  a, = C0
  clear_collection(C0)
  del C0
  f7, i7, i8, = C1
  clear_collection(C1)
  del C1
  [t39] = nvFusion0(a, f7, i7, i8, t6)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t33 = prims.mul(t4, t6)  # t33: "cuda:0 f32[2, 2]"
    # t34 = prims.mul(a, t6)  # t34: "cuda:0 f32[2, 2]"
    # t35 = prims.mul(f7, t34)  # t35: "cuda:0 f32[2, 2]"
    # t36 = prims.mul(t2, t35)  # t36: "cuda:0 f32[2, 2]"
    # t39 = prims.add(t33, t36)  # t39: "cuda:0 f32[2, 2]"
  del a, f7, i7, i8, t6
  return (t39,)

Not blocking the merge due to an absence of tests, but maybe you could come up with a short test for this change? The rematerialization tests currently live in tests/test_nvfuser_remat.py.

sure, let me try if I can construct a graph corresponding to the case

By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr @IvanYashchuk

@IvanYashchuk
Copy link
Collaborator

By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr

Thank you for checking! This is very good and what I wanted to see.

@kiya00
Copy link
Collaborator Author

kiya00 commented May 21, 2024

Hi @mruberry , could you take a look for review

@kiya00
Copy link
Collaborator Author

kiya00 commented May 24, 2024

Hi @nikitaved @t-vi , could you help to take a look for merging

Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, stamped!

@lantiga lantiga merged commit 5ff8a34 into main May 28, 2024
37 checks passed
@lantiga lantiga deleted the fix_remat branch May 28, 2024 20:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants