Skip to content

[PyTorch] Inference mode disables initializing quantized weights with column-wise usage#1847

Merged
ksivaman merged 4 commits intoNVIDIA:mainfrom
timmoon10:inference-mode-weight-init
Jun 13, 2025
Merged

[PyTorch] Inference mode disables initializing quantized weights with column-wise usage#1847
ksivaman merged 4 commits intoNVIDIA:mainfrom
timmoon10:inference-mode-weight-init

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

Description

When initializing a model with quantized weights, the required data is different for training and inference (training requires row-wise data for forward GEMM and column-wise data for dgrad GEMM, inference only requires column-wise). This PR adds logic so the model will only initialize quantized weights with the data required for inference when initialized within no-grad mode or inference mode. It is also less aggressive about deallocating weight data in order to handle cases where we were alternating between training and validation modes.

This is an alternative to #1827. The heuristic API in that PR are somewhat redundant with these plain PyTorch APIs, but it is also more general.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Avoid initializing quantized weights with column-wise usage if grads are not enabled
  • Do not deallocate unnecessary usages in weight tensors during forward pass

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 added 2 commits June 4, 2025 00:16
…ce mode

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested review from ksivaman and ptrendx June 4, 2025 00:37
@timmoon10 timmoon10 added bug Something isn't working enhancement New feature or request labels Jun 4, 2025
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 added the 2.5.0 label Jun 5, 2025
import pytest
import os

import transformer_engine.pytorch
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Copy Markdown
Collaborator Author

@timmoon10 timmoon10 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it convenient to be able to access a class without explicitly doing from ... import ...:

module = transformer_engine.pytorch.ops.Linear(hidden_size, hidden_size)

It's just a matter of style though. Within the package we explicitly list the imports to order to guarantee only relative imports, but this isn't relevant for tests since we always do absolute imports. Also, Google's style guide recommends against it.

Comment thread tests/pytorch/test_sanity.py Outdated
Comment thread tests/pytorch/test_sanity.py Outdated
timmoon10 and others added 2 commits June 9, 2025 16:50
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit 655512c into NVIDIA:main Jun 13, 2025
21 checks passed
@timmoon10 timmoon10 deleted the inference-mode-weight-init branch June 13, 2025 03:22
@ptrendx ptrendx mentioned this pull request Jul 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.5.0 bug Something isn't working enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants