In [5]:
%cd /ibex/user/slimhy/PADS/code
import torch
from torch.nn import functional as F
from torch_linear_assignment import batch_linear_assignment


def debug_rec_loss(x, x_rec, use_loop=False):
    """
    Call the loss function.
    """
    B, N, D = (
        x.shape
    )  # B: batch size, D: latent dimension (512), N: number of vectors (8)

    # Compute the cost matrix using cdist
    cost_matrix = torch.cdist(x, x_rec, p=2)

    # Compute the linear assignment
    assignment = batch_linear_assignment(cost_matrix)
    
    # Compute the loss
    total_loss = 0
    if use_loop:
        x_rec_matched = []
        for b in range(B):
            x_rec_matched += [x_rec[b, assignment[b]]]
        x_rec_matched = torch.stack(x_rec_matched)
    else:
        x_rec_matched = x_rec[torch.arange(B).unsqueeze(1), assignment]
        # total_loss = loss.mean() * 2.

        # for b in range(B):
        #     print(x[b])
        #     print(x_rec_matched[b])
        #     print('---')

    total_loss = F.mse_loss(x, x_rec_matched)

    return total_loss


# Define two sets of vectors
set_A = torch.tensor([
    [5, 5, 5, 5],
    [2, 2, 2, 2],
    [3, 3, 3, 3],
]).type(torch.float32).unsqueeze(0)

set_B = torch.tensor([
    [4, 4, 4, 4],
    [5, 5, 5, 5],
    [6, 6, 6, 6],
]).type(torch.float32).unsqueeze(0)

set_C = torch.tensor([
    [4, 4, 4, 4],
    [5, 5, 5, 5],
    [6, 6, 6, 6],
]).type(torch.float32).unsqueeze(0)

set_D = torch.tensor([
    [1, 1, 1, 1],
    [6, 6, 6, 6],
    [3, 3, 3, 3],
]).type(torch.float32).unsqueeze(0)

x_a = torch.cat([set_A, set_B, set_A, set_B], dim=0)
x_b = torch.cat([set_C, set_D, set_C, set_D], dim=0)

debug_rec_loss(x_a, x_b, use_loop=True)

/ibex/user/slimhy/PADS/code


tensor(5.)

In [3]:
import numpy as np

def test_debug_rec_loss(debug_rec_loss, num_tests=10, rtol=1e-5, atol=1e-8):
    for _ in range(num_tests):
        # Generate random dimensions
        B = np.random.randint(1, 5)  # Batch size
        N = np.random.randint(3, 10)  # Number of vectors
        D = np.random.randint(2, 8)  # Latent dimension

        # Generate random tensors
        x_a = torch.rand(B, N, D)
        x_b = torch.rand(B, N, D)

        # Compute loss with use_loop=True
        loss_with_loop = debug_rec_loss(x_a, x_b, use_loop=True)

        # Compute loss with use_loop=False
        loss_without_loop = debug_rec_loss(x_a, x_b, use_loop=False)

        # Check if the results are close
        assert torch.isclose(loss_with_loop, loss_without_loop, rtol=rtol, atol=atol), \
            f"Test failed: loss_with_loop ({loss_with_loop}) and loss_without_loop ({loss_without_loop}) are not close enough"

    print(f"All {num_tests} tests passed successfully!")

# Run the test
test_debug_rec_loss(debug_rec_loss)


All 10 tests passed successfully!
