Skip to content

Initial commit to pass scale as Tensor for multi_tensor_scale op#2594

Merged
ksivaman merged 11 commits intoNVIDIA:mainfrom
vasunvidia:vrengasamy/multi_tensor_scale_cg
Mar 12, 2026
Merged

Initial commit to pass scale as Tensor for multi_tensor_scale op#2594
ksivaman merged 11 commits intoNVIDIA:mainfrom
vasunvidia:vrengasamy/multi_tensor_scale_cg

Conversation

@vasunvidia
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vasunvidia vasunvidia force-pushed the vrengasamy/multi_tensor_scale_cg branch from b2a5ae5 to 4081afc Compare February 11, 2026 19:17
@vasunvidia vasunvidia marked this pull request as ready for review February 11, 2026 19:20
@vasunvidia vasunvidia force-pushed the vrengasamy/multi_tensor_scale_cg branch from 4081afc to cfd4370 Compare February 11, 2026 19:20
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 11, 2026

Greptile Summary

This PR introduces multi_tensor_scale_tensor, a new variant of the fused overflow-check-and-scale kernel where the scale factor is passed as a single-element CUDA tensor instead of a scalar float. The implementation cleanly refactors the existing ScaleFunctor and the new ScalePtrFunctor to share a common scale_chunk device function, eliminating the previously noted code-duplication concern.

Key changes:

  • New ScalePtrFunctor delegates to a templated scale_chunk helper (shared with ScaleFunctor) using overloaded get_scale_value() device functions to handle both float and float* scale inputs.
  • New C API nvte_multi_tensor_scale_tensor_cuda, C++ extension wrapper, and pybind11 binding wiring the kernel up to Python.
  • multi_tensor_scale_tensor exported from transformer_engine.pytorch.optimizers.
  • Tests mirror the existing test_multi_tensor_scale and cover correct scaling and overflow detection for NaN/Inf inputs.

Outstanding concerns (carried from prior review rounds):

  • scale.cpp: No TORCH_CHECK(scale.is_cuda()) / TORCH_CHECK(scale.numel() == 1) guard — passing a CPU tensor causes an illegal device memory access at runtime.
  • nvte_multi_tensor_scale_tensor_cuda (scale.cu): reinterpret_cast<float*>(scale_tensor->data.dptr) is performed without verifying dtype == kFloat32 or numel() == 1, silently misinterpreting the bits of any other type.
  • scale_chunk: isfinite() is only checked for input tensor elements; if *scale_ptr is Inf or NaN, outputs become non-finite but is_infinite_gmem is never set, creating a silent correctness hole for loss-scale overflow detection.
  • The scale pointer is dereferenced once per thread per block from global memory without __ldg(), causing redundant L2 traffic for large tensor lists.

Confidence Score: 2/5

  • Not safe to merge — multiple unresolved validation gaps can cause silent runtime crashes or silent correctness failures in loss-scale overflow detection.
  • The refactoring itself (scale_chunk helper, get_scale_value overloads) is clean and correct. However, three issues from earlier review rounds remain unaddressed: (1) no CUDA device check for the scale tensor risks an illegal memory access crash, (2) no dtype/numel check before the reinterpret_cast risks silent bit-misinterpretation, and (3) a non-finite scale tensor bypasses the is_infinite flag entirely, creating a silent training-stability bug. These are not theoretical — they affect production AMP training workflows.
  • transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp (missing CUDA/numel checks) and transformer_engine/common/multi_tensor/scale.cu (missing dtype validation and non-finite scale propagation).

Important Files Changed

Filename Overview
transformer_engine/common/multi_tensor/scale.cu Introduces ScalePtrFunctor and shared scale_chunk helper to eliminate code duplication. The refactoring is sound, but: (1) stale *noop_gmem reference in commented-out guard, (2) missing dtype/numel validation before reinterpret_cast<float*> at the API boundary, and (3) scale_chunk never checks isfinite(*scale_ptr), so a non-finite loss scale silently corrupts outputs without setting is_infinite_gmem.
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp Adds multi_tensor_scale_tensor_cuda wrapper. No CUDA device check or element-count check is performed on the scale tensor; passing a CPU tensor will cause an illegal memory access in device code when the kernel dereferences the host pointer.
transformer_engine/common/include/transformer_engine/multi_tensor.h Adds well-documented nvte_multi_tensor_scale_tensor_cuda declaration. Doc comment for scale parameter could mention the float32 / single-element requirements, but otherwise correct.
transformer_engine/pytorch/csrc/extensions.h Adds the new multi_tensor_scale_tensor_cuda C++ declaration with consistent is_infinite naming.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Registers multi_tensor_scale_tensor in pybind11 with GIL release. Straightforward and correct.
transformer_engine/pytorch/optimizers/init.py Re-exports multi_tensor_scale_tensor from transformer_engine_torch. Minimal, correct change.
tests/pytorch/test_multi_tensor.py New test_multi_tensor_scale_tensor mirrors the existing scalar-scale test and covers downscaling correctness plus overflow detection for NaN/Inf inputs. Does not test a non-CUDA scale tensor (expected error path) or a non-finite scale tensor (the known undetected-overflow gap).

Sequence Diagram

sequenceDiagram
    participant PY as Python (optimizers)
    participant CPP as scale.cpp (PyTorch ext)
    participant API as nvte API (scale.cu)
    participant KERN as CUDA Kernel

    PY->>CPP: multi_tensor_scale_tensor(chunk_size, is_infinite, tensor_lists, scale_tensor)
    CPP->>CPP: makeTransformerEngineTensor(is_infinite)
    CPP->>CPP: makeTransformerEngineTensor(scale)
    CPP->>API: nvte_multi_tensor_scale_tensor_cuda(...)
    API->>API: convertNVTETensorCheck(scale)
    API->>API: reinterpret_cast<float*>(scale_tensor->data.dptr)
    API->>KERN: multi_tensor_apply<2>(..., ScalePtrFunctor, scale_ptr)
    loop Per block / chunk
        KERN->>KERN: ScalePtrFunctor::operator() → scale_chunk(scale_ptr)
        KERN->>KERN: get_scale_value(scale_ptr) — dereferences GPU ptr
        KERN->>KERN: multiply inputs, track isfinite(r_in[ii])
        KERN-->>API: write is_infinite_gmem if !finite
    end
    API-->>CPP: return
    CPP-->>PY: return (is_infinite tensor updated in-place)
Loading

Last reviewed commit: e27f134

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp Outdated
Comment on lines +219 to +223
Tensor *scale_tensor = convertNVTETensorCheck(scale);
multi_tensor_scale::multi_tensor_scale_tensor_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
reinterpret_cast<float *>(scale_tensor->data.dptr), stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate scale tensor dtype and size before casting. should check that:

  • dtype is kFloat32
  • numel() == 1 (single element tensor)
  Tensor *scale_tensor = convertNVTETensorCheck(scale);
  NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32, 
             "scale tensor must be float32, got ", to_string(scale_tensor->dtype()));
  NVTE_CHECK(scale_tensor->numel() == 1,
             "scale tensor must have exactly 1 element, got ", scale_tensor->numel());
  multi_tensor_scale::multi_tensor_scale_tensor_cuda(

// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
float scale = *scale_ptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all threads in a block read from the same device memory location without synchronization - could cause redundant memory traffic but functionally correct. consider caching in shared memory or using __ldg() for read-only cache optimization

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

NVTE_API_CALL(nvte_multi_tensor_scale_tensor_cuda);
using namespace transformer_engine;

Tensor *scale_tensor = convertNVTETensorCheck(scale);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate scale tensor before casting - should check dtype is kFloat32 and numel() == 1

Suggested change
Tensor *scale_tensor = convertNVTETensorCheck(scale);
Tensor *scale_tensor = convertNVTETensorCheck(scale);
NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32,
"scale tensor must be float32, got ", to_string(scale_tensor->dtype()));
NVTE_CHECK(scale_tensor->numel() == 1,
"scale tensor must have exactly 1 element, got ", scale_tensor->numel());
multi_tensor_scale::multi_tensor_scale_tensor_cuda(

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 18, 2026

Additional Comments (1)

tests/pytorch/test_multi_tensor.py
add test cases for multi_tensor_scale_tensor to verify the tensor-based scale parameter works with different dtypes, shapes, and edge cases (wrong dtype, multi-element tensors)

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall reasonable, but shows signs of repeated copy-pasting. Also, can we add a test?

* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation is incorrect. We should change this variable name to something more accurate and update everywhere.

Suggested change
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[out] isfinite Whether the kernel detected a non-finite input value.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed noop_flag to is_infinite for nvte_multi_tensor_scale_tensor_cuda and nvte_multi_tensor_scale_cuda ops

int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// note for clarification to future michael:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: There is no Michael taking responsibility for this. This comment is originally from NVIDIA/apex@6763a8b and it's made its way here through multiple levels of blind copy-pastes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the comment.

Comment on lines +21 to 31
void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor is_infinite,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor scale) {
auto is_infinite_cu = makeTransformerEngineTensor(is_infinite);
auto scale_cu = makeTransformerEngineTensor(scale);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
nvte_multi_tensor_scale_tensor_cuda(chunk_size, is_infinite_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, scale_cu.data(),
at::cuda::getCurrentCUDAStream());
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CUDA device validation for scale tensor

The scale parameter is an at::Tensor that could reside on CPU (host) memory. Inside ScalePtrFunctor, the kernel dereferences scale_ptr directly in device code:

float scale = *scale_ptr;

If scale is a CPU tensor, every GPU thread will attempt to read a host memory address from device code, resulting in an illegal memory access CUDA error at runtime (typically showing as a silent hang or a CUDA error: an illegal memory access was encountered crash).

A device check should be added before wrapping the tensor, similar to how other PyTorch CUDA extensions guard their inputs:

void multi_tensor_scale_tensor_cuda(int chunk_size, at::Tensor is_infinite,
                                    std::vector<std::vector<at::Tensor>> tensor_lists,
                                    at::Tensor scale) {
  TORCH_CHECK(scale.is_cuda(), "scale tensor must be on CUDA device");
  TORCH_CHECK(scale.numel() == 1, "scale tensor must contain exactly 1 element, got ", scale.numel());
  TORCH_CHECK(scale.scalar_type() == at::kFloat, "scale tensor must be float32");
  auto is_infinite_cu = makeTransformerEngineTensor(is_infinite);
  auto scale_cu = makeTransformerEngineTensor(scale);

Without this guard, passing a CPU-side torch.tensor([1.0]) (very easy to do accidentally) silently corrupts the kernel launch.

Comment on lines +105 to +172
template <typename in_t, typename out_t>
struct ScalePtrFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *is_infinite_gmem,
TensorListMetadata<2> &tl, // NOLINT(*)
float *scale_ptr) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
float scale = *scale_ptr;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

in_t *in = reinterpret_cast<in_t *>(tl.addresses[0][tensor_loc]);
in += chunk_idx * chunk_size;

out_t *out = reinterpret_cast<out_t *>(tl.addresses[1][tensor_loc]);
out += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_in, in, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
// store
load_store(out, r_out, i_start, 0);
}
} else {
// Non-divergent exit condition for __syncthreads, not necessary here
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite)
*is_infinite_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalePtrFunctor duplicates ScaleFunctor almost entirely

ScalePtrFunctor is identical to ScaleFunctor except for the one-line dereference float scale = *scale_ptr; at the top. The remaining ~60 lines of aligned/unaligned loop logic, ILP unrolling, finite tracking, and the is_infinite_gmem write are copy-pasted verbatim.

This duplication means any future bugfix or optimization (e.g. adding __ldg() for the scale pointer, fixing the finite flag propagation, or changing the loop bounds logic) must be applied in both functors, which is a maintenance hazard. Consider templating the scale acquisition so a single functor covers both cases, for example via a ScaleGetter policy type or by simply passing a pointer and dereferencing it in the existing ScaleFunctor when USE_PTR is set.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

vasunvidia and others added 7 commits March 2, 2026 21:33
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…sed but not actually enabled

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…s is passed but not actually enabled"

This reverts commit 74a9bcc.

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
@vasunvidia vasunvidia force-pushed the vrengasamy/multi_tensor_scale_cg branch from 207a20d to aa4f177 Compare March 3, 2026 05:33
Comment thread transformer_engine/pytorch/csrc/extensions.h Outdated
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +113 to +170
float scale = *scale_ptr;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

in_t *in = reinterpret_cast<in_t *>(tl.addresses[0][tensor_loc]);
in += chunk_idx * chunk_size;

out_t *out = reinterpret_cast<out_t *>(tl.addresses[1][tensor_loc]);
out += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];

// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_in, in, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
// store
load_store(out, r_out, i_start, 0);
}
} else {
// Non-divergent exit condition for __syncthreads, not necessary here
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite)
*is_infinite_gmem = 1; // Blindly fire off a write. These will race but that's ok.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_infinite_gmem not set when scale itself is non-finite

ScalePtrFunctor reads the scale from GPU memory at line 113 (float scale = *scale_ptr;) but never checks whether that value is finite. The finite flag only tracks isfinite(r_in[ii]), so if *scale_ptr is NaN or Inf:

  • A finite input r_in[ii] multiplied by Inf or NaN produces a non-finite output.
  • Yet finite = finite && isfinite(r_in[ii]) remains true.
  • is_infinite_gmem is never set, and the caller has no way to detect the overflow.

With the scalar-float variant (ScaleFunctor), callers always hold the scale value on the host and can validate it before the kernel launch. With the new tensor variant, the scale lives in GPU memory and there is no equivalent host-side check. This creates a silent correctness hole where a non-finite loss-scale tensor can corrupt the parameter tensors without triggering the overflow flag.

Consider including the scale in the finiteness check:

float scale = *scale_ptr;
bool finite = isfinite(scale);   // seed finite with scale validity

or alternatively check the output value:

finite = finite && isfinite(static_cast<float>(r_out[ii]));

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member

/te-ci

Comment on lines +43 to +45
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale variable name in commented-out guard

The commented-out guard references *noop_gmem, which was the old parameter name when this comment lived inside ScaleFunctor::operator(). After the refactoring that moved this logic into scale_chunk, the actual parameter is is_infinite_gmem. Any developer who later decides to enable this guard will use the wrong identifier and get a compile error.

Suggested change
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
// I'd like this kernel to propagate infs/nans.
// if(*is_infinite_gmem == 1)
// return;

@ksivaman ksivaman merged commit 06a23e3 into NVIDIA:main Mar 12, 2026
25 of 32 checks passed
vthumbe1503 pushed a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
…DIA#2594)

* Initial commit to pass scale as Tensor for multi_tensor_scale op

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable capturable mode for optimizer if store_param_remainders is passed but not actually enabled

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Revert "Enable capturable mode for optimizer if store_param_remainders is passed but not actually enabled"

This reverts commit 74a9bcc.

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Apply suggestion from @greptile-apps[bot]

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Change noop_flag to is_infinite

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Update transformer_engine/pytorch/csrc/extensions.h

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove duplication

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add test for scale tensor cuda

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants