# Simple Document-Level CP Sharding Simulation

In [None]:
import torch
from torch.autograd import Function

class PerDocumentCPSharding(Function):
    @staticmethod
    def forward(ctx, input, shard_id, world_size=2):
        """
        CP-sharding simulation
        - input (Tensor): [B, L, D] input tensor
        - shard_id (int): 0 or 1
        - world_size (int): Must be 2 (2 workers for this implementation)
        """
        # Force world size to be 2
        if world_size != 2:
            raise ValueError("This implementation only supports world_size==2.")

        # Assume doc length is divisible by 4 for simplicity
        B, L, D = input.shape
        if L % 4 != 0:
            raise ValueError("Sequence length L must be divisible by 4.")
        quarter = L // 4

        # Save context for backward pass
        ctx.input_shape = input.shape
        ctx.quarter = quarter
        ctx.shard_id = shard_id

        if shard_id == 0:
            # Worker 0: first and last quarter
            first_part = input[:, :quarter, :]
            last_part  = input[:, 3*quarter:, :]
            out = torch.cat([first_part, last_part], dim=1)
        elif shard_id == 1:
            # Worker 1: Middle half
            out = input[:, quarter:3*quarter, :]
        else:
            raise ValueError("shard_id must be 0 or 1.")

        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        Obtain gradients for original input
        For worker 0:
          - The first quarter of grad_output goes to tokens 0 : quarter
          - The second quarter goes to tokens 3*quarter : L
        For worker 1:
          - grad_output maps to tokens quarter : 3*quarter
        """
        B, L, D = ctx.input_shape
        quarter = ctx.quarter
        shard_id = ctx.shard_id

        # Initialize gradient tensor with zeros
        grad_input = torch.zeros((B, L, D), device=grad_output.device, dtype=grad_output.dtype)

        if shard_id == 0:
            # grad_output shape: [B, 2*quarter, D]
            grad_input[:, :quarter, :]      = grad_output[:, :quarter, :]
            grad_input[:, 3*quarter:, :] = grad_output[:, quarter:, :]
        elif shard_id == 1:
            grad_input[:, quarter:3*quarter, :] = grad_output
        else:
            raise ValueError("Invalid shard_id in backward pass.")

        return grad_input, None, None

# Wrapper
def per_document_cp_shard(input, shard_id, world_size):
    return PerDocumentCPSharding.apply(input, shard_id, world_size)


In [None]:
# Test usage
if __name__ == "__main__":
    # Assume a batch of 2 documents, each with 16 tokens and feature dimension 8
    B, L, D = 2, 16, 8
    world_size = 2  # Only two shards are supported here

    # Create an input tensor
    x = torch.randn(B, L, D, requires_grad=True)

    # Simulate worker 0 (shard_id=0)
    out0 = per_document_cp_shard(x, shard_id=0, world_size=world_size)
    loss0 = out0.sum()
    loss0.backward(retain_graph=True)
    print("Worker 0 input gradient:")
    print(x.grad)  # Should be nonzero only in positions 0:quarter and 3*quarter : L

    # Reset gradients
    x.grad.zero_()

    # Simulate worker 1 (shard_id=1)
    out1 = per_document_cp_shard(x, shard_id=1, world_size=world_size)
    loss1 = out1.sum()
    loss1.backward()
    print("Worker 1 input gradient:")
    print(x.grad)  # Should be nonzero only in positions quarter : 3*quarter

Worker 0 input gradient:
tensor([[[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0

# Now including varying document lengths

In [None]:
import torch
from torch.autograd import Function
import math

class PerDocumentCPShardingVarLen(Function):
    @staticmethod
    def forward(ctx, input, doc_lens, shard_id, world_size):
        """
        Forward pass for per-document CP sharding with variable document lengths.

        For each document:
          - Compute left_end = floor(doc_len / 4)
          - Compute right_start = floor(3 * doc_len / 4)

        Then, for each document:
          - If shard_id == 0: output = concatenate( tokens[0:left_end], tokens[right_start:doc_len] )
          - If shard_id == 1: output = tokens[left_end:right_start]

        Since document lengths vary, outputs are padded to a common length.

        Args:
            input (Tensor): Padded input tensor of shape [B, L, D]
            doc_lens (Tensor): 1D tensor of shape [B] with true lengths of each document.
            shard_id (int): Worker id (0 or 1). (Only world_size==2 is supported.)
            world_size (int): Number of shards; must be 2.

        Returns:
            Tensor: Sharded output tensor of shape [B, max_out_len, D]
        """
        if world_size != 2:
            raise ValueError("This implementation only supports world_size==2.")

        B, L, D = input.shape
        output_list = []
        idx_info = []  # to store per-document indices for backward
        out_lengths = []  # atctual number of tokens output for each document
        max_out_len = 0

        for i in range(B):
            L_i = int(doc_lens[i])
            # Compute split indices for document i
            left_end = L_i // 4            # First quarter boundary
            right_start = (3 * L_i) // 4     # Start of last quarter

            if shard_id == 0:
                # Worker 0 gets the first quarter and the last quarter
                part1 = input[i, :left_end, :]
                part2 = input[i, right_start:L_i, :]
                out_i = torch.cat([part1, part2], dim=0)
            elif shard_id == 1:
                # Worker 1 gets the middle half
                out_i = input[i, left_end:right_start, :]
            else:
                raise ValueError("shard_id must be 0 or 1.")

            cur_len = out_i.size(0)
            max_out_len = max(max_out_len, cur_len)
            out_lengths.append(cur_len)
            # Store indices to use in backward
            idx_info.append({'left_end': left_end, 'right_start': right_start, 'doc_len': L_i})
            output_list.append(out_i)

        # Pad each document’s output to have the same length (max_out_len)
        out_tensor = input.new_zeros((B, max_out_len, D))
        for i, out_i in enumerate(output_list):
            cur_len = out_i.size(0)
            out_tensor[i, :cur_len, :] = out_i

        # Save information needed for the backward pass
        ctx.idx_info = idx_info
        ctx.out_lengths = out_lengths
        ctx.doc_lens = doc_lens
        ctx.shard_id = shard_id
        ctx.input_shape = input.shape  # [B, L, D]

        return out_tensor

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: get gradients for the original padded input.

        For each document i:
          - For shard 0:
              • Gradients corresponding to the first part (length = left_end)
                are placed into positions [0 : left_end]
              • Gradients corresponding to the second part (length = doc_len - right_start)
                are placed into positions [right_start : doc_len]
          - For shard 1:
              • Gradients are placed into positions [left_end : right_start]

        Args:
            grad_output (Tensor): Gradient of the sharded output, shape [B, max_out_len, D]

        Returns:
            Tuple: (grad_input, None, None, None)
                   grad_input has shape [B, L, D] corresponding to the original input
        """
        B, L, D = ctx.input_shape
        grad_input = torch.zeros((B, L, D), device=grad_output.device, dtype=grad_output.dtype)
        idx_info = ctx.idx_info
        shard_id = ctx.shard_id

        for i, info in enumerate(idx_info):
            doc_len = info['doc_len']
            left_end = info['left_end']
            right_start = info['right_start']

            if shard_id == 0:
                # Number of tokens in worker 0's output:
                num_first = left_end               # tokens from the beginning
                num_last = doc_len - right_start     # tokens from the end
                # Scatter gradients for the first part
                if num_first > 0:
                    grad_input[i, :left_end, :] = grad_output[i, :num_first, :]
                # Scatter gradients for the last part
                if num_last > 0:
                    grad_input[i, right_start:doc_len, :] = grad_output[i, num_first:num_first + num_last, :]
            elif shard_id == 1:
                # Worker 1 output corresponds to tokens [left_end:right_start]
                grad_input[i, left_end:right_start, :] = grad_output[i, :right_start - left_end, :]

        return grad_input, None, None, None

# Wrapper
def per_document_cp_shard_varlen(input, doc_lens, shard_id, world_size):
    return PerDocumentCPShardingVarLen.apply(input, doc_lens, shard_id, world_size)

In [None]:
# Test usage
if __name__ == "__main__":
    # Example with a batch of 3 documents
    # The padded input tensor is of shape [B, L, D], L is the maximum document length
    B = 3
    max_length = 20  # Maximum padded length
    D = 8
    # Document lengths (each document should have at least 4 tokens for the split to make sense)
    doc_lens = torch.tensor([16, 12, 18])
    # Create a padded input tensor (for simplicity, use random data)
    input_tensor = torch.randn(B, max_length, D, requires_grad=True)

    # Simulate processing on worker 0 (shard_id=0)
    out_worker0 = per_document_cp_shard_varlen(input_tensor, doc_lens, shard_id=0, world_size=2)
    # And worker 1 (shard_id=1)
    out_worker1 = per_document_cp_shard_varlen(input_tensor, doc_lens, shard_id=1, world_size=2)

    # For demonstration, compute a dummy loss on each worker's output
    loss0 = out_worker0.sum()
    loss1 = out_worker1.sum()

    # Backpropagate (retain graph between calls for demonstration)
    loss0.backward(retain_graph=True)
    loss1.backward()

    print("Sharded outputs (worker 0) shape:", out_worker0.shape)
    print("Sharded outputs (worker 1) shape:", out_worker1.shape)
    print("Expected ouput: [3, 9, 8] (3 documents, 9 tokens per document after padding, and 8 feature dimensions).")


Sharded outputs (worker 0) shape: torch.Size([3, 9, 8])
Sharded outputs (worker 1) shape: torch.Size([3, 9, 8])
Expected ouput: [3, 9, 8] (3 documents, 9 tokens per document after padding, and 8 feature dimensions).


# Now integrate per-document cp sharding with an attention layer

In [None]:
import torch
from torch.autograd import Function
import math

# Dummies for simulating distributed environment
def get_distributed_rank(cp_group):
    # For testing, assume cp_group is a dict with key 'rank'
    return cp_group['rank']

class DummyCPStream:
    def wait_stream(self, stream):
        pass

# Per-document CP sharding (variable-length) implementation.
class PerDocumentCPShardingVarLen(Function):
    @staticmethod
    def forward(ctx, input, doc_lens, shard_id, world_size):
        """
        For each document in the batch:
          - Compute left_end = floor(doc_len / 4)
          - Compute right_start = floor(3 * doc_len / 4)
        Then:
          - If shard_id == 0: output = concat( tokens[0:left_end], tokens[right_start:doc_len] )
          - If shard_id == 1: output = tokens[left_end:right_start]
        Since documents have variable true lengths, each document’s sharded output is padded
        to the maximum output length in the batch.
        """
        if world_size != 2:
            raise ValueError("This implementation only supports world_size==2.")

        B, padded_length, D = input.shape  # padded_length is the padded sequence length
        output_list = []
        idx_info = []   # Stores per-document split indices (for backward)
        out_lengths = []  # Actual output length per document.
        max_out_len = 0

        for i in range(B):
            L_i = int(doc_lens[i])
            left_end = L_i // 4
            right_start = (3 * L_i) // 4
            if shard_id == 0:
                # Worker 0: first quarter and last quarter
                part1 = input[i, :left_end, :]
                part2 = input[i, right_start:L_i, :]
                out_i = torch.cat([part1, part2], dim=0)
            elif shard_id == 1:
                # Worker 1: middle half
                out_i = input[i, left_end:right_start, :]
            else:
                raise ValueError("shard_id must be 0 or 1.")

            cur_len = out_i.size(0)
            max_out_len = max(max_out_len, cur_len)
            out_lengths.append(cur_len)
            idx_info.append({'left_end': left_end, 'right_start': right_start, 'doc_len': L_i})
            output_list.append(out_i)

        # Pad each document’s sharded output to have shape [max_out_len, D]
        out_tensor = input.new_zeros((B, max_out_len, D))
        for i, out_i in enumerate(output_list):
            cur_len = out_i.size(0)
            out_tensor[i, :cur_len, :] = out_i

        # Save context for backward
        ctx.idx_info = idx_info
        ctx.out_lengths = out_lengths
        ctx.doc_lens = doc_lens
        ctx.shard_id = shard_id
        ctx.input_shape = input.shape  # [B, padded_length, D]

        return out_tensor

    @staticmethod
    def backward(ctx, grad_output):
        """
        Gets gradients from the sharded space back to the original input positions
        The output gradient is padded to the original padded length
        """
        B, padded_length, D = ctx.input_shape
        grad_input = torch.zeros((B, padded_length, D), device=grad_output.device, dtype=grad_output.dtype)
        idx_info = ctx.idx_info
        shard_id = ctx.shard_id

        for i, info in enumerate(idx_info):
            doc_len = info['doc_len']
            left_end = info['left_end']
            right_start = info['right_start']
            if shard_id == 0:
                num_first = left_end
                num_last = doc_len - right_start
                if num_first > 0:
                    grad_input[i, :left_end, :] = grad_output[i, :num_first, :]
                if num_last > 0:
                    grad_input[i, right_start:doc_len, :] = grad_output[i, num_first:num_first+num_last, :]
            elif shard_id == 1:
                grad_input[i, left_end:right_start, :] = grad_output[i, :right_start-left_end, :]
        return grad_input, None, None, None

def per_document_cp_shard_varlen(input, doc_lens, shard_id, world_size):
    return PerDocumentCPShardingVarLen.apply(input, doc_lens, shard_id, world_size)

# Backward helper for mapping sharded gradients back to full input
def per_document_cp_shard_varlen_backward(grad_sharded, doc_lens, padded_length, shard_id, world_size):
    B, _, D = grad_sharded.size()
    grad_input = grad_sharded.new_zeros(B, padded_length, D)
    for i in range(B):
        doc_len = int(doc_lens[i].item())
        left_end = doc_len // 4
        right_start = (3 * doc_len) // 4
        if shard_id == 0:
            num_first = left_end
            num_last = doc_len - right_start
            if num_first > 0:
                grad_input[i, :left_end, :] = grad_sharded[i, :num_first, :]
            if num_last > 0:
                grad_input[i, right_start:doc_len, :] = grad_sharded[i, num_first:num_first+num_last, :]
        else:  # shard_id == 1
            grad_input[i, left_end:right_start, :] = grad_sharded[i, :right_start-left_end, :]
    return grad_input

# Attention layer using per-document CP sharding
class AttnLayerWithPerDocCPSharding(Function):
    @staticmethod
    def forward(ctx,
                is_training,
                q,
                k,
                v,
                doc_lens,
                dropout_p,
                softmax_scale,
                cp_group,
                cp_stream):
        """
        Forward pass for a scaled dot-product attention layer that applies per-document CP sharding.
        q, k, v: tensors of shape [B, S, D] where S is the padded sequence length.
        doc_lens: 1D tensor of length B with the true document lengths.

        The layer first applies per-document CP sharding to q, k, v. Then it computes attention.
        """
        if softmax_scale is None:
            softmax_scale = q.size(-1) ** (-0.5)
        rank = get_distributed_rank(cp_group)
        world_size = 2  # Only two workers supported

        # Save padded sequence length from original q, k, v
        padded_length = q.size(1)

        # Apply per-document CP sharding to q, k, v
        sharded_q = per_document_cp_shard_varlen(q, doc_lens, rank, world_size)
        sharded_k = per_document_cp_shard_varlen(k, doc_lens, rank, world_size)
        sharded_v = per_document_cp_shard_varlen(v, doc_lens, rank, world_size)

        # Compute scaled dot-product attention on sharded tokens
        attn_scores = torch.matmul(sharded_q, sharded_k.transpose(-2, -1)) * softmax_scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        if is_training and dropout_p > 0:
            attn_probs = torch.nn.functional.dropout(attn_probs, p=dropout_p, training=True)
        output = torch.matmul(attn_probs, sharded_v)

        # Save context for backward
        ctx.save_for_backward(sharded_q, sharded_k, sharded_v, attn_probs, doc_lens)
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.rank = rank
        ctx.world_size = world_size
        # Save original padded length
        ctx.padded_length = padded_length

        return output

    @staticmethod
    def backward(ctx, dout):
        """
        Backward pass for the attention layer.
        Computes gradients for sharded q, k, v and then maps them back to the original full inputs.
        """
        sharded_q, sharded_k, sharded_v, attn_probs, doc_lens = ctx.saved_tensors
        dropout_p = ctx.dropout_p
        softmax_scale = ctx.softmax_scale
        rank = ctx.rank
        world_size = ctx.world_size
        padded_length = ctx.padded_length

        # Compute gradients for sharded_v
        d_sharded_v = torch.matmul(attn_probs.transpose(-2, -1), dout)

        # Gradients for attention scores
        d_attn = torch.matmul(dout, sharded_v.transpose(-2, -1))
        # Backprop through softmax
        d_softmax = attn_probs * (d_attn - (d_attn * attn_probs).sum(dim=-1, keepdim=True))
        d_softmax = d_softmax * softmax_scale

        d_sharded_q = torch.matmul(d_softmax, sharded_k)
        d_sharded_k = torch.matmul(d_softmax.transpose(-2, -1), sharded_q)

        # Map gradients from the sharded space back to full inputs
        dq = per_document_cp_shard_varlen_backward(d_sharded_q, doc_lens, padded_length, rank, world_size)
        dk = per_document_cp_shard_varlen_backward(d_sharded_k, doc_lens, padded_length, rank, world_size)
        dv = per_document_cp_shard_varlen_backward(d_sharded_v, doc_lens, padded_length, rank, world_size)

        # Forward inputs: (is_training, q, k, v, doc_lens, dropout_p, softmax_scale, cp_group, cp_stream)
        return None, dq, dk, dv, None, None, None, None, None

# Wrapper
def attn_layer_with_perdoc_cp_sharding(is_training, q, k, v, doc_lens, dropout_p, softmax_scale, cp_group, cp_stream):
    return AttnLayerWithPerDocCPSharding.apply(is_training, q, k, v, doc_lens, dropout_p, softmax_scale, cp_group, cp_stream)

# Test usage
if __name__ == "__main__":
    torch.manual_seed(0)

    # 4 documents with real text
    documents = [
        # Document about England (58 tokens)
        "England is a country in the United Kingdom, known for its rich history, iconic landmarks such as Big Ben, the Tower of London, and Buckingham Palace. Its cities, especially London, offer vibrant arts, literature, and culinary scenes, while the countryside features rolling hills, quaint villages, and centuries-old castles.",
        # Document about France (57 tokens)
        "France is renowned for its art, fashion, and culture, with a history that dates back centuries. The country offers stunning architecture like the Eiffel Tower and the Louvre, picturesque countryside, and gourmet cuisine that has influenced the world. From the vineyards of Bordeaux to the beaches of the Riviera, France captivates visitors with its romantic charm.",
        # Document about Germany (58 tokens)
        "Germany is a country known for its engineering, innovation, and rich cultural heritage. It boasts a blend of modern cities like Berlin and Munich with historic sites such as the Brandenburg Gate and ancient castles. Germany's economy is one of the strongest in Europe, and its traditions in music, philosophy, and automotive technology continue to influence the world.",
        # Document about Poland (35 tokens)
        "Poland is a nation with a vibrant history and dynamic cultural scene. From medieval castles and historic cities like Krakow and Gdansk to modern economic growth, Poland offers a fascinating mixture of old and new."
    ]

    # Tokenize documents (simple whitespace split)
    token_lists = [doc.split() for doc in documents]
    doc_lengths = [len(tokens) for tokens in token_lists]
    B = len(documents)
    S = max(doc_lengths)  # padded sequence length
    D = 8  # embedding dimension

    print("Document token counts:", doc_lengths)
    print("Padded sequence length:", S)

    # Create a padded input tensor.
    # For each document, simulate embeddings with random values.
    padded_inputs = torch.zeros(B, S, D)
    for i, tokens in enumerate(token_lists):
        L_i = len(tokens)
        padded_inputs[i, :L_i, :] = torch.randn(L_i, D)
    padded_inputs.requires_grad_()

    # Convert doc_lengths to tensor.
    doc_lens_tensor = torch.tensor(doc_lengths)

    dropout_p = 0.0  # Set dropout=0 for testing
    softmax_scale = None  # Let the layer compute it automatically.
    cp_group = {'rank': 0}  # Simulate worker 0.
    cp_stream = DummyCPStream()

    print("\n--- Simulating Worker 0 with Real Documents ---")
    out_worker0 = attn_layer_with_perdoc_cp_sharding(True, padded_inputs, padded_inputs, padded_inputs,
                                                     doc_lens_tensor, dropout_p, softmax_scale, cp_group, cp_stream)
    # print("Output (worker 0):", out_worker0)
    print("Output shape (worker 0):", out_worker0.shape)
    print("Expected output shape for worker0 is [4, 29, 8]")  # [4 documents, max 29 tokens, 8 features]

    loss0 = out_worker0.sum()
    loss0.backward()
    print("Gradients for padded_inputs after worker 0 backward:", padded_inputs.grad.shape)
    print("Expected gradient shape for padded_inputs should be [4, S, 8] where S =", S)

    # ----- Simulate Worker 1 -----
    padded_inputs.grad.zero_()
    cp_group['rank'] = 1  # Switch to worker 1.

    print("\n--- Simulating Worker 1 with Real Documents ---")
    out_worker1 = attn_layer_with_perdoc_cp_sharding(True, padded_inputs, padded_inputs, padded_inputs,
                                                     doc_lens_tensor, dropout_p, softmax_scale, cp_group, cp_stream)
    # print("Output (worker 1):", out_worker1)
    print("Output shape (worker 1):", out_worker1.shape)
    print("Expected output shape for worker1 is [4, 29, 8]")  # [4 documents, max 29 tokens, 8 features]

    loss1 = out_worker1.sum()
    loss1.backward()
    print("Gradients for padded_inputs after worker 1 backward:", padded_inputs.grad.shape)
    print("Expected gradient shape for padded_inputs should be [4, S, 8] where S =", S)


Document token counts: [48, 56, 58, 35]
Padded sequence length: 58

--- Simulating Worker 0 with Real Documents ---
Output shape (worker 0): torch.Size([4, 29, 8])
Expected output shape for worker0 is [4, 29, 8]
Gradients for padded_inputs after worker 0 backward: torch.Size([4, 58, 8])
Expected gradient shape for padded_inputs should be [4, S, 8] where S = 58

--- Simulating Worker 1 with Real Documents ---
Output shape (worker 1): torch.Size([4, 29, 8])
Expected output shape for worker1 is [4, 29, 8]
Gradients for padded_inputs after worker 1 backward: torch.Size([4, 58, 8])
Expected gradient shape for padded_inputs should be [4, S, 8] where S = 58


# Now try with a distributed environment where number of workers is not necessarily 2 and with CUDA streams

In [None]:
import torch
from torch.autograd import Function
import math

# Dummies for simulating distributed environment
def get_distributed_rank(cp_group):
    # For testing, assume cp_group is a dict with key 'rank'
    return cp_group['rank']

class DummyCPStream:
    def wait_stream(self, stream):
        stream.synchronize()  # simple synchronization

# Per-document CP sharding (variable-length) implementation for general world_size
class PerDocumentCPShardingVarLenGeneral(Function):
    # store the indices_info in a class variable so that it can be retrieved later
    last_indices_info = None

    @staticmethod
    def forward(ctx, input, doc_lens, shard_id, world_size):
        """
        For each document in the batch:
          - Let N = 2 * world_size.
          - Compute seg_len_base = doc_len // N and rem = doc_len % N
          - For worker with rank r:
              * Segment 0 indices:
                    start0 = r * seg_len_base + min(r, rem)
                    end0   = start0 + seg_len_base + (1 if r < rem else 0)
              * Segment 1 indices:
                    Let r2 = N - r - 1.
                    start1 = r2 * seg_len_base + min(r2, rem)
                    end1   = start1 + seg_len_base + (1 if r2 < rem else 0)
          - For each document, return two tensors: seg0 and seg1, each padded to the maximum length across the batch
        """
        B, padded_length, D = input.shape
        seg0_list = []
        seg1_list = []
        seg0_lengths = []
        seg1_lengths = []
        indices_info = []
        N = 2 * world_size
        for i in range(B):
            L_i = int(doc_lens[i])
            seg_len_base = L_i // N
            rem = L_i % N
            r = shard_id
            extra0 = 1 if r < rem else 0
            start0 = r * seg_len_base + min(r, rem)
            end0 = start0 + seg_len_base + extra0
            r2 = N - r - 1
            extra1 = 1 if r2 < rem else 0
            start1 = r2 * seg_len_base + min(r2, rem)
            end1 = start1 + seg_len_base + extra1
            seg0 = input[i, :L_i, :][start0:end0, :]
            seg1 = input[i, :L_i, :][start1:end1, :]
            seg0_list.append(seg0)
            seg1_list.append(seg1)
            seg0_lengths.append(seg0.size(0))
            seg1_lengths.append(seg1.size(0))
            indices_info.append({'start0': start0, 'end0': end0,
                                 'start1': start1, 'end1': end1,
                                 'doc_len': L_i})
        max_seg0 = max(seg0_lengths)
        max_seg1 = max(seg1_lengths)
        seg0_padded = input.new_zeros((B, max_seg0, D))
        seg1_padded = input.new_zeros((B, max_seg1, D))
        for i, seg0 in enumerate(seg0_list):
            seg0_padded[i, :seg0.size(0), :] = seg0
        for i, seg1 in enumerate(seg1_list):
            seg1_padded[i, :seg1.size(0), :] = seg1
        # Save indices_info in both ctx and as a class variable for later retrieval
        ctx.indices_info = indices_info
        PerDocumentCPShardingVarLenGeneral.last_indices_info = indices_info
        ctx.input_shape = input.shape  # [B, padded_length, D]
        return seg0_padded, seg1_padded

    @staticmethod
    def backward(ctx, grad_seg0, grad_seg1):
        """
        For each document, scatter gradients from grad_seg0 and grad_seg1 back into
        a gradient tensor of shape [B, padded_length, D] using the saved indices
        """
        B, padded_length, D = ctx.input_shape
        grad_input = grad_seg0.new_zeros(B, padded_length, D)
        indices_info = ctx.indices_info
        for i, info in enumerate(indices_info):
            start0 = info['start0']
            end0 = info['end0']
            start1 = info['start1']
            end1 = info['end1']
            seg0_len = end0 - start0
            seg1_len = end1 - start1
            grad_input[i, start0:end0, :] = grad_seg0[i, :seg0_len, :]
            grad_input[i, start1:end1, :] = grad_seg1[i, :seg1_len, :]
        return grad_input, None, None, None

def per_document_cp_shard_varlen_general(input, doc_lens, shard_id, world_size):
    return PerDocumentCPShardingVarLenGeneral.apply(input, doc_lens, shard_id, world_size)

# Backward helper for mapping gradients from segments back to full input
def per_document_cp_shard_varlen_general_backward(grad_seg0, grad_seg1, indices_info, padded_length):
    """
    For each document, using saved indices from forward, scatter gradients from
    grad_seg0 and grad_seg1 into a gradient tensor of shape [B, padded_length, D]
    """
    B, _, D = grad_seg0.size()
    grad_input = grad_seg0.new_zeros(B, padded_length, D)
    for i, info in enumerate(indices_info):
        start0 = info['start0']
        end0 = info['end0']
        start1 = info['start1']
        end1 = info['end1']
        seg0_len = end0 - start0
        seg1_len = end1 - start1
        grad_input[i, start0:end0, :] = grad_seg0[i, :seg0_len, :]
        grad_input[i, start1:end1, :] = grad_seg1[i, :seg1_len, :]
    return grad_input

# ------------------------------------------------------------------
# Attention layer using per-document CP sharding with CUDA streams
# ------------------------------------------------------------------
class AttnLayerWithPerDocCPShardingGeneral(Function):
    @staticmethod
    def forward(ctx,
                is_training,
                q,
                k,
                v,
                doc_lens,
                dropout_p,
                softmax_scale,
                cp_group,
                cp_stream,
                world_size):
        """
        Scaled dot-product attention layer that applies per-document CP sharding with CUDA streams
        q, k, v: tensors of shape [B, S, D] (S = padded sequence length)
        doc_lens: 1D tensor of length B with true document lengths.
        world_size: number of workers in the distributed environment.
        """
        if softmax_scale is None:
            softmax_scale = q.size(-1) ** (-0.5)
        rank = get_distributed_rank(cp_group)
        B, padded_length, D = q.shape

        # Apply per-document CP sharding to q, k, v using the general function
        q_seg0, q_seg1 = per_document_cp_shard_varlen_general(q, doc_lens, rank, world_size)
        k_seg0, k_seg1 = per_document_cp_shard_varlen_general(k, doc_lens, rank, world_size)
        v_seg0, v_seg1 = per_document_cp_shard_varlen_general(v, doc_lens, rank, world_size)

        # Create two CUDA streams
        stream0 = torch.cuda.Stream()
        stream1 = torch.cuda.Stream()

        # Process segment 0 and segment 1 concurrently
        with torch.cuda.stream(stream0):
            attn_scores0 = torch.matmul(q_seg0, k_seg0.transpose(-2, -1)) * softmax_scale
            attn_probs0 = torch.softmax(attn_scores0, dim=-1)
            if is_training and dropout_p > 0:
                attn_probs0 = torch.nn.functional.dropout(attn_probs0, p=dropout_p, training=True)
            out_seg0 = torch.matmul(attn_probs0, v_seg0)

        with torch.cuda.stream(stream1):
            attn_scores1 = torch.matmul(q_seg1, k_seg1.transpose(-2, -1)) * softmax_scale
            attn_probs1 = torch.softmax(attn_scores1, dim=-1)
            if is_training and dropout_p > 0:
                attn_probs1 = torch.nn.functional.dropout(attn_probs1, p=dropout_p, training=True)
            out_seg1 = torch.matmul(attn_probs1, v_seg1)

        # Synchronize streams
        torch.cuda.synchronize()

        # Concatenate the two segments along the sequence dimension
        output = torch.cat([out_seg0, out_seg1], dim=1)  # shape: [B, (max_seg0 + max_seg1), D]

        # Save context for backward.
        # Save all necessary tensors for computing gradients through attention.
        ctx.save_for_backward(q_seg0, q_seg1, k_seg0, k_seg1, v_seg0, v_seg1, doc_lens, attn_probs0, attn_probs1)
        # Retrieve indices_info stored by the sharding function.
        ctx.indices_info = PerDocumentCPShardingVarLenGeneral.last_indices_info
        ctx.dropout_p = dropout_p
        ctx.softmax_scale = softmax_scale
        ctx.cp_group = cp_group
        ctx.cp_stream = cp_stream
        ctx.rank = rank
        ctx.world_size = world_size
        ctx.input_shape = q.shape  # original [B, S, D]
        return output

    @staticmethod
    def backward(ctx, dout):
        """
        Splits the gradient of the concatenated output into two parts corresponding to
        the two segments and computes gradients for q, k, and v.
        """
        # Retrieve saved tensors.
        q_seg0, q_seg1, k_seg0, k_seg1, v_seg0, v_seg1, doc_lens, attn_probs0, attn_probs1 = ctx.saved_tensors
        dropout_p = ctx.dropout_p
        softmax_scale = ctx.softmax_scale
        rank = ctx.rank
        world_size = ctx.world_size
        padded_length = ctx.input_shape[1]

        # Determine lengths of segments (from the sharded tensors).
        L0 = q_seg0.size(1)
        L1 = q_seg1.size(1)
        # Split dout into segments corresponding to seg0 and seg1.
        dout_seg0 = dout[:, :L0, :]  # shape: [B, L0, D]
        dout_seg1 = dout[:, L0:L0+L1, :]  # shape: [B, L1, D]

        # --- Compute gradients for segment 0 ---
        # dV0 = A0^T @ dout_seg0
        d_v_seg0 = torch.matmul(attn_probs0.transpose(-2, -1), dout_seg0)
        # dA0 = dout_seg0 @ V0^T
        dA0 = torch.matmul(dout_seg0, v_seg0.transpose(-2, -1))
        # dZ0 = A0 * (dA0 - sum(dA0 * A0, dim=-1, keepdim=True))
        dZ0 = attn_probs0 * (dA0 - (dA0 * attn_probs0).sum(dim=-1, keepdim=True))
        dZ0 = dZ0 * softmax_scale
        # dQ0 = dZ0 @ K0
        d_q_seg0 = torch.matmul(dZ0, k_seg0)
        # dK0 = dZ0^T @ Q0
        d_k_seg0 = torch.matmul(dZ0.transpose(-2, -1), q_seg0)

        # --- Compute gradients for segment 1 ---
        d_v_seg1 = torch.matmul(attn_probs1.transpose(-2, -1), dout_seg1)
        dA1 = torch.matmul(dout_seg1, v_seg1.transpose(-2, -1))
        dZ1 = attn_probs1 * (dA1 - (dA1 * attn_probs1).sum(dim=-1, keepdim=True))
        dZ1 = dZ1 * softmax_scale
        d_q_seg1 = torch.matmul(dZ1, k_seg1)
        d_k_seg1 = torch.matmul(dZ1.transpose(-2, -1), q_seg1)

        # Map gradients from segments back to full inputs using the backward helper.
        dq = per_document_cp_shard_varlen_general_backward(d_q_seg0, d_q_seg1, ctx.indices_info, padded_length)
        dk = per_document_cp_shard_varlen_general_backward(d_k_seg0, d_k_seg1, ctx.indices_info, padded_length)
        dv = per_document_cp_shard_varlen_general_backward(d_v_seg0, d_v_seg1, ctx.indices_info, padded_length)
        return None, dq, dk, dv, None, None, None, None, None, None

def attn_layer_with_perdoc_cp_sharding_general(is_training, q, k, v, doc_lens, dropout_p, softmax_scale, cp_group, cp_stream, world_size):
    return AttnLayerWithPerDocCPShardingGeneral.apply(is_training, q, k, v, doc_lens, dropout_p, softmax_scale, cp_group, cp_stream, world_size)

# Test usage for the general implementation with fake documents
if __name__ == "__main__":
    torch.manual_seed(0)

    # Create dummy inputs: q, k, v of shape [B, S, D]
    B = 4
    S = 160  # padded sequence length (assume documents up to 160 tokens)
    D = 8
    q = torch.randn(B, S, D, requires_grad=True)
    k = torch.randn(B, S, D, requires_grad=True)
    v = torch.randn(B, S, D, requires_grad=True)
    # Fake document lengths
    doc_lens = torch.tensor([158, 127, 160, 155])

    dropout_p = 0.1  # Doesn't matter for gradient computation here
    softmax_scale = None
    world_size = 4  # Number of workers
    # Simulate for worker 1
    cp_group = {'rank': 1}
    cp_stream = DummyCPStream()

    print("Document token counts:", doc_lens.tolist())
    print("Padded sequence length:", S)

    out_general = attn_layer_with_perdoc_cp_sharding_general(True, q, k, v, doc_lens, dropout_p, softmax_scale, cp_group, cp_stream, world_size)
    print("Output shape (general):", out_general.shape)
    # Expected output shape: [B, L_worker, D] where L_worker = (max_seg0 + max_seg1) for the batch
    print("Expected output shape: [4, 40, 8]")

    loss_general = out_general.sum()
    loss_general.backward()
    print("Gradients for q (general) shape:", q.grad.shape)
    print("Expected gradients for q shape: [4, 160, 8]")


Document token counts: [158, 127, 160, 155]
Padded sequence length: 160
Output shape (general): torch.Size([4, 40, 8])
Expected output shape: [4, 40, 8]
Gradients for q (general) shape: torch.Size([4, 160, 8])
Expected gradients for q shape: [4, 160, 8]
