Open
Description
I would like to request support for the torch.rms_norm operation in the Torch dialect of Torch-MLIR.
I tested with the torch.rms_norm using fx.export_and_import and the reproduced error is
Minimal Reproduction
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
def test_rms_norm():
class RMSNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
normalized_shape=[3,4]
input,weight=x
return torch.rms_norm(input,normalized_shape,weight,eps=0.8)
exported=fx.export_and_import(RMSNorm(),(torch.randn(1,2,3,4),torch.randn(3,4)),output_type='torch')
print(exported)
Metadata
Metadata
Assignees
Labels
No labels