From 22f4bcaf48f537d68ce65387edbdee2078cf5a98 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 26 Aug 2025 16:13:33 -0700 Subject: [PATCH 1/3] test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/pytorch/attention/test_cp_utils.py | 615 ++++++++++++++++++ .../dot_product_attention/context_parallel.py | 174 ++++- 3 files changed, 789 insertions(+), 1 deletion(-) create mode 100644 tests/pytorch/attention/test_cp_utils.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index e5b4b58617..7f061d222a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py new file mode 100644 index 0000000000..68d6e0c10d --- /dev/null +++ b/tests/pytorch/attention/test_cp_utils.py @@ -0,0 +1,615 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for context parallel utils.""" +import torch +import unittest +from typing import Tuple +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import get_batch_on_this_cp_rank, pad_thd_sequences_for_cp, generate_positional_ids_for_cp + + +class TestSequencePadding(unittest.TestCase): + def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self): + """Test with custom padding values for all tensors.""" + # Setup + + input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3]) + cu_seqlens = torch.tensor([0, 3, 5, 9]) + labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100]) + positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + divisibility_factor = 8 + + pid = 777 + label_pad = -200 + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Sequence: [ a a a p p p p p b b pppppp ccccpppp] + print("input_ids_padded: ", input_ids_padded) + print("labels_padded: ", labels_padded) + print("positional_ids_padded: ", positional_ids_padded) + print("cu_seqlens_padded: ", cu_seqlens_padded) + + expected_input_ids = torch.tensor( + [1, 1, 1, pid, pid, pid, pid, pid, 2, 2, pid, pid, pid, pid, pid, pid, 3, 3, 3, 3, pid, pid, pid, pid] + ) + expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24]) + expected_labels_padded = torch.tensor( + [ + -100, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + 13, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + ] + ) + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + assert torch.equal(input_ids_padded, expected_input_ids) + assert torch.equal(labels_padded, expected_labels_padded) + assert torch.equal(positional_ids_padded, expected_positional_ids) + assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded) + + def test_mixed_sequence_lengths_with_divisibility_factor(self): + """Test with sequences both shorter and longer than divisibility factor.""" + # Setup - divisibility factor 6 + # Seq 1: length 2 (shorter than 6, needs 4 padding) + # Seq 2: length 7 (longer than 6, needs 5 padding to reach 12) + # Seq 3: length 4 (shorter than 6, needs 2 padding) + # Seq 4: length 10 (longer than 6, needs 2 padding to reach 12) + + input_ids = torch.tensor([1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + labels = torch.tensor( + [10, 11, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 33, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + ) + positional_ids = torch.tensor([0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + cu_seqlens = torch.tensor([0, 2, 9, 13, 23]) + divisibility_factor = 6 + + pid = 999 + label_pad = -300 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: [1,1] + 4 pads = 6 total + # Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total + # Seq 3: [3,3,3,3] + 2 pads = 6 total + # Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total + + expected_input_ids = torch.tensor( + [ + 1, + 1, + pid, + pid, + pid, + pid, # Seq 1: 2 + 4 padding + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + pid, + pid, + pid, + pid, # Seq 2: 7 + 5 padding + 3, + 3, + 3, + 3, + pid, + pid, # Seq 3: 4 + 2 padding + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + pid, + pid, # Seq 4: 10 + 2 padding + ] + ) + + expected_labels = torch.tensor( + [ + 10, + 11, + label_pad, + label_pad, + label_pad, + label_pad, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + 30, + 31, + 32, + 33, + label_pad, + label_pad, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [ + 0, + 1, + 2, + 3, + 4, + 5, # Seq 1 positions continue through padding + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 2 positions continue + 0, + 1, + 2, + 3, + 4, + 5, # Seq 3 positions continue + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 4 positions continue + ] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + def test_sequences_longer_than_divisibility_factor(self): + """Test with all sequences longer than the divisibility factor.""" + # Setup - divisibility factor 4, all sequences longer than 4 + # Seq 1: length 7 (needs 1 padding to reach 8) + # Seq 2: length 11 (needs 1 padding to reach 12) + # Seq 3: length 5 (needs 3 padding to reach 8) + + input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, # 7 tokens + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, # 11 tokens + 3, + 3, + 3, + 3, + 3, # 5 tokens + ] + ) + labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 300, + 301, + 302, + 303, + 304, + ] + ) + positional_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4]) + cu_seqlens = torch.tensor([0, 7, 18, 23]) + divisibility_factor = 4 + + pid = 888 + label_pad = -400 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: 7 + 1 pad = 8 (divisible by 4) + # Seq 2: 11 + 1 pad = 12 (divisible by 4) + # Seq 3: 5 + 3 pads = 8 (divisible by 4) + + expected_input_ids = torch.tensor( + [1, 1, 1, 1, 1, 1, 1, pid, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, pid, 3, 3, 3, 3, 3, pid, pid, pid] + ) + + expected_labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + label_pad, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + label_pad, + 300, + 301, + 302, + 303, + 304, + label_pad, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + +class TestContextParallelUtils(unittest.TestCase): + """Test utilities for context parallel functionality.""" + + def setUp(self): + """Set up mock distributed environment.""" + # Mock torch.distributed functions + self.original_get_world_size = torch.distributed.get_world_size + self.original_get_rank = torch.distributed.get_rank + + def tearDown(self): + """Restore original torch.distributed functions.""" + torch.distributed.get_world_size = self.original_get_world_size + torch.distributed.get_rank = self.original_get_rank + + def _mock_distributed_env(self, cp_size, cp_rank): + """Mock the distributed environment for testing.""" + def mock_get_world_size(group=None): + return cp_size + def mock_get_rank(group=None): + return cp_rank + + torch.distributed.get_world_size = mock_get_world_size + torch.distributed.get_rank = mock_get_rank + + def test_cp_rank_slicing_simple_case(self): + """Test CP rank slicing with a simple 2-rank, single sequence case.""" + # Setup: Single sequence of length 8, CP size = 2 + # Each sequence gets divided into 2*cp_size = 4 slices of size 2 each + # Rank 0 gets slices [0,1] and [6,7] (first and last) + # Rank 1 gets slices [2,3] and [4,5] (second and second-to-last) + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]]) + expected_labels_r1 = torch.tensor([[30, 40, 50, 60]]) + expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5]) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_cp_rank_slicing_multiple_sequences(self): + """Test CP rank slicing with multiple sequences.""" + # Setup: Two sequences of length 8 each, CP size = 2 + # Total sequence length = 16, cu_seqlens = [0, 8, 16] + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]]) + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8, 16]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # For each sequence, rank 0 gets first and last slices + # Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8] + # Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_with_cp_size_1(self): + """Test that CP size = 1 returns original tensors unchanged.""" + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=1, cp_rank=0) + input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # With CP size = 1, should return original tensors + self.assertTrue(torch.equal(input_ids_result, input_ids)) + self.assertTrue(torch.equal(labels_result, labels)) + self.assertTrue(torch.equal(pos_ids_result, position_ids)) + + def test_cp_rank_slicing_sequence_dim_detection(self): + """Test that the function correctly detects sequence dimension.""" + # Test with sequence dimension = 0 (sequence_length, batch_size) + input_ids = torch.tensor([[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]) # (8, 2) + labels = torch.tensor([[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]) + position_ids = torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Should get indices [0,1] and [6,7] along dimension 0 + expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_mixed_dimensions(self): + """Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension.""" + # Setup: Single sequence of length 8, CP size = 2 + # This tests the opposite case from the simple test: + # - input_ids and labels: 1D (no batch dimension) + # - position_ids: 2D (has batch dimension) + + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D + labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D + position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result + expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result + expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result + expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result + expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_integration_with_padding_and_cp_slicing(self): + """Integration test: pad sequences then slice for CP ranks.""" + # Start with unpadded sequences + input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2] + labels = torch.tensor([10, 11, 20, 21, 22]) + positional_ids = torch.tensor([0, 1, 0, 1, 2]) + cu_seqlens = torch.tensor([0, 2, 5]) + divisibility_factor = 4 # Will pad to lengths 4 and 4 + + # First, pad sequences + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=0, + padding_label_id=-100, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8] + expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0]) + self.assertTrue(torch.equal(input_ids_padded, expected_padded)) + + # Now test CP slicing with cp_size=2 + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens_padded, + input_ids_padded.unsqueeze(0), labels_padded.unsqueeze(0), positional_ids_padded + ) + + # Each sequence of length 4 gets divided into 4 slices of size 1 + # Rank 0 gets slices [0] and [3] from each sequence + # Seq 1: indices [0] and [3] -> values [1] and [0] + # Seq 2: indices [4] and [7] -> values [2] and [0] + expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c6f4647c04..6f66bc60f6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,7 +4,7 @@ """Context Parallelism.""" import os -from typing import List, Union +from typing import List, Union, Tuple import torch import transformer_engine_torch as tex @@ -3927,3 +3927,175 @@ def attn_forward_func_with_cp( raise ValueError(f"Unsupported communication type: {cp_comm_type}!") return out + + +def pad_thd_sequences_for_cp( + input_ids: torch.Tensor, + labels: torch.Tensor, + cu_seqlens: torch.Tensor, + divisibility_factor: int, + padding_token_id: int = 0, + padding_label_id: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pads sequences to be divisible by the divisibility factor. + + Args: + input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences + labels: Tensor of shape (1, N) or (N,) containing labels for each token + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + padding_token_id: Token ID to use for padding (default: 0) + padding_label_id: Label ID to use for padding (default: -100) + + Returns: + Tuple of: + - input_ids_padded: Padded input_ids tensor + - labels_padded: Padded labels tensor + - cu_seqlens_padded: Cumulative sequence lengths accounting for padding + """ + # Flatten input_ids and labels if needed + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + if labels.dim() == 2: + labels = labels.squeeze(0) + + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) + padding_amounts = [((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor - l.item() for l in seqlens] + + # Extract sequences and labels for each batch item + batch_sequences = [input_ids[start.item():end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] + batch_labels = [labels[start.item():end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] + + # Pad sequences and labels to required length + input_ids_padded = torch.cat([torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) if pad > 0 else seq for seq, pad in zip(batch_sequences, padding_amounts)]) + labels_padded = torch.cat([torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) if pad > 0 else seq for seq, pad in zip(batch_labels, padding_amounts)]) + + # Compute cumulative padded sequence lengths, starting from 0 + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + cu_seqlens_padded = torch.cumsum( + torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0 + ) + + return input_ids_padded, labels_padded, cu_seqlens_padded + + +def generate_positional_ids_for_cp( + cu_seqlens: torch.Tensor, + divisibility_factor: int, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """Generate positional IDs for sequences padded to be divisible by divisibility_factor. + + Args: + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + dtype: Data type for the generated positional IDs (default: torch.long) + + Returns: + Generated positional_ids tensor where each sequence starts from 0 and continues through padding + """ + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence + padding_amounts = [((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor - l.item() for l in seqlens] + + # Generate positional IDs for each padded sequence (each starts from 0) + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + positional_ids = torch.cat([ + torch.arange(0, int(length), dtype=dtype) for length in padded_lengths + ]) + + return positional_ids + + +def get_batch_on_this_cp_rank( + cu_seqlens_padded: torch.Tensor, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + position_ids_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup = None, + qvk_format: str = "thd", +): + """Slice batch input along sequence dimension into multiple chunks for THD format. + + This function is inteded for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + if qvk_format == "thd": + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + if cp_size > 1: + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError("1D tensor shape doesn't match expected sequence length. Make sure the inputs are in THD format and padded correctly.") + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError("Make sure the inputs are in THD format and padded correctly.") + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + position_ids_padded = process_tensor(position_ids_padded) + else: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded, position_ids_padded \ No newline at end of file From 2f739c14c1ab299982921be0acea00f837301191 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 9 Sep 2025 15:19:17 -0700 Subject: [PATCH 2/3] assert line change Signed-off-by: Jonathan Mitchell --- .../pytorch/attention/dot_product_attention/context_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 6f66bc60f6..1eadd0e2dc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4096,6 +4096,6 @@ def process_tensor(val): labels_padded = process_tensor(labels_padded) position_ids_padded = process_tensor(position_ids_padded) else: - raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") return input_ids_padded, labels_padded, position_ids_padded \ No newline at end of file From fc77446968dc9aff88ccbf700dd177e2a5783d5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Sep 2025 22:19:52 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_cp_utils.py | 206 +++++++++++++----- .../dot_product_attention/context_parallel.py | 77 +++++-- 2 files changed, 210 insertions(+), 73 deletions(-) diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py index 68d6e0c10d..00200c62d2 100644 --- a/tests/pytorch/attention/test_cp_utils.py +++ b/tests/pytorch/attention/test_cp_utils.py @@ -6,7 +6,11 @@ import torch import unittest from typing import Tuple -from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import get_batch_on_this_cp_rank, pad_thd_sequences_for_cp, generate_positional_ids_for_cp +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + get_batch_on_this_cp_rank, + pad_thd_sequences_for_cp, + generate_positional_ids_for_cp, +) class TestSequencePadding(unittest.TestCase): @@ -31,12 +35,12 @@ def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_ padding_token_id=pid, padding_label_id=label_pad, ) - + positional_ids_padded = generate_positional_ids_for_cp( cu_seqlens, divisibility_factor, ) - + # Sequence: [ a a a p p p p p b b pppppp ccccpppp] print("input_ids_padded: ", input_ids_padded) print("labels_padded: ", labels_padded) @@ -44,7 +48,32 @@ def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_ print("cu_seqlens_padded: ", cu_seqlens_padded) expected_input_ids = torch.tensor( - [1, 1, 1, pid, pid, pid, pid, pid, 2, 2, pid, pid, pid, pid, pid, pid, 3, 3, 3, 3, pid, pid, pid, pid] + [ + 1, + 1, + 1, + pid, + pid, + pid, + pid, + pid, + 2, + 2, + pid, + pid, + pid, + pid, + pid, + pid, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + pid, + ] ) expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24]) expected_labels_padded = torch.tensor( @@ -92,11 +121,39 @@ def test_mixed_sequence_lengths_with_divisibility_factor(self): # Seq 3: length 4 (shorter than 6, needs 2 padding) # Seq 4: length 10 (longer than 6, needs 2 padding to reach 12) - input_ids = torch.tensor([1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + input_ids = torch.tensor( + [1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + ) labels = torch.tensor( - [10, 11, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 33, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + [ + 10, + 11, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 30, + 31, + 32, + 33, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + ] + ) + positional_ids = torch.tensor( + [0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ) - positional_ids = torch.tensor([0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) cu_seqlens = torch.tensor([0, 2, 9, 13, 23]) divisibility_factor = 6 @@ -112,7 +169,7 @@ def test_mixed_sequence_lengths_with_divisibility_factor(self): padding_token_id=pid, padding_label_id=label_pad, ) - + positional_ids_padded = generate_positional_ids_for_cp( cu_seqlens, divisibility_factor, @@ -315,7 +372,9 @@ def test_sequences_longer_than_divisibility_factor(self): 304, ] ) - positional_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4]) + positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4] + ) cu_seqlens = torch.tensor([0, 7, 18, 23]) divisibility_factor = 4 @@ -331,7 +390,7 @@ def test_sequences_longer_than_divisibility_factor(self): padding_token_id=pid, padding_label_id=label_pad, ) - + positional_ids_padded = generate_positional_ids_for_cp( cu_seqlens, divisibility_factor, @@ -343,7 +402,36 @@ def test_sequences_longer_than_divisibility_factor(self): # Seq 3: 5 + 3 pads = 8 (divisible by 4) expected_input_ids = torch.tensor( - [1, 1, 1, 1, 1, 1, 1, pid, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, pid, 3, 3, 3, 3, 3, pid, pid, pid] + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + pid, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + 3, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + ] ) expected_labels = torch.tensor( @@ -390,27 +478,30 @@ def test_sequences_longer_than_divisibility_factor(self): self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + class TestContextParallelUtils(unittest.TestCase): """Test utilities for context parallel functionality.""" - + def setUp(self): """Set up mock distributed environment.""" # Mock torch.distributed functions self.original_get_world_size = torch.distributed.get_world_size self.original_get_rank = torch.distributed.get_rank - + def tearDown(self): """Restore original torch.distributed functions.""" torch.distributed.get_world_size = self.original_get_world_size torch.distributed.get_rank = self.original_get_rank - + def _mock_distributed_env(self, cp_size, cp_rank): """Mock the distributed environment for testing.""" + def mock_get_world_size(group=None): return cp_size + def mock_get_rank(group=None): return cp_rank - + torch.distributed.get_world_size = mock_get_world_size torch.distributed.get_rank = mock_get_rank @@ -420,38 +511,38 @@ def test_cp_rank_slicing_simple_case(self): # Each sequence gets divided into 2*cp_size = 4 slices of size 2 each # Rank 0 gets slices [0,1] and [6,7] (first and last) # Rank 1 gets slices [2,3] and [4,5] (second and second-to-last) - + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected cu_seqlens = torch.tensor([0, 8]) - + # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # Rank 0 should get indices [0,1] and [6,7] expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]]) expected_labels_r0 = torch.tensor([[10, 20, 70, 80]]) expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7]) - + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) - + # Test rank 1 self._mock_distributed_env(cp_size=2, cp_rank=1) input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # Rank 1 should get indices [2,3] and [4,5] expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]]) expected_labels_r1 = torch.tensor([[30, 40, 50, 60]]) expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5]) - + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) @@ -460,25 +551,27 @@ def test_cp_rank_slicing_multiple_sequences(self): """Test CP rank slicing with multiple sequences.""" # Setup: Two sequences of length 8 each, CP size = 2 # Total sequence length = 16, cu_seqlens = [0, 8, 16] - + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]]) - labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]]) + labels = torch.tensor( + [[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]] + ) position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) cu_seqlens = torch.tensor([0, 8, 16]) - + # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # For each sequence, rank 0 gets first and last slices # Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8] # Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18] expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]]) expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]]) expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7]) - + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) @@ -489,12 +582,12 @@ def test_cp_rank_slicing_with_cp_size_1(self): labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) cu_seqlens = torch.tensor([0, 8]) - + self._mock_distributed_env(cp_size=1, cp_rank=0) input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # With CP size = 1, should return original tensors self.assertTrue(torch.equal(input_ids_result, input_ids)) self.assertTrue(torch.equal(labels_result, labels)) @@ -503,21 +596,27 @@ def test_cp_rank_slicing_with_cp_size_1(self): def test_cp_rank_slicing_sequence_dim_detection(self): """Test that the function correctly detects sequence dimension.""" # Test with sequence dimension = 0 (sequence_length, batch_size) - input_ids = torch.tensor([[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]) # (8, 2) - labels = torch.tensor([[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]) - position_ids = torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) + input_ids = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) # (8, 2) + labels = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) + position_ids = torch.tensor( + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]] + ) cu_seqlens = torch.tensor([0, 8]) - + self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # Should get indices [0,1] and [6,7] along dimension 0 expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]]) - + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) @@ -528,38 +627,38 @@ def test_cp_rank_slicing_mixed_dimensions(self): # This tests the opposite case from the simple test: # - input_ids and labels: 1D (no batch dimension) # - position_ids: 2D (has batch dimension) - + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch cu_seqlens = torch.tensor([0, 8]) - + # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # Rank 0 should get indices [0,1] and [6,7] expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim) - + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) - + # Test rank 1 self._mock_distributed_env(cp_size=2, cp_rank=1) input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( cu_seqlens, input_ids, labels, position_ids ) - + # Rank 1 should get indices [2,3] and [4,5] expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim) - + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) @@ -572,7 +671,7 @@ def test_integration_with_padding_and_cp_slicing(self): positional_ids = torch.tensor([0, 1, 0, 1, 2]) cu_seqlens = torch.tensor([0, 2, 5]) divisibility_factor = 4 # Will pad to lengths 4 and 4 - + # First, pad sequences input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( input_ids.unsqueeze(0), @@ -582,34 +681,35 @@ def test_integration_with_padding_and_cp_slicing(self): padding_token_id=0, padding_label_id=-100, ) - + positional_ids_padded = generate_positional_ids_for_cp( cu_seqlens, divisibility_factor, ) - + # Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8] expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0]) self.assertTrue(torch.equal(input_ids_padded, expected_padded)) - + # Now test CP slicing with cp_size=2 - + # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( - cu_seqlens_padded, - input_ids_padded.unsqueeze(0), labels_padded.unsqueeze(0), positional_ids_padded + cu_seqlens_padded, + input_ids_padded.unsqueeze(0), + labels_padded.unsqueeze(0), + positional_ids_padded, ) - + # Each sequence of length 4 gets divided into 4 slices of size 1 # Rank 0 gets slices [0] and [3] from each sequence # Seq 1: indices [0] and [3] -> values [1] and [0] # Seq 2: indices [4] and [7] -> values [2] and [0] expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]]) - - self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 1eadd0e2dc..f00bd573f1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3963,15 +3963,41 @@ def pad_thd_sequences_for_cp( seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) - padding_amounts = [((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor - l.item() for l in seqlens] + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] # Extract sequences and labels for each batch item - batch_sequences = [input_ids[start.item():end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] - batch_labels = [labels[start.item():end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] + batch_sequences = [ + input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + batch_labels = [ + labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] # Pad sequences and labels to required length - input_ids_padded = torch.cat([torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) if pad > 0 else seq for seq, pad in zip(batch_sequences, padding_amounts)]) - labels_padded = torch.cat([torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) if pad > 0 else seq for seq, pad in zip(batch_labels, padding_amounts)]) + input_ids_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_sequences, padding_amounts) + ] + ) + labels_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_labels, padding_amounts) + ] + ) # Compute cumulative padded sequence lengths, starting from 0 padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) @@ -3988,27 +4014,31 @@ def generate_positional_ids_for_cp( dtype: torch.dtype = torch.long, ) -> torch.Tensor: """Generate positional IDs for sequences padded to be divisible by divisibility_factor. - + Args: cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths divisibility_factor: Each sequence length must be divisible by this factor dtype: Data type for the generated positional IDs (default: torch.long) - + Returns: Generated positional_ids tensor where each sequence starts from 0 and continues through padding """ # Compute the sequence lengths from cu_seqlens seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - + # List: amount of padding needed for each sequence - padding_amounts = [((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor - l.item() for l in seqlens] - + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + # Generate positional IDs for each padded sequence (each starts from 0) padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) - positional_ids = torch.cat([ - torch.arange(0, int(length), dtype=dtype) for length in padded_lengths - ]) - + positional_ids = torch.cat( + [torch.arange(0, int(length), dtype=dtype) for length in padded_lengths] + ) + return positional_ids @@ -4022,7 +4052,7 @@ def get_batch_on_this_cp_rank( ): """Slice batch input along sequence dimension into multiple chunks for THD format. - This function is inteded for use in self attention. It will not work for cross attention because + This function is inteded for use in self attention. It will not work for cross attention because it does not handle the case where the sequence length of the query and key are different. Which are parallelized across GPUs in a context parallel group. @@ -4038,7 +4068,9 @@ def get_batch_on_this_cp_rank( # Calculate the chunk sizes for each sequence total_slices_of_any_sequence = 2 * cp_size - slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence # Process each tensor directly instead of using keys_to_change loop def process_tensor(val): @@ -4050,20 +4082,25 @@ def process_tensor(val): seq_len_val = cu_seqlens_padded[-1].item() else: seq_len_val = cu_seqlens_padded[-1] - + # Handle 1D tensors (like position_ids that don't have batch dimension) if val.ndim == 1: if val.shape[0] == seq_len_val: current_seq_dim = 0 else: - raise ValueError("1D tensor shape doesn't match expected sequence length. Make sure the inputs are in THD format and padded correctly.") + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) elif val.ndim >= 2: if val.shape[1] == seq_len_val: current_seq_dim = 1 elif val.shape[0] == seq_len_val: current_seq_dim = 0 else: - raise ValueError("Make sure the inputs are in THD format and padded correctly.") + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) else: raise ValueError("Tensor must be at least 1D") @@ -4098,4 +4135,4 @@ def process_tensor(val): else: raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") - return input_ids_padded, labels_padded, position_ids_padded \ No newline at end of file + return input_ids_padded, labels_padded, position_ids_padded