Skip to content

feat: TP-aware KDLoss with distributed softmax and T² scaling#1499

Merged
akoumpa merged 1 commit intomainfrom
ssameni/feat_tp_kd_loss
Mar 9, 2026
Merged

feat: TP-aware KDLoss with distributed softmax and T² scaling#1499
akoumpa merged 1 commit intomainfrom
ssameni/feat_tp_kd_loss

Conversation

@Separius
Copy link
Copy Markdown
Contributor

@Separius Separius commented Mar 9, 2026

Add tensor-parallel support to KDLoss via two new module-level helpers:

  • _infer_tp_group_from_dtensor: extracts the TP ProcessGroup from a vocab-sharded DTensor logit, avoiding an explicit tp_group argument in most cases.
  • _kl_forward_tp: computes per-token KL using numerically stable global softmax/log-softmax over all_reduce, keeping logits on local shards to avoid gathering the full vocabulary.

KDLoss.forward gains a tp_group parameter (default None, backward- compatible) and auto-detects a TP group from DTensor student_logits. T² loss scaling (Hinton et al., 2015) is applied when temperature != 1 so that gradient magnitudes stay independent of the chosen temperature.

Tests extended with single-process gloo-backed fixtures that verify the TP path matches the non-TP path at world_size=1, plus dedicated tests for T² scaling and _infer_tp_group_from_dtensor.

Add tensor-parallel support to KDLoss via two new module-level helpers:
- _infer_tp_group_from_dtensor: extracts the TP ProcessGroup from a
  vocab-sharded DTensor logit, avoiding an explicit tp_group argument
  in most cases.
- _kl_forward_tp: computes per-token KL using numerically stable global
  softmax/log-softmax over all_reduce, keeping logits on local shards
  to avoid gathering the full vocabulary.

KDLoss.forward gains a tp_group parameter (default None, backward-
compatible) and auto-detects a TP group from DTensor student_logits.
T² loss scaling (Hinton et al., 2015) is applied when temperature != 1
so that gradient magnitudes stay independent of the chosen temperature.

Tests extended with single-process gloo-backed fixtures that verify the
TP path matches the non-TP path at world_size=1, plus dedicated tests
for T² scaling and _infer_tp_group_from_dtensor.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Mar 9, 2026

@akoumpa for visibility

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 9, 2026

/ok to test 8ffe1e7

@akoumpa akoumpa merged commit 30fbb00 into main Mar 9, 2026
52 checks passed
@akoumpa akoumpa deleted the ssameni/feat_tp_kd_loss branch March 9, 2026 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants