# Testing clamp_col Function

In [1]:
import torch

def clamp_tensor(tensor: torch.Tensor, bounds: torch.Tensor) -> torch.Tensor:
    """
    Clamps the values of a tensor within given bounds.

    Args:
        tensor (Tensor): The tensor to clamp.
        bounds (Tensor): A tensor containing the lower and upper bounds.

    Returns:
        Tensor: The clamped tensor.
    """
    if bounds.dim() == 2 and bounds.size(0) == 2:
        # Common bounds for all samples: bounds shape is (2, dim)
        lower_bounds, upper_bounds = bounds
    elif bounds.dim() == 3 and bounds.size(1) == 2:
        # Per-sample bounds: bounds shape is (batch_size, 2, dim)
        lower_bounds = bounds[:, 0, :]
        upper_bounds = bounds[:, 1, :]
    else:
        raise ValueError("Invalid bounds dimension. Expected bounds of shape (2, dim) or (batch_size, 2, dim).")
    
    # Perform clamping
    return torch.max(torch.min(tensor, upper_bounds), lower_bounds)

# Original clamp_col function
def clamp_col(t: torch.Tensor, bounds: torch.Tensor) -> torch.Tensor:
    if len(bounds.shape) == 2:
        # Common bounds for all samples
        dim = bounds.shape[-1]
        for i in range(dim):
            t[:, i] = torch.clamp(t[:, i], min=bounds[0][i], max=bounds[1][i])
    else:
        # Per-sample bounds
        assert bounds.shape[0] == t.shape[0], "Batch size of bounds and tensor must match."
        dim = bounds.shape[-1]
        for i in range(dim):
            for j in range(bounds.shape[0]):
                t[j, i] = torch.clamp(t[j, i], min=bounds[j][0][i], max=bounds[j][1][i])
    return t




In [2]:


# Test Case 1: Common bounds for all samples
batch_size = 10
dim = 5
t = torch.randn(batch_size, dim)
bounds = torch.tensor([
    [-1.0, -2.0, -3.0, -4.0, -5.0],  # Lower bounds
    [1.0, 2.0, 3.0, 4.0, 5.0]        # Upper bounds
])


t1 = t.clone()
t2 = t.clone()


output1 = clamp_tensor(t1, bounds)
output2 = clamp_col(t2, bounds)


if torch.allclose(output1, output2):
    print("Test Case 1 Passed: Outputs are identical.")
else:
    print("Test Case 1 Failed: Outputs differ.")
    print("Difference:\n", output1 - output2)


Test Case 1 Passed: Outputs are identical.


In [3]:
# Test Case 2: Per-sample bounds
# Generate per-sample bounds of shape (batch_size, 2, dim)
lower_bounds = torch.linspace(-1, -5, steps=dim).repeat(batch_size, 1)
upper_bounds = torch.linspace(1, 5, steps=dim).repeat(batch_size, 1)
bounds_per_sample = torch.stack([lower_bounds, upper_bounds], dim=1)  # Shape: (batch_size, 2, dim)


t1 = t.clone()
t2 = t.clone()


output1 = clamp_tensor(t1, bounds_per_sample)
output2 = clamp_col(t2, bounds_per_sample)


if torch.allclose(output1, output2):
    print("Test Case 2 Passed: Outputs are identical.")
else:
    print("Test Case 2 Failed: Outputs differ.")
    print("Difference:\n", output1 - output2)



Test Case 2 Passed: Outputs are identical.
