In [25]:
import torch
import math

def batched_simt_reshape_with_offset(x: torch.Tensor, offset: int, row_length: int, pad_value: int = 0) -> torch.Tensor:
    """
    Reshapes a batched 3D tensor into a new 4D tensor based on an offset value,
    grouping elements and padding each row to the specified length while preserving the embedding dimension.
    This version is optimized for SIMT architectures and handles batched inputs with embeddings.

    Args:
    x (torch.Tensor): Input batched 3D tensor of shape (batch_size, sequence_length, d_embed)
    offset (int): Number of elements to group before padding
    row_length (int): Desired length of each row in the output
    pad_value (int): Value to use for padding (default: 0)

    Returns:
    torch.Tensor: Reshaped and padded batched 4D tensor
    """
    batch_size, input_length, d_embed = x.shape
    num_rows = math.ceil(input_length / row_length)
    
    # Create output tensor filled with pad_value
    output = torch.full((batch_size, num_rows, row_length, d_embed), pad_value, dtype=x.dtype, device=x.device)
    
    # Create a mask for non-pad values (assuming pad_value is applied across the entire embedding)
    mask = (x != pad_value).any(dim=-1)
    
    # Get the indices of non-pad values
    batch_indices, seq_indices = torch.where(mask)
    
    # Calculate row and column indices for the output
    row_indices = (seq_indices // offset) % num_rows
    col_indices = seq_indices % offset
    
    # Filter out column indices that are out of bounds
    valid_indices = col_indices < row_length
    batch_indices = batch_indices[valid_indices]
    row_indices = row_indices[valid_indices]
    col_indices = col_indices[valid_indices]
    seq_indices = seq_indices[valid_indices]
    
    # Use advanced indexing to fill the output tensor
    output[batch_indices, row_indices, col_indices] = x[batch_indices, seq_indices]
    
    return output

In [29]:
# Example usage with batched inputs including embedding dimension
d_embed = 3
x1 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]])
x2 = torch.tensor([[1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]])
x3 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [0, 0, 0], [0, 0, 0]])
x4 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]])

# Create a batched input
batched_input = torch.stack([x1, x2, x3, x4])

row_length = 4

result = batched_simt_reshape_with_offset(batched_input, offset=1, row_length=row_length)

print(result.shape)

# You can access individual results like this:
print("Result for x1:")
print(result[0])
print("Result for x2:")
print(result[1])
print("Result for x3:")
print(result[2])
print("Result for x4:")
print(result[3])

torch.Size([4, 2, 4, 3])
Result for x1:
tensor([[[ 7,  8,  9],
         [ 0,  0,  0],
         [ 0,  0,  0],
         [ 0,  0,  0]],

        [[10, 11, 12],
         [ 0,  0,  0],
         [ 0,  0,  0],
         [ 0,  0,  0]]])
Result for x2:
tensor([[[1, 2, 3],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[4, 5, 6],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]])
Result for x3:
tensor([[[13, 14, 15],
         [ 0,  0,  0],
         [ 0,  0,  0],
         [ 0,  0,  0]],

        [[16, 17, 18],
         [ 0,  0,  0],
         [ 0,  0,  0],
         [ 0,  0,  0]]])
Result for x4:
tensor([[[7, 8, 9],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[4, 5, 6],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]])
