In [1]:
import numpy as np
import torch
import sys
sys.path.append("netrep")
from netrep.metrics import LinearMetric
sys.path.append("..")
from Procrustes import ProcrustesDistance

In [2]:
random_state = torch.manual_seed(17)
X = torch.randn(256, 32, requires_grad=True)

A = torch.randn(32, 32, requires_grad=True)
# Full rank
while torch.linalg.matrix_rank(A) < 32:
    A = torch.randn(32, 32, requires_grad=True)
# QR decomp
Q, R = torch.linalg.qr(A)
# Ensure positive det
signs = torch.sign(torch.diag(R))
signs[signs == 0] = 1
Q = Q @ torch.diag(signs)
if torch.det(Q) < 0:
    Q[:, 0] = -Q[:, 0]
# Apply transformation with noise
Y = X @ Q + torch.randn(256, 32) * 0.1

In [3]:
proc_metric = LinearMetric(alpha=1.0, center_columns=True, score_method='euclidean')
proc_metric.fit(X.detach().numpy(), Y.detach().numpy())
dist = proc_metric.score(X.detach().numpy(), Y.detach().numpy())
print("Procrustes distance:", dist)

Procrustes distance: 8.728781203289941


In [4]:
diff_metric = ProcrustesDistance()
loss = diff_metric(X, Y)
print("Procrustes distance:", loss.item())

Procrustes distance: 8.728781170289293


In [5]:
loss.backward(retain_graph=True)
print(X.grad, Y.grad)

tensor([[-1.5946e-04,  6.0550e-05, -1.6906e-04,  ...,  2.0220e-04,
         -2.9361e-04,  1.9688e-04],
        [-1.8422e-06,  1.8042e-04,  1.0639e-04,  ...,  1.7494e-04,
         -4.0179e-04, -1.6341e-06],
        [ 3.0343e-04,  6.1424e-05,  3.5554e-04,  ...,  2.9094e-04,
         -1.6500e-04,  2.7779e-04],
        ...,
        [-2.6535e-04,  1.1842e-05,  1.9980e-04,  ...,  4.6371e-04,
         -1.1459e-04, -5.0998e-05],
        [-1.2970e-05,  4.3654e-04,  1.9298e-04,  ..., -2.8490e-04,
          4.1537e-04,  4.7382e-04],
        [-5.6980e-04, -1.1011e-04,  1.4352e-04,  ...,  1.3642e-05,
         -5.2949e-05, -3.5602e-04]]) None


  print(X.grad, Y.grad)
