# Differentiable Bubble Sort

Differentiable implementation of bubble sort with configurable (learnable) comparator function


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [5]:
# Check PyTorch installation and device availability
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu126
Using device: cuda
GPU: NVIDIA GeForce RTX 4070 Laptop GPU


## Differentiable Swap Function

The core of differentiable sorting is implementing a continuous approximation of the swap operation. We use linear interpolation to create a differentiable swap function.

**Mathematical Foundation:**
Given two elements `a` and `b`, and a swap parameter `t`:

$$
\text{new}_a = a \cdot t + b \cdot (1 - t)
$$
$$
\text{new}_b = b \cdot t + a \cdot (1 - t)  
$$

**Key Properties:**
- When `t = 0`: Elements are fully swapped (`a ↔ b`)
- When `t = 1`: Elements remain unchanged  
- Values between 0 and 1 create partial swaps, enabling gradient flow

**Alternative Approaches:**
This notebook uses linear interpolation for simplicity and interpretability. Other differentiable sorting strategies include:
- [Softmax-based approximations](https://github.com/johnhw/differentiable_sorting) 
- [Optimal transport methods](https://arxiv.org/pdf/1905.11885.pdf)
- [Higher-dimensional projections](https://arxiv.org/pdf/2002.08871.pdf)

Each approach has different trade-offs in terms of approximation quality, computational efficiency, and gradient properties.

In [6]:
def swap(x, i, j, t=None):
    """
    Differentiable swap function using linear interpolation.

    Args:
        x: torch.Tensor - Input tensor of shape [sequence_length, feature_size]
        i: int - First index to swap
        j: int - Second index to swap
        t: torch.Tensor - Swap parameter (scalar or tensor)
                         If None, defaults to full swap (t=0)

    Returns:
        torch.Tensor - Tensor with elements at positions i and j interpolated
    """
    if t is None:
        t = torch.tensor(0.0)

    # Ensure t is a scalar for consistent behavior
    if isinstance(t, torch.Tensor) and t.numel() > 1:
        t = t.mean()  # Use mean if t is a tensor

    # Clone to avoid in-place operations
    result = x.clone()

    # Linear interpolation swap
    # t=0: full swap, t=1: no swap
    result[i] = t * x[i] + (1 - t) * x[j]
    result[j] = t * x[j] + (1 - t) * x[i]

    return result

In [7]:
# Test the swap function with gradients
x = torch.tensor(
    [[1, 1, 0, 0], [1, 1, 0, 1], [1, 0, 0, 0], [1, 0, 1, 0], [1, 1, 1, 0]],
    dtype=torch.float32,
    requires_grad=True,
)

t = torch.zeros_like(x, requires_grad=True)
i, j = 1, 2

# Forward pass
z = swap(x, i, j, t)
print("Swapped tensor:")
print(z)

# Compute gradients
loss = z.sum()
loss.backward()

print("\nGradient w.r.t. input x:")
print(x.grad)
print("\nGradient w.r.t. swap parameter t:")
print(t.grad)

Swapped tensor:
tensor([[1., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 1., 0., 1.],
        [1., 0., 1., 0.],
        [1., 1., 1., 0.]], grad_fn=<CopySlices>)

Gradient w.r.t. input x:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

Gradient w.r.t. swap parameter t:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])


In [8]:
# Simple test case with 1D vectors
x = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float32, requires_grad=True)
t = torch.zeros_like(x, requires_grad=True)
i, j = 0, 1

z = swap(x, i, j, t)
print("Original tensor:")
print(x.detach())
print("Swapped tensor:")
print(z)

# Test gradients
loss = z.sum()
loss.backward()
print("Gradient w.r.t. x:")
print(x.grad)
print("Gradient w.r.t. t:")
print(t.grad)

Original tensor:
tensor([[1.],
        [2.],
        [3.]])
Swapped tensor:
tensor([[2.],
        [1.],
        [3.]], grad_fn=<CopySlices>)
Gradient w.r.t. x:
tensor([[1.],
        [1.],
        [1.]])
Gradient w.r.t. t:
tensor([[0.],
        [0.],
        [0.]])


## Differentiable Bubble Sort Algorithm

The bubble sort algorithm iteratively compares adjacent elements and swaps them if they're in the wrong order. In our differentiable version:

1. **Comparator Function**: Instead of hard comparisons (`a > b`), we use a learned function that outputs a continuous swap parameter `t ∈ [0,1]`
2. **Soft Swapping**: Elements are interpolated rather than discretely swapped, maintaining differentiability
3. **End-to-End Learning**: The entire sorting process remains differentiable, allowing gradient-based optimization

**Key Innovation**: The swap parameter `t` replaces discrete conditional logic, enabling backpropagation through the sorting algorithm.

In [9]:
def bubble_sort(x, cmp_fun):
    """
    Differentiable bubble sort implementation.

    Args:
        x: torch.Tensor - Input tensor of shape [sequence_length, feature_size]
        cmp_fun: callable - Comparator function that takes a tensor of shape
                           [batch_size, 2, feature_size] and returns swap parameters

    Returns:
        torch.Tensor - Sorted tensor (approximately, depending on comparator quality)
    """
    x_len = x.shape[0]

    # Bubble sort algorithm with differentiable comparisons
    for i in range(x_len):
        for j in range(i + 1, x_len):
            # Prepare comparison input: stack elements i and j
            cmp_input = torch.stack([x[i], x[j]], dim=0).unsqueeze(
                0
            )  # [1, 2, feature_size]

            # Get swap parameter from comparator
            t = cmp_fun(cmp_input)  # Returns [batch_size] tensor

            # Apply differentiable swap
            x = swap(x, i, j, t)

    return x

### Sample Comparator Function

This is a simple rule-based comparator for demonstration. It sorts based on the sum of elements in each vector (number of 1s in binary vectors).

**Note**: The `torch.sign` operation is non-differentiable, making this comparator unsuitable for end-to-end learning. It serves only as a test to verify our sorting logic works correctly.

In [10]:
def sample_comparator(x):
    """
    Rule-based comparator that sorts by sum of elements.

    Args:
        x: torch.Tensor - Shape [batch_size, 2, feature_size]

    Returns:
        torch.Tensor - Swap parameters of shape [batch_size]
    """
    # Sum elements for each vector
    sums = x.sum(dim=-1)  # [batch_size, 2]
    diff = sums[:, 0] - sums[:, 1]  # Compare first vs second element

    # Convert to swap parameter: 0 = swap, 1 = don't swap
    # If first element sum > second element sum, don't swap (t=1)
    # If first element sum < second element sum, swap (t=0)
    return 1 - (torch.sign(diff) + 1) / 2


# Test the sample comparator
x_test = torch.tensor(
    [[1, 0, 0, 0], [1, 1, 1, 1]], dtype=torch.float32  # sum = 1  # sum = 4
)

cmp_input = x_test.unsqueeze(0)  # Add batch dimension
cmp_result = sample_comparator(cmp_input)
print(f"Comparator input shapes: {cmp_input.shape}")
print(f"Element sums: {x_test.sum(dim=-1)}")
print(f"Comparator output (swap parameter): {cmp_result}")
print("Interpretation: t=1 means 'don't swap' (first element should come first)")

Comparator input shapes: torch.Size([1, 2, 4])
Element sums: tensor([1., 4.])
Comparator output (swap parameter): tensor([1.])
Interpretation: t=1 means 'don't swap' (first element should come first)


In [11]:
# Test bubble sort with simple 1D case
x = torch.tensor([[3.0], [1.0], [2.0]], dtype=torch.float32, requires_grad=True)
print("Original tensor:")
print(x.detach().flatten())

sorted_x = bubble_sort(x, sample_comparator)
print("Sorted tensor:")
print(sorted_x.detach().flatten())

# Test gradient computation
loss = sorted_x.sum()
loss.backward()
print("Gradient w.r.t. input:")
print(x.grad.flatten())

Original tensor:
tensor([3., 1., 2.])
Sorted tensor:
tensor([1., 2., 3.])
Gradient w.r.t. input:
tensor([1., 1., 1.])


In [12]:
# Test with binary vectors (sorting by number of 1s)
x = torch.tensor(
    [[1, 1, 0], [1, 0, 0], [1, 1, 1]],  # 2 ones  # 1 one  # 3 ones
    dtype=torch.float32,
    requires_grad=True,
)

print("Original tensor (rows sorted by number of 1s):")
print(x.detach())
print("Number of 1s per row:", x.sum(dim=1).detach())

sorted_x = bubble_sort(x, sample_comparator)
print("\nSorted tensor:")
print(sorted_x.detach())
print("Number of 1s per row:", sorted_x.sum(dim=1).detach())

# Compute gradients
loss = sorted_x.sum()
loss.backward()
print("\nGradient w.r.t. input:")
print(x.grad)

Original tensor (rows sorted by number of 1s):
tensor([[1., 1., 0.],
        [1., 0., 0.],
        [1., 1., 1.]])
Number of 1s per row: tensor([2., 1., 3.])

Sorted tensor:
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
Number of 1s per row: tensor([1., 2., 3.])

Gradient w.r.t. input:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [13]:
# Test with more complex binary vectors
x = torch.tensor(
    [
        [1, 1, 0, 0],  # 2 ones
        [1, 1, 0, 1],  # 3 ones
        [1, 0, 0, 0],  # 1 one
        [1, 0, 1, 0],  # 2 ones
        [1, 1, 1, 0],  # 3 ones
    ],
    dtype=torch.float32,
    requires_grad=True,
)

print("Original tensor:")
print(x.detach())
print("Number of 1s per row:", x.sum(dim=1).detach())

sorted_x = bubble_sort(x, sample_comparator)
print("\nSorted tensor (should be ordered by number of 1s):")
print(sorted_x.detach())
print("Number of 1s per row:", sorted_x.sum(dim=1).detach())

# Compute gradients
loss = sorted_x.sum()
loss.backward()
print("\nGradient w.r.t. input:")
print(x.grad)

Original tensor:
tensor([[1., 1., 0., 0.],
        [1., 1., 0., 1.],
        [1., 0., 0., 0.],
        [1., 0., 1., 0.],
        [1., 1., 1., 0.]])
Number of 1s per row: tensor([2., 3., 1., 2., 3.])

Sorted tensor (should be ordered by number of 1s):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.5000, 0.5000, 0.0000],
        [1.0000, 0.5000, 0.5000, 0.0000],
        [1.0000, 1.0000, 0.5000, 0.5000],
        [1.0000, 1.0000, 0.5000, 0.5000]])
Number of 1s per row: tensor([1., 2., 2., 3., 3.])

Gradient w.r.t. input:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])


In [14]:
# Generate test data: lower triangular matrix
# Each row has an increasing number of 1s (from 1 to 10)
def generate_data():
    return np.tril(np.ones((10, 10), dtype=np.float32))


actual_data = generate_data()
print("Target sorted data (lower triangular matrix):")
print("Each row should have 1, 2, 3, ..., 10 ones respectively")
print(actual_data)

Target sorted data (lower triangular matrix):
Each row should have 1, 2, 3, ..., 10 ones respectively
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]


In [15]:
# Create shuffled version for testing
shuffled_data = generate_data()
np.random.shuffle(shuffled_data)
print("Shuffled data to be sorted:")
print("Number of 1s per row:", shuffled_data.sum(axis=1))
print(shuffled_data)

Shuffled data to be sorted:
Number of 1s per row: [ 1.  6.  3.  7. 10.  9.  4.  2.  5.  8.]
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]]


In [16]:
# Test sorting with the rule-based comparator
shuffled_tensor = torch.tensor(shuffled_data, dtype=torch.float32)
sorted_result = bubble_sort(shuffled_tensor, sample_comparator)

print("Sorted result:")
print("Number of 1s per row:", sorted_result.sum(dim=1).detach())
print(sorted_result.detach())

print("\nVerification - should be all zeros if sorting is perfect:")
difference = sorted_result.detach() - torch.tensor(actual_data)
print("Max absolute difference:", torch.abs(difference).max().item())

Sorted result:
Number of 1s per row: tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

Verification - should be all zeros if sorting is perfect:
Max absolute difference: 0.0


## Learnable Neural Comparator

The key innovation is replacing the rule-based comparator with a learnable neural network. This enables:

1. **End-to-end Learning**: The sorting criterion can be learned from data rather than hand-coded
2. **Complex Patterns**: Neural networks can learn sophisticated comparison functions beyond simple rules
3. **Gradient Flow**: Unlike rule-based comparators with discrete operations, neural comparators maintain differentiability

**Architecture**: A simple MLP that takes two concatenated feature vectors and outputs a swap probability.

In [17]:
class ComparatorBlock(nn.Module):
    """
    Neural network comparator for differentiable sorting.

    Takes two feature vectors and outputs a swap parameter t ∈ [0,1].
    - t ≈ 0: Swap the elements (first element should come after second)
    - t ≈ 1: Don't swap (first element should come before second)
    """

    def __init__(self, feature_size, hidden_size=10):
        super(ComparatorBlock, self).__init__()

        # Input size is 2 * feature_size (two concatenated vectors)
        input_size = 2 * feature_size

        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid(),  # Output in [0,1] range
        )

        # Initialize weights using He initialization for ReLU networks
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
                nn.init.zeros_(layer.bias)

    def forward(self, x):
        """
        Args:
            x: torch.Tensor - Shape [batch_size, 2, feature_size]

        Returns:
            torch.Tensor - Swap parameters of shape [batch_size]
        """
        batch_size, _, feature_size = x.shape

        # Flatten the two vectors and concatenate them
        x_flat = x.reshape(batch_size, -1)  # [batch_size, 2 * feature_size]

        # Pass through network
        output = self.network(x_flat).squeeze(-1)  # [batch_size]

        return output

In [18]:
# Test the neural comparator before training
feature_size = shuffled_data.shape[-1]
learned_comparator = ComparatorBlock(feature_size)

# Test forward pass
test_input = torch.zeros((1, 2, feature_size))
test_output = learned_comparator(test_input)
print(f"Neural comparator output shape: {test_output.shape}")
print(f"Output value (random initialization): {test_output.item():.4f}")

# Test sorting with untrained comparator
shuffled_tensor = torch.tensor(shuffled_data, dtype=torch.float32)
untrained_result = bubble_sort(shuffled_tensor, learned_comparator)

# Compute initial loss (L2 distance to target)
target_tensor = torch.tensor(actual_data, dtype=torch.float32)
initial_loss = F.mse_loss(untrained_result, target_tensor)
print(f"\nInitial loss (before training): {initial_loss.item():.4f}")
print("This should be high since the comparator is randomly initialized")

Neural comparator output shape: torch.Size([1])
Output value (random initialization): 0.5000

Initial loss (before training): 0.1936
This should be high since the comparator is randomly initialized


In [19]:
# Test gradient computation through the learned comparator
learned_comparator = ComparatorBlock(feature_size)
shuffled_tensor = torch.tensor(shuffled_data, dtype=torch.float32, requires_grad=True)

# Forward pass
sorted_result = bubble_sort(shuffled_tensor, learned_comparator)
target_tensor = torch.tensor(actual_data, dtype=torch.float32)
loss = F.mse_loss(sorted_result, target_tensor)

# Backward pass
loss.backward()

print(f"Loss: {loss.item():.4f}")
print(f"Input gradient norm: {shuffled_tensor.grad.norm().item():.4f}")
print(
    f"Number of comparator parameters: {sum(p.numel() for p in learned_comparator.parameters())}"
)

# Check if gradients flow to the comparator
total_grad_norm = 0
for param in learned_comparator.parameters():
    if param.grad is not None:
        total_grad_norm += param.grad.norm().item() ** 2
total_grad_norm = total_grad_norm**0.5

print(f"Comparator gradient norm: {total_grad_norm:.4f}")
print(
    "✓ Gradients are flowing through the entire pipeline!"
    if total_grad_norm > 0
    else "✗ No gradients in comparator"
)

Loss: 0.1628
Input gradient norm: 0.0326
Number of comparator parameters: 331
Comparator gradient norm: 0.0984
✓ Gradients are flowing through the entire pipeline!


## End-to-End Training

Now we train the neural comparator to learn the sorting criterion. The network learns to output appropriate swap parameters by minimizing the reconstruction loss between the sorted output and the target sorted data.

**Training Process:**
1. **Forward Pass**: Input data → Neural Comparator → Bubble Sort → Output
2. **Loss Computation**: MSE between sorted output and target  
3. **Backward Pass**: Gradients flow back through the entire pipeline
4. **Parameter Update**: Adam optimizer updates the comparator weights

**Key Insight**: The network implicitly learns to count ones in binary vectors to determine the correct sorting order.

In [20]:
# Training setup
learned_comparator = ComparatorBlock(feature_size)
optimizer = torch.optim.Adam(learned_comparator.parameters(), lr=3e-4)

# Training data
shuffled_tensor = torch.tensor(shuffled_data, dtype=torch.float32, requires_grad=True)
target_tensor = torch.tensor(actual_data, dtype=torch.float32)


def training_step():
    """Single training step"""
    optimizer.zero_grad()

    # Forward pass through sorting algorithm
    sorted_result = bubble_sort(shuffled_tensor, learned_comparator)

    # Compute reconstruction loss
    loss = F.mse_loss(sorted_result, target_tensor)

    # Backward pass
    loss.backward()
    optimizer.step()

    return loss.item()


# Training loop
print("Training the neural comparator...")
print("Epoch\tLoss")
print("-" * 20)

for epoch in range(1000):
    loss = training_step()

    if epoch % 100 == 0:
        print(f"{epoch}\t{loss:.6f}")

print(f"Final loss: {loss:.6f}")

Training the neural comparator...
Epoch	Loss
--------------------
0	0.137729
100	0.088227
200	0.062058
300	0.044952
400	0.032504
500	0.023212
600	0.015396
700	0.007712
800	0.004370
900	0.002778
Final loss: 0.001930


In [21]:
# Evaluate the trained model
with torch.no_grad():
    # Get final sorted result
    final_sorted = bubble_sort(shuffled_tensor, learned_comparator)
    final_rounded = torch.round(final_sorted)

    print("Final sorted result (after training):")
    print("Number of 1s per row:", final_sorted.sum(dim=1))
    print(final_sorted)

    print("\nRounded result:")
    print(final_rounded)

    print("\nTarget:")
    print("Number of 1s per row:", target_tensor.sum(dim=1))
    print(target_tensor)

    # Check if sorting is perfect
    difference = final_rounded - target_tensor
    max_error = torch.abs(difference).max().item()

    print(f"\nMaximum absolute error: {max_error}")
    print(
        "Perfect sorting achieved!"
        if max_error == 0
        else f"Some errors remain (max: {max_error})"
    )

    if max_error == 0:
        print("✓ The neural network successfully learned to sort by counting 1s!")

Final sorted result (after training):
Number of 1s per row: tensor([1.0675, 2.0908, 3.0170, 3.9428, 5.0283, 6.0331, 7.0663, 7.9561, 8.8992,
        9.8989])
tensor([[1.0000e+00, 5.1918e-02, 1.2592e-02, 2.6614e-03, 2.2052e-04, 6.0783e-05,
         4.6084e-06, 1.2963e-06, 3.4263e-07, 1.1421e-07],
        [1.0000e+00, 9.6451e-01, 1.1318e-01, 1.1205e-02, 1.1693e-03, 7.0482e-04,
         6.6414e-06, 3.1076e-06, 2.4658e-07, 4.1117e-08],
        [1.0000e+00, 9.8763e-01, 8.8469e-01, 1.2927e-01, 9.6228e-03, 5.6742e-03,
         6.6642e-05, 1.2689e-05, 2.3181e-06, 3.1993e-07],
        [1.0000e+00, 9.9651e-01, 9.9118e-01, 8.7525e-01, 6.1498e-02, 1.6994e-02,
         1.2856e-03, 6.8157e-05, 2.3096e-05, 1.8056e-06],
        [1.0000e+00, 9.9965e-01, 9.9870e-01, 9.9044e-01, 9.5370e-01, 7.4372e-02,
         9.7898e-03, 1.4965e-03, 1.7134e-04, 1.1284e-05],
        [1.0000e+00, 9.9982e-01, 9.9971e-01, 9.9239e-01, 9.7648e-01, 9.0892e-01,
         1.3861e-01, 1.3740e-02, 3.2069e-03, 1.8409e-04],
        [

In [None]:
# Additional analysis: Test the learned comparator directly
print("Testing the learned comparator on specific pairs:")
print("-" * 50)

# Create test pairs with the same feature size as training data (10 dimensions)
test_pairs = [
    # 1 vs 2 ones - should swap (t≈0)
    ([1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]),
    # 3 vs 1 ones - should not swap (t≈1)
    ([1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    # 2 vs 2 ones - indifferent (t≈0.5)
    ([1, 1, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]),
]

with torch.no_grad():
    for i, (vec1, vec2) in enumerate(test_pairs):
        # Prepare input with correct feature size
        pair_input = torch.tensor([[vec1, vec2]], dtype=torch.float32)
        swap_param = learned_comparator(pair_input).item()

        sum1, sum2 = sum(vec1), sum(vec2)
        should_swap = "Yes" if sum1 > sum2 else ("No" if sum1 < sum2 else "Either")

        print(f"Pair {i+1}: sum={sum1} vs sum={sum2}")
        print(f"  Should swap: {should_swap}")
        print(f"  Comparator output: {swap_param:.3f}")
        print(f"  Interpretation: {'Swap' if swap_param < 0.5 else 'No swap'}")
        print()

print("The comparator has learned to compare based on the sum of elements!")

Testing the learned comparator on specific pairs:
--------------------------------------------------
Pair 1: sum=1 vs sum=2
  Should swap: No
  Comparator output: 0.963
  Interpretation: No swap

Pair 2: sum=3 vs sum=1
  Should swap: Yes
  Comparator output: 0.153
  Interpretation: Swap

Pair 3: sum=2 vs sum=2
  Should swap: Either
  Comparator output: 0.807
  Interpretation: No swap

The comparator has learned to compare based on the sum of elements!
