In [None]:
# !pip install latentis==0.0.8

# Define spaces

In [None]:
import torch
import torch.nn.functional as F

N_SAMPLES = 10_000
SPACE1_DIM = 1024
SPACE2_DIM = 768

space1 = torch.randn(N_SAMPLES, SPACE1_DIM)

In [None]:
from latentis.transform.functional import random_isometry_state

In [None]:
space2_mode = "isometry(space1)"
if space2_mode == "isometry(space1)":
    random_isometry = random_isometry_state(x=space1, random_seed=51)["matrix"]
    space2 = space1 @ random_isometry
elif space2_mode == "random":
    space2 = torch.randn(N_SAMPLES, SPACE2_DIM)
else:
    raise ValueError(f"Unknown space2_mode: {space2_mode}")

In [None]:
add_noise: bool = True

if add_noise:
    space2 = space2 + F.normalize(torch.rand_like(space2), dim=-1) * space2.norm(
        p=2, dim=0
    )

In [None]:
space1.shape, space2.shape

# Define Translator

In [None]:
from latentis.transform.base import StandardScaling
from latentis.transform.dim_matcher import ZeroPadding
from latentis.transform.translate.aligner import MatrixAligner, Translator
from latentis.transform.translate.functional import (
    svd_align_state,
    lstsq_align_state,
    lstsq_ortho_align_state,
)


translator_ortho = Translator(
    aligner=MatrixAligner(name="ortho", align_fn_state=svd_align_state),
    x_transform=StandardScaling(),
    y_transform=StandardScaling(),
    dim_matcher=ZeroPadding(),
)
# translator_ortho is the "ortho" in the paper. Basically, Procrustes analysis having an orthogonal transformation
# estimated after applying standard scaling to the spaces.

translator_linear = Translator(
    aligner=MatrixAligner(name="linear", align_fn_state=lstsq_align_state),
)

translator_linear_scaling = Translator(
    aligner=MatrixAligner(
        name="linear+standard_scaling", align_fn_state=lstsq_align_state
    ),
    x_transform=StandardScaling(),
    y_transform=StandardScaling(),
    # dim_matcher=ZeroPadding(), # No need for padding with lstsq
)

translator_linear_ortho = Translator(
    aligner=MatrixAligner(
        name="ortho(linear+standard_scaling)", align_fn_state=lstsq_ortho_align_state
    ),
    x_transform=StandardScaling(),
    y_transform=StandardScaling(),
)

# Latent Translation

In [None]:
dtype = torch.double
for translator in (
    translator_ortho,
    translator_linear,
    translator_linear_scaling,
    translator_linear_ortho,
):
    print(f"Testing {translator.aligner.name}")

    # fit the translator using the anchor data. In this case, the anchor data are the whole space1 and space2
    translator.fit(x=space1.to(dtype), y=space2.to(dtype))

    # first method to transform the space X into the space Y
    space1_transformed1 = translator.transform(space1.to(dtype))[0]

    # second method to transform the space X into the space Y (this is the same as the first method but:
    # relies on the forward method to implicitly call "transform" and returns a dictionary, not a tuple
    space1_transformed2 = translator(space1.to(dtype))["x"]

    assert torch.allclose(space1_transformed1, space1_transformed2)
    mse = (space2 - space1_transformed1).abs().mean()
    print(f"MSE: {mse}")

    cos_sim = F.cosine_similarity(space2, space1_transformed1).mean()
    print(f"Cosine similarity: {cos_sim}")
    print()

In [None]:
# Functional version (no need to create a translator object, but waaaaaaay more verbose)

from latentis.transform.functional import (
    standard_scaling_transform,
    standard_scaling_state,
    standard_scaling_inverse,
)
from latentis.transform.translate.functional import svd_align

In [None]:
def translator_ortho_fn(x, y):
    # Same flow as in Figure 2 of the paper
    y_stats = standard_scaling_state(y)
    x = standard_scaling_transform(x=x, **standard_scaling_state(x))
    y = standard_scaling_transform(x=y, **y_stats)

    x_transformed = svd_align(x, y)

    x_transformed = standard_scaling_inverse(x=x_transformed, **y_stats)

    return x_transformed

In [None]:
transformed_fn = translator_ortho_fn(space1.to(dtype), space2.to(dtype))

assert torch.allclose(translator_ortho(space1.to(dtype))["x"], transformed_fn)