Skip to content

[TORCH] Add support for aten.rms_norm op #4206

Open
@sharavana20

Description

@sharavana20

I would like to request support for the torch.rms_norm operation in the Torch dialect of Torch-MLIR.

I tested with the torch.rms_norm using fx.export_and_import and the reproduced error is

Image

Minimal Reproduction

def run(f):
    print(f"{f.__name__}")
    print("-" * len(f.__name__))
    f()
    print()

@run
def test_rms_norm():
    class RMSNorm(nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self,x):
            normalized_shape=[3,4]
            input,weight=x
            return torch.rms_norm(input,normalized_shape,weight,eps=0.8)
    exported=fx.export_and_import(RMSNorm(),(torch.randn(1,2,3,4),torch.randn(3,4)),output_type='torch')
    print(exported)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions