Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def te_version() -> str:
with open(root_path / "VERSION", "r") as f:
version = f.readline().strip()

try:
output = subprocess.run(
["git", "rev-parse" , "--short", "HEAD"],
capture_output=True,
cwd=root_path,
check=True,
universal_newlines=True,
)
except (CalledProcessError, OSError):
commit = ""
else:
commit = output.stdout.strip()

# [augment] Here is where we replace the git hash with our own versioning.
# You can disable this behavior with NVTE_NO_AUGMENT_VERSION=1.
if not int(os.getenv("NVTE_NO_AUGMENT_VERSION", "0")):
Expand All @@ -43,21 +56,10 @@ def te_version() -> str:
torch_version = parse(torch.__version__)
cuda_version = parse(torch.version.cuda)
version_string = f".cu{cuda_version.major}{cuda_version.minor}.torch{torch_version.major}{torch_version.minor}"
return version + "+augment" + version_string
return version + "+augment" + version_string + "." + commit

if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")):
try:
output = subprocess.run(
["git", "rev-parse" , "--short", "HEAD"],
capture_output=True,
cwd=root_path,
check=True,
universal_newlines=True,
)
except (CalledProcessError, OSError):
pass
else:
commit = output.stdout.strip()
if len(commit) > 0:
version += f"+{commit}"
return version

Expand Down
37 changes: 31 additions & 6 deletions transformer_engine/common/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
* See LICENSE for license information.
************************************************************************/

#include <cstdlib>
#include <transformer_engine/layer_norm.h>
#include <string>
#include <vector>
#include "ln.h"
#include "../common.h"
Expand All @@ -31,6 +33,9 @@ Compute always in FP32
namespace transformer_engine {
namespace layer_norm {

// [Augment] Forward declare helper kernel added to avoid using memset.
void launch_zero_out(void *, size_t, size_t, cudaStream_t);

using namespace transformer_engine;

// Create registries and provide runtime versions of config hash functions.
Expand Down Expand Up @@ -232,16 +237,36 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}

// NOTE[augment]: this envvar exists to restore the prior behavior of TE (ie, use a memset
// kernel. So if you want to get the upstream behavior, run with NVTE_FORCE_MEMSET=1.
const char *envval = std::getenv("NVTE_FORCE_MEMSET");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only nit is that it would be good to document the flag. This is a fall back to the original behavior of library.

bool force_memset = (envval != nullptr) && (std::string(envval) == "1");
// Clear buffers
if ( params.fp8_out ) {
cudaMemsetAsync(params.amax, 0,
layer_norm::product(z->amax.shape) *
typeToSize(z->amax.dtype), stream);
if ( force_memset ) {
cudaMemsetAsync(params.amax, 0,
layer_norm::product(z->amax.shape) *
typeToSize(z->amax.dtype), stream);
} else {
// [Augment] Use the zero-out kernel, not memset.
layer_norm::launch_zero_out(params.amax,
layer_norm::product(z->amax.shape),
typeToSize(z->amax.dtype),
stream);
}
}
if ( launch_params.barrier_size > 0 ) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
if ( force_memset ) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
} else {
// [Augment] Use the zero-out kernel, not memset.
layer_norm::launch_zero_out(params.barrier,
layer_norm::product(barrier->data.shape),
typeToSize(barrier->data.dtype),
stream);
}
}

// Launch the kernel.
Expand Down
42 changes: 42 additions & 0 deletions transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,48 @@

using namespace transformer_engine::layer_norm;

// [Augment] What follows is a small custom kernel (and launch function) to zero-out a buffer.
// We use this to replace a call to cudaMemsetAsync, which introduces gaps in cuda graph
// execution. I am sure there is a more natural place to put this, but I haven't spent the time
// to figure out the TE build system.
namespace transformer_engine::layer_norm {

// Kernel itself: simple blockwise loop
template <typename T>
__launch_bounds__(128)
__global__ void zero_out(T *x, const size_t N) {
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int stride = blockDim.x * gridDim.x;
for (int i = tidx + bidx * blockDim.x; i < N; i += stride) {
x[i] = 0;
}
}

// Launch function: switch over element size and use an appropriate dtype for each.
// NOTE: if speed ever becomes an issue, this ought to be vectorized.
void launch_zero_out(void *buf, size_t N, size_t elem_size, cudaStream_t stream) {
int num_blocks = DIVUP(static_cast<int>(N), 128);
switch (elem_size) {
case 1:
zero_out<byte><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<byte*>(buf), N);
break;
case 2:
zero_out<fp16><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<fp16*>(buf), N);
break;
case 4:
zero_out<float><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<float*>(buf), N);
break;
case 8:
zero_out<double><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<double*>(buf), N);
break;
default:
break;
}
}

}; // namespace transformer_engine::layer_norm

template<
typename weight_t,
typename input_t,
Expand Down
35 changes: 30 additions & 5 deletions transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
* See LICENSE for license information.
************************************************************************/

#include <cstdlib>
#include <numeric>
#include <string>
#include <vector>
#include "../common.h"
#include "rmsnorm.h"
Expand Down Expand Up @@ -35,6 +37,9 @@ namespace transformer_engine {

namespace layer_norm {
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size);

// [Augment] Forward declare helper kernel added to avoid using memset.
void launch_zero_out(void *, size_t, size_t, cudaStream_t);
}

namespace rmsnorm {
Expand Down Expand Up @@ -177,15 +182,35 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
}

// NOTE[augment]: this envvar exists to restore the prior behavior of TE (ie, use a memset
// kernel. So if you want to get the upstream behavior, run with NVTE_FORCE_MEMSET=1.
const char *envval = std::getenv("NVTE_FORCE_MEMSET");
bool force_memset = (envval != nullptr) && (std::string(envval) == "1");
// Clear buffers
if (params.fp8_out) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
if (force_memset) {
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
stream);
} else {
// [Augment] Use the zero-out kernel, not memset.
layer_norm::launch_zero_out(params.amax,
rmsnorm::product(z->amax.shape),
typeToSize(z->amax.dtype),
stream);
}
}
if (launch_params.barrier_size > 0) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
if (force_memset) {
cudaMemsetAsync(params.barrier, 0,
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
stream);
} else {
// [Augment] Use the zero-out kernel, not memset.
layer_norm::launch_zero_out(params.barrier,
rmsnorm::product(barrier->data.shape),
typeToSize(barrier->data.dtype),
stream);
}
}

// Launch the kernel.
Expand Down