In [153]:
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment

In [143]:
data = np.random.rand(100, 2) * 10

grid_size = 100
grid = np.array([(x, y) for x in range(grid_size) for y in range(grid_size)])

In [144]:
max_x, max_y = np.max(data, axis=0)
min_x, min_y = np.min(data, axis=0)

In [145]:
x_spread = max_x - min_x
y_spread = max_y - min_y

max_spread = np.max([x_spread, y_spread])

In [146]:
data[:, 0] = (data[:, 0] - min_x) * (grid_size-1) / max_spread
data[:, 1] = (data[:, 1] - min_y) * (grid_size-1) / max_spread

In [147]:
cost_matrix = np.array([[np.sqrt(np.sum((d - g)**2)) for g in grid] for d in data])

In [148]:
row_indices, col_indices = linear_sum_assignment(cost_matrix)

In [151]:
data_new = grid[col_indices, :]

In [152]:
data_new

array([[74, 34],
       [33, 58],
       [88, 45],
       [89, 37],
       [51,  3],
       [68, 87],
       [60, 32],
       [ 5, 21],
       [83,  6],
       [81, 10],
       [ 0, 11],
       [62, 56],
       [ 1, 65],
       [73, 51],
       [11,  9],
       [83, 21],
       [15, 51],
       [ 9, 10],
       [87, 27],
       [76,  4],
       [83,  3],
       [15, 17],
       [65,  3],
       [54,  5],
       [ 8, 65],
       [ 5, 27],
       [43, 51],
       [65, 25],
       [60, 38],
       [92, 54],
       [99, 93],
       [27, 42],
       [78, 42],
       [96, 24],
       [96,  9],
       [56, 96],
       [13, 43],
       [71, 54],
       [36, 86],
       [64, 59],
       [ 1, 30],
       [48, 55],
       [84,  0],
       [31, 65],
       [27, 75],
       [98, 62],
       [50,  4],
       [ 5, 42],
       [27, 77],
       [11, 46],
       [71, 21],
       [41, 57],
       [28, 20],
       [62, 27],
       [23, 87],
       [56, 35],
       [66, 33],
       [12, 21],
       [12, 30

In [64]:
indices = list(range(len(data)))

dist = {(i, j): np.sqrt(np.sum((data[i] - data[j])**2)) for i in indices for j in indices}

In [77]:
mapping = random.sample(grid, k=len(data))

In [78]:
mapping

[(9, 40),
 (90, 39),
 (83, 97),
 (43, 80),
 (88, 20),
 (3, 65),
 (68, 29),
 (72, 5),
 (65, 19),
 (77, 65),
 (76, 90),
 (46, 83),
 (56, 90),
 (32, 22),
 (1, 52),
 (86, 41),
 (10, 15),
 (75, 44),
 (54, 84),
 (41, 81),
 (41, 13),
 (97, 39),
 (5, 51),
 (9, 73),
 (1, 11),
 (33, 21),
 (13, 95),
 (86, 46),
 (84, 91),
 (53, 54),
 (63, 33),
 (38, 34),
 (45, 59),
 (13, 28),
 (75, 22),
 (49, 64),
 (52, 65),
 (51, 39),
 (89, 17),
 (76, 3),
 (39, 43),
 (75, 18),
 (88, 21),
 (98, 47),
 (88, 12),
 (8, 49),
 (85, 11),
 (89, 90),
 (42, 81),
 (20, 4),
 (84, 31),
 (64, 47),
 (35, 62),
 (86, 79),
 (14, 86),
 (50, 92),
 (67, 5),
 (36, 66),
 (3, 68),
 (62, 74),
 (53, 25),
 (30, 36),
 (82, 2),
 (99, 95),
 (8, 79),
 (70, 5),
 (87, 60),
 (19, 52),
 (27, 26),
 (23, 40),
 (10, 79),
 (52, 73),
 (26, 47),
 (70, 22),
 (29, 0),
 (82, 56),
 (6, 91),
 (21, 31),
 (31, 81),
 (33, 60),
 (42, 45),
 (20, 46),
 (43, 75),
 (68, 43),
 (36, 34),
 (39, 78),
 (97, 56),
 (47, 5),
 (23, 81),
 (41, 61),
 (28, 53),
 (46, 87),
 (28, 

In [48]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.linear2 = torch.nn.Linear(10, 10)
        self.linear3 = torch.nn.Linear(10, 2)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        return self.linear3(x)

In [49]:
model = Model()

In [50]:
def loss(x):
    return torch.mean(torch.sin(2 * np.pi * x - 0.5 * np.pi) + 1)

In [51]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [52]:
for i in range(1000):
    optimizer.zero_grad()

    output = model(a)
    l = loss(output)
    l.backward()

    optimizer.step()

In [53]:
model(a)

tensor([[-1.0000e+00,  4.4703e-07],
        [-1.0000e+00,  4.0978e-07],
        [-1.0000e+00,  4.3213e-07],
        [-1.0000e+00,  4.3213e-07],
        [-1.0000e+00,  4.6194e-07],
        [-1.0000e+00,  3.5763e-07],
        [-1.0000e+00,  4.0233e-07],
        [-1.0000e+00,  3.8743e-07],
        [-1.0000e+00,  4.3213e-07],
        [-1.0000e+00,  4.1723e-07],
        [-1.0000e+00,  3.8743e-07],
        [-1.0000e+00,  4.6194e-07],
        [-1.0000e+00,  3.5763e-07],
        [-1.0000e+00,  2.2352e-07],
        [-1.0000e+00,  4.6194e-07],
        [-1.0000e+00,  4.7684e-07],
        [-1.0000e+00,  4.5449e-07],
        [-1.0000e+00,  4.7684e-07],
        [-1.0000e+00,  4.0978e-07],
        [-1.0000e+00,  5.0664e-07],
        [-1.0000e+00,  3.8743e-07],
        [-1.0000e+00,  4.6194e-07],
        [-1.0000e+00,  3.8743e-07],
        [-1.0000e+00,  4.3213e-07],
        [-1.0000e+00,  3.7253e-07],
        [-1.0000e+00,  4.6194e-07],
        [-1.0000e+00,  4.6194e-07],
        [-1.0000e+00,  4.172

# Algorithm Sketch

Given:
- high-dimensional embeddings $x \in X$
- a mapping from these to grid coordinates $g$
- a distance function $d$

1. Randomly sample grid coordinates for each high-dimensional embedding (no replacement)
2. Sample three data points $x_1, x_2, x_3$
3. Calculate $d_x = d(x_1, x_2) / d(x_1, x_3)$
4. Calculate $d_g = d(g(x_1), g(x_2)) / d(g(x_1), g(x_3))$
5. Calculate the loss $l = (d_g - d_x)^2$
6. Calculate the gradients of $l$ w.r.t. $g(x_i)$ with $i \in \{1, 2, 3\}$
7. Apply the gradients to the grid coordinates
8. Find the nearest free possible grid coordinates for these data points and assign them
9. Go to 2. and repeat until convergence