Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256},
{993, 512},
{768, 1024},
{65536, 128},
{65504, 128},
{16384, 1632},
};

Expand Down
17 changes: 15 additions & 2 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,19 @@ std::pair<double, double> getTolerances(const DType type) {

template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
// Check how many RNG calls are required to generate one uniform random value
int rng_calls_per_val = 0;
{
std::mt19937 gen1 = *gen, gen2 = *gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float _ = dis(gen1);
while (gen2 != gen1) {
auto _ = gen2();
++rng_calls_per_val;
}
}

// Generate uniform random values in parallel
#pragma omp parallel proc_bind(spread)
{
std::mt19937 gen_local = *gen;
Expand All @@ -702,14 +715,14 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min);
gen_local.discard(idx_min * rng_calls_per_val);
std::uniform_real_distribution<> dis(-2.0, 1.0);

for (int i = idx_min; i < idx_max; ++i) {
data[i] = static_cast<T>(dis(gen_local));
}
}
gen->discard(size);
gen->discard(size * rng_calls_per_val);
}

void fillUniform(Tensor *t) {
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 4e-5}
return {"rtol": 1e-4, "atol": 1e-4}
raise ValueError(f"Unsupported dtype ({dtype})")


Expand Down
28 changes: 16 additions & 12 deletions transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ Compute always in FP32
namespace transformer_engine {
namespace normalization {

bool& use_zero_centered_gamma_in_weight_dtype();

cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
: cudnn_frontend::NormFwdPhase_t::INFERENCE;
Expand All @@ -49,13 +47,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode, bool training) {
// TODO: Add scaling_mode to general_key is needed
uint64_t general_key = static_cast<uint32_t>(itype) | (static_cast<uint32_t>(otype) << 3) |
(static_cast<uint32_t>(ctype) << 6) | (static_cast<uint32_t>(wtype) << 9) |
(uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 |
(uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) |
(uint32_t(mode) << 19) | (uint32_t(training) << 22);
bool is_tuned, NVTEScalingMode mode, bool training,
bool gamma_in_weight_dtype) {
static_assert(NVTE_INVALID_SCALING < 1024,
"This function assumes at most 10 bits used in the scaling mode.");
static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType");
uint64_t general_key = static_cast<uint64_t>(itype) | (static_cast<uint64_t>(otype) << 5) |
(static_cast<uint64_t>(ctype) << 10) |
(static_cast<uint64_t>(wtype) << 15) | (uint64_t(NormType) << 20) |
(uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) |
(uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) |
(uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38);
return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
}

Expand Down Expand Up @@ -466,11 +468,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode, const bool training) {
const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) {
const DType ctype = DType::kFloat32;
bool is_tuned = is_aligned && (batch_size % 4 == 0);
auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size,
hidden_size, zero_centered_gamma, is_tuned, mode, training);
auto key =
get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size,
zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype);

auto it = normalizationPlanMap.find(key);
if (it != normalizationPlanMap.end()) {
Expand Down Expand Up @@ -528,6 +531,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
}

// Only for testing, not thread-safe
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype);
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable;
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/common/normalization/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
bool training = true);
bool training = true, bool gamma_in_weight_dtype = false);

template <typename KernelParamsType>
class TeNormalizationRegistry {
Expand Down Expand Up @@ -307,7 +307,8 @@ class NormalizationPlanRegistry {
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true);
const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true,
const bool gamma_in_weight_dtype = false);

private:
NormalizationPlanRegistry() {}
Expand Down Expand Up @@ -381,6 +382,8 @@ bool is_ptr_aligned(const Args*... ptrs) {
bool use_cudnn_norm_fwd();
bool use_cudnn_norm_bwd();

bool& use_zero_centered_gamma_in_weight_dtype();

} // namespace normalization
} // namespace transformer_engine

Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/common/normalization/layernorm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "../../common.h"
#include "../common.h"
#include "transformer_engine/transformer_engine.h"

namespace transformer_engine {

Expand Down Expand Up @@ -64,9 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);

bool gamma_in_weight_dtype = false;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr,
Expand All @@ -83,7 +86,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training);
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
Expand Down Expand Up @@ -150,9 +154,11 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te

NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr,
Expand All @@ -165,7 +171,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/normalization.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"

namespace transformer_engine {
Expand Down Expand Up @@ -53,9 +54,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;

bool gamma_in_weight_dtype = false;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr);
Expand All @@ -68,7 +71,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training);
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
Expand Down Expand Up @@ -126,9 +130,11 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const

NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
Expand All @@ -141,7 +147,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
Expand Down