In [None]:
class NormalizedDynamics(torch.nn.Module):
    def __init__(self, dim=2, alpha=1.0, max_iter=50):
        super().__init__()
        self.dim = dim
        self.alpha = torch.nn.Parameter(torch.tensor(alpha))
        self.max_iter = max_iter

    def forward(self, x):
        original_mean = torch.mean(x, dim=0, keepdim=True)
        original_std = torch.std(x, dim=0, keepdim=True)

        x_centered = x - original_mean
        dists = torch.cdist(x_centered, x_centered)

        k = min(15, x.size(0) - 1)
        kth_dists, _ = torch.topk(dists, k, dim=1, largest=False)
        sigma = kth_dists[:, -1].view(-1, 1)

        kernel = torch.exp(-dists / (2 * sigma**2))
        kernel = kernel / torch.sum(kernel, dim=1, keepdim=True)

        drift = torch.matmul(kernel, x_centered)

        step_size = self.dim**(-self.alpha)
        h = x_centered + step_size * (drift - x_centered)

        h = h * (original_std / torch.std(h, dim=0, keepdim=True))
        h = h + original_mean

        return h

    def fit_transform(self, X):
        if not torch.is_tensor(X):
            X = torch.tensor(X, dtype=torch.float32)

        X_embedded = X.clone()
        for _ in range(self.max_iter):
            X_embedded = self.forward(X_embedded)
        return X_embedded.detach().numpy()