From 96527688ed99bf652a1bb4bffa3ba155db75e984 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 9 Aug 2023 20:58:31 -0700 Subject: [PATCH] Fix pybind strings for RMSNorm Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 6dc48a4b5c..93196962e0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -30,11 +30,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); - m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "LN FWD FP8"); - m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "LN FWD FP8"); - m.def("rmsnorm_bwd", &rmsnorm_bwd, "LN BWD"); - m.def("rmsnorm_fwd", &rmsnorm_fwd, "LN FWD"); - m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "LN FWD"); + m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8"); + m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8"); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); + m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD");