In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda")

In [9]:
def iterate_over(data, nexts):
    """
    Iterate through data following the permutation defined by nexts matrix.
    nexts[i] defines which element comes after position i.
    """
    data_len = data.shape[0]

    # Start at position 0
    current_pos = torch.zeros(data_len, device=data.device)
    current_pos[0] = 1.0

    result = torch.zeros(data_len, device=data.device, dtype=data.dtype)

    for step in range(data_len):
        # Get current data value using matrix multiplication
        current_value = torch.dot(current_pos, data)
        result[step] = current_value

        # Move to next position using the nexts transition matrix
        current_pos = current_pos @ nexts

    return result


data = torch.tensor([1.0, 3.0, 2.0], device=device, requires_grad=True)
target = torch.tensor([1.0, 2.0, 3.0], device=device)

# Define permutation: 0->2, 1->0, 2->1 (each position points to next)
perm_indices = torch.tensor([2, 0, 1], device=device)
nexts = F.one_hot(perm_indices, num_classes=len(data)).float().requires_grad_(True)

# Forward pass
result = iterate_over(data, nexts)
loss = F.mse_loss(result, target)

print(f"Data: {data}")
print(f"Nexts: {torch.argmax(nexts, dim=-1)}")
print(f"Target: {target}")
print(f"Result: {result}")
print(f"Loss: {loss:.4f}")

# Compute gradients using backward()
loss.backward()

print(f"\nGradients:")
print(f"data.grad: {data.grad}")
print(f"nexts.grad:\n{nexts.grad}")

Data: tensor([1., 3., 2.], device='cuda:0', requires_grad=True)
Nexts: tensor([2, 0, 1], device='cuda:0')
Target: tensor([1., 2., 3.], device='cuda:0')
Result: tensor([1., 2., 3.], device='cuda:0', grad_fn=<CopySlices>)
Loss: 0.0000

Gradients:
data.grad: tensor([0., 0., 0.], device='cuda:0')
nexts.grad:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


In [10]:
Q = F.one_hot(torch.tensor([2, 4, 1, 0, 3]), num_classes=5).float()
Q @ Q @ Q @ Q @ Q

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

In [18]:
class BistableLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = x**2
        b = (x - 1) ** 2
        return a * b


class PermuteMatrixLoss(nn.Module):
    def __init__(self, cycle_length=1, cycle_weight=0):
        super().__init__()
        self.cycle_length = cycle_length
        self.cycle_weight = cycle_weight
        self.bistable_loss = BistableLoss()

    def forward(self, P):
        loss = 0

        P_square = P**2
        axis_1_sum = P_square.sum(dim=1)
        axis_0_sum = P_square.sum(dim=0)

        # Penalize axes not adding up to one
        loss += F.mse_loss(axis_1_sum, torch.ones_like(axis_1_sum)) * 0.5
        loss += F.mse_loss(axis_0_sum, torch.ones_like(axis_0_sum)) * 0.5

        # Penalize numbers outside [0, 1]
        loss += self.bistable_loss(P).sum()

        # Cycle loss
        Q = P
        for _ in range(self.cycle_length - 1):
            Q = P @ Q
        cycle_loss = F.mse_loss(Q, torch.eye(Q.shape[0], device=P.device)) * 0.5
        loss += cycle_loss * self.cycle_weight

        return loss

In [19]:
# Test cases
permute_loss = PermuteMatrixLoss()

test1 = torch.tensor(
    [
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
    ],
    dtype=torch.float32,
)

test2 = torch.tensor(
    [
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 1],
    ],
    dtype=torch.float32,
)

test3 = torch.tensor(
    [
        [-1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
    ],
    dtype=torch.float32,
)

test4 = torch.tensor(
    [
        [2, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
    ],
    dtype=torch.float32,
)

test5 = torch.tensor(
    [
        [0.1, 0, 0],
        [0, 0.1, 0],
        [0, 0, 0.1],
    ],
    dtype=torch.float32,
)

test6 = torch.tensor(
    [
        [0.5, 0.5, 0],
        [0.5, 0.5, 0],
        [0, 0, 1],
    ],
    dtype=torch.float32,
)

test7 = torch.tensor(
    [
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 1],
    ],
    dtype=torch.float32,
)

print("Test results:")
print(f"test1 (identity matrix): {permute_loss(test1):.6f}")
print(f"test2 (valid swap permutation): {permute_loss(test2):.6f}")
print(f"test3 (negative values): {permute_loss(test3):.6f}")
print(f"test4 (values > 1): {permute_loss(test4):.6f}")
print(f"test5 (rows don't sum to 1): {permute_loss(test5):.6f}")

# test7 with cycle parameters
permute_loss_cycle = PermuteMatrixLoss(cycle_length=3, cycle_weight=1)
print(f"test7 (cycle constraint violation): {permute_loss_cycle(test7):.6f}")

Test results:
test1 (identity matrix): 0.000000
test2 (valid swap permutation): 0.000000
test3 (negative values): 4.000000
test4 (values > 1): 7.000000
test5 (rows don't sum to 1): 1.004400
test7 (cycle constraint violation): 0.222222


In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# lr = 1e-3
lr = 1e-1
num_epochs = 10000
log_interval = 1000

# Data setup
# data = torch.tensor([10, 30, 20, 50, 40], dtype=torch.float32, device=device)
# target_data = torch.tensor([10, 20, 30, 40, 50], dtype=torch.float32, device=device)
data = torch.tensor([54, 87, 67, 29, 85], dtype=torch.float32, device=device)
target_data = torch.tensor([54, 67, 87, 85, 29], dtype=torch.float32, device=device)
data_len = len(data)

# Initialize nexts parameter - try different initializations
# nexts = F.one_hot(torch.tensor([2, 4, 1, 0, 3]), data_len).float()
nexts = F.one_hot(torch.tensor([1, 1, 1, 1, 1]), data_len).float()
# nexts = F.one_hot(torch.tensor([0, 0, 0, 0, 0]), data_len).float()
# nexts = torch.rand(data_len, data_len)

nexts = nexts.to(device).requires_grad_(True)

# Setup optimizer and loss
optimizer = optim.Adam([nexts], lr=lr)
permute_loss_fn = PermuteMatrixLoss(cycle_length=data_len, cycle_weight=1.0).to(device)


def train_step():
    optimizer.zero_grad()

    # Apply softmax to get valid probability matrix
    prob_nexts = F.softmax(nexts, dim=1)
    # prob_nexts = nexts

    # Forward pass
    predicted_data = iterate_over(data, prob_nexts)

    # Compute losses
    data_loss = F.mse_loss(predicted_data, target_data)
    regularization_loss = permute_loss_fn(prob_nexts)
    total_loss = data_loss + regularization_loss

    # Backward pass
    total_loss.backward()
    optimizer.step()

    return total_loss, predicted_data, prob_nexts


# Training loop
print(
    "| Epoch |   Loss   |   Predicted   | Defuzzified |  Pred Indices | Target Indices |"
)
print("-" * 85)

for epoch in range(num_epochs):
    loss, predicted_data, prob_nexts = train_step()

    if epoch % log_interval == 0:
        with torch.no_grad():
            # Get hard assignment for visualization
            pred_indices = torch.argmax(nexts, dim=1)
            defuzzified_nexts = F.one_hot(pred_indices, data_len).float().to(device)
            defuzzified_data = iterate_over(data, defuzzified_nexts)

            target_indices = [2, 4, 1, 0, 3]  # Known correct permutation

            print(
                f"| {epoch:5d} | {loss.item():8.4f} | {torch.round(predicted_data).cpu().numpy()} | "
                f"{defuzzified_data.cpu().numpy()} | {pred_indices.cpu().tolist()} | {target_indices} |"
            )

print(f"\nFinal nexts matrix:")
print(nexts.detach().cpu().numpy())

| Epoch |   Loss   |   Predicted   | Defuzzified |  Pred Indices | Target Indices |
-------------------------------------------------------------------------------------
|     0 | 442.9864 | [54. 70. 70. 70. 70.] | [54. 87. 87. 87. 87.] | [1, 1, 1, 1, 1] | [2, 4, 1, 0, 3] |
|  1000 |   0.3668 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 3, 3] | [2, 4, 1, 0, 3] |
|  2000 |   0.3098 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 3, 3] | [2, 4, 1, 0, 3] |
|  3000 |   0.0044 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 0, 3] | [2, 4, 1, 0, 3] |
|  4000 |   0.0015 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 0, 3] | [2, 4, 1, 0, 3] |
|  5000 |   0.0008 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 0, 3] | [2, 4, 1, 0, 3] |
|  6000 |   0.0004 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 0, 3] | [2, 4, 1, 0, 3] |
|  7000 |   0.0002 | [54. 67. 87. 85. 29.] | [54. 67. 87. 85. 29.] | [2, 4, 1, 0, 3] | [2, 4, 1

In [35]:
P = F.softmax(nexts, dim=1)
print(torch.argmax(P, dim=1))
print((P @ P @ P @ P @ P).round())

tensor([2, 4, 1, 0, 3], device='cuda:0')
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]], device='cuda:0', grad_fn=<RoundBackward0>)


In [37]:
argmax_next = torch.argmax(P, dim=1)
DQ = F.one_hot(argmax_next, num_classes=data_len).float()
DQ @ DQ @ DQ @ DQ @ DQ

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]], device='cuda:0')