Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable matmul for nvFuser #207

Merged
merged 3 commits into from
May 3, 2024
Merged

Enable matmul for nvFuser #207

merged 3 commits into from
May 3, 2024

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Apr 17, 2024

What does this PR do?

Enables matmul in nvFuser. Part of resolving NVIDIA/Fuser#2053

@Priya2698 Priya2698 changed the title register matmul Enable matmul for nvFuser Apr 17, 2024
@jjsjann123
Copy link
Collaborator

Do you mind trying this in your branch?
#193

Note: you might want to remove the nv_enable_bookend and replace that with enabling matmul.

I'm getting assert on dtype being reduced float. (maybe add that in the check?! but why do we have that in the first place, wasn't it kicked to aten?!)

It's also failing with
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.ops.permute(T1, dims=[1, 0])
    T3 = fd.ops.permute(T0, dims=[2, 1, 0])
    S4 = fd.define_scalar(16, dtype=DataType.Int)
    S5 = fd.define_scalar(32, dtype=DataType.Int)
    V6 = fd.define_vector([S4, S5], dtype=DataType.Int)
    T7 = fd.ops.reshape(T3, new_shape=V6)
    T8 = fd.ops.matmul(T2, T7)
    S9 = fd.define_scalar(16, dtype=DataType.Int)
    S10 = fd.define_scalar(16, dtype=DataType.Int)
    S11 = fd.define_scalar(2, dtype=DataType.Int)
    V12 = fd.define_vector([S9, S10, S11], dtype=DataType.Int)
    T13 = fd.ops.reshape(T8, new_shape=V12)
    T14 = fd.ops.permute(T13, dims=[2, 1, 0])
    fd.add_output(T14)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((512,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 16, 16), (256, 16, 1)),
    torch.randn((256,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 16), (16, 1)),
]   
fd.execute(inputs)

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry this falls off my radar.

cc'ing @IvanYashchuk regarding the matmul checker rejection vs throwing an error. It makes more sense to just reject the matmul instead of throwing an error with out-dated nvfuser version (our stable release is still using older nvfuser version).

Meanwhile, stamping to merge!
We can revisit Ivan's suggestion is he has a strong opinion on the exception, we can throw a warning if you are just concerned about silently running with out-dated library.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit 831d6d0 into main May 3, 2024
35 of 39 checks passed
@t-vi t-vi deleted the nvf_matmul branch May 3, 2024 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants