# Simple demonstration of OT4P
We provide a minimal example demonstrating the use of OT4P. Given $X$ and $Y = PXP^{\top}$, where $P$ is a permutation matrix, the objective is to find the true permutation matrix $P$, which is defined as:

$$
\min_P \|PXP^{\top} - Y\|^2.
$$

We use OT4P to solve the above problem from three different perspectives: deterministic optimization, 
stochastic optimization, and constrained optimization.

In [None]:
## Install the library
# !pip install torch
# !pip install git+https://github.com/ivan-chai/torch-linear-assignment.git

## Define the problem
import torch
from src.ot4p import OT4P

size = 100
X = torch.randn(size, size)
trueP = torch.eye(size)[torch.randperm(size)]
Y = trueP @ X @ trueP.T

# loss function
def loss_fn(Y, X, P):
    return torch.mean(torch.pow(P @ X @ P.transpose(-2, -1) - Y, 2))

## Deterministic optimization
As described in Section~3.2, we address this problem from the perspective of deterministic optimization.

In [None]:
# Initialize the weight parameter, model, and optimizer
weightP = torch.nn.Parameter(torch.randn(size, size), requires_grad=True)
model = OT4P(size)
optimizer = torch.optim.AdamW([weightP], lr=1e-1)

# Perform 500 iterations
print("Starting Deterministic Optimization...")
for i in range(500):
    optimizer.zero_grad()
    perm_matrix = model(weightP, tau=0.5)
    loss_train = loss_fn(Y, X, perm_matrix)
    loss_train.backward()
    optimizer.step()
    
    # Compute validation loss
    with torch.no_grad():
        perm_matrix_val = model(weightP, tau=0)
        loss_val = loss_fn(Y, X, perm_matrix_val)
        
    # Print training and validation losses
    print(f"Iteration {i+1}: Training Loss = {loss_train.item():.6f}, Validation Loss = {loss_val.item():.6f}")

    # Update base of the model
    model.update_base(weightP)
    
    # Check convergence
    if loss_val < 1e-5:
        print(f"Deterministic optimization converges at iteration {i+1}")
        break

## Stochastic optimization

As described in Section 3.3, we address this problem from the perspective of stochastic optimization.

In [None]:
# Initialize the weight parameter, model, and optimizer
weightP = torch.nn.Parameter(torch.randn(size, size), requires_grad=True)
log_weightP_var = torch.nn.Parameter(torch.randn(size, size), requires_grad=True)
model = OT4P(size)
optimizer = torch.optim.AdamW([weightP], lr=1e-1)

# Perform 500 iterations
print("Starting Stochastic Optimization...")
for i in range(500):
    optimizer.zero_grad()
    
    # Re-parameterization trick
    mean = weightP.unsqueeze(0).expand(5, -1, -1)
    std = torch.exp(log_weightP_var / 2).unsqueeze(0).expand(5, -1, -1)
    sample = mean + std * torch.randn_like(mean) * 0.01
    
    perm_matrix = model(sample, tau=0.5)
    loss_train = loss_fn(Y, X, perm_matrix)
    loss_train.backward()
    optimizer.step()
    
    # Compute validation loss
    with torch.no_grad():
        perm_matrix_val = model(weightP, tau=0)
        loss_val = loss_fn(Y, X, perm_matrix_val)
        
    # Print training and validation losses
    print(f"Iteration {i+1}: Training Loss = {loss_train.item():.6f}, Validation Loss = {loss_val.item():.6f}")

    # Update base of the model
    model.update_base(weightP)
    
    # Check convergence
    if loss_val < 1e-5:
        print(f"Stochastic optimization converges at iteration {i+1}")
        break

## Constrained optimization
When some matching relationships are already determined, we can incorporate this constraint into the model to reduce the complexity of the problem.

In [None]:
# Initialize the constraint matrix
constraint_matrix = torch.ones((size, size))
num_selected = int(size * 0.05)
# Select a subset of rows
selected_rows = torch.randperm(size)[:num_selected]
# Set the constraints
for row in selected_rows:
    col_index = trueP[row].nonzero().item()
    constraint_matrix[row, :] = 0
    constraint_matrix[:, col_index] = 0
    constraint_matrix[row, col_index] = 1

# Initialize the weight parameter, model with constraint, and optimizer
weightP = torch.nn.Parameter(torch.randn(size, size), requires_grad=True)
model = OT4P(size)
model.constraint = constraint_matrix.unsqueeze(0)
optimizer = torch.optim.AdamW([weightP], lr=1e-1)

# Perform 500 iterations
print("Starting Constrained Optimization...")
for i in range(500):
    optimizer.zero_grad()
    perm_matrix = model(weightP, tau=0.5)
    loss_train = loss_fn(Y, X, perm_matrix)
    loss_train.backward()
    optimizer.step()
    
    # Compute validation loss
    with torch.no_grad():
        perm_matrix_val = model(weightP, tau=0)
        loss_val = loss_fn(Y, X, perm_matrix_val)
        
    # Print training and validation losses
    print(f"Iteration {i+1}: Training Loss = {loss_train.item():.6f}, Validation Loss = {loss_val.item():.6f}")

    # Update base of the model
    model.update_base(weightP)
    
    # Check convergence
    if loss_val < 1e-5:
        print(f"Constrained optimization converges at iteration {i+1}")
        break