In [None]:
import torch

def compute_overlap_length(left: torch.Tensor, right: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
    """
    Computes the length of overlap for each data point in [left, right] with the `valid` binary tensor.

    Args:
        left (torch.Tensor): Tensor of shape (N,) representing the left bounds of the intervals.
        right (torch.Tensor): Tensor of shape (N,) representing the right bounds of the intervals.
        valid (torch.Tensor): Binary tensor of shape (T,) indicating valid positions (1 for valid, 0 otherwise).

    Returns:
        torch.Tensor: Tensor of shape (N,) representing the length of overlap for each interval in [left, right] with valid.
    """
    # Ensure left and right bounds are within the range of T
    T = valid.size(0)
    left_clamped = torch.clamp(left, 0, T - 1)
    right_clamped = torch.clamp(right, 0, T - 1)

    # Create a mask for all intervals
    N = left.size(0)
    range_tensor = torch.arange(T, device=valid.device).unsqueeze(0).expand(N, -1)

    # Generate masks for each interval
    interval_masks = (range_tensor >= left_clamped.unsqueeze(1)) & (range_tensor <= right_clamped.unsqueeze(1))

    # Compute overlaps using batch matrix multiplication
    overlaps = torch.matmul(interval_masks.float(), valid.float().unsqueeze(1)).squeeze(1).long()

    return overlaps

# Example usage
if __name__ == "__main__":
    left = torch.tensor([3], device="cuda")
    right = torch.tensor([8], device="cuda")
    valid = torch.tensor([0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0], device="cuda")

    result = compute_overlap_length(left, right, valid)
    print(result)

tensor([3], device='cuda:0')
