## Sample thed data

In [6]:
import os
import sys
import time
import wandb
import torch
import numpy as np
from model import ModelArgs, MLPEncoderArgs, Transformer, MLPEncoder, MLLMTransformer, TransformerEncoderArgs, TransformerEncoder, CNNEncoderArgs,CNNEncoder
from dataset import ICLDataset, MMDataset, get_mus_label_class, generate_input_seqs, generate_input_seqs_mm_v1,get_mm_mus_label_class
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True

# Sample the data
K1 = 8192
K2=512
N=8
L1 = 32
L2 = 16
D1 = 64
D2 = 512
S = 5
B=1
alpha1 = 0.0
alpha2 = 0.0
eps1 = 0.1
eps2 = 0.1
no_repeats = False
rope = True
rope_theta = 10000
P1 =1.0/(np.arange(1,K1+1)**alpha1)
P1 = P1/np.sum(P1)
P2 =1.0/(np.arange(1,K2+1)**alpha2)
P2 = P2/np.sum(P2)
mus_label_m1, mus_class_m1, labels_class_m1, mus_label_m2, mus_class_m2, labels_class_m2, mapping_m2_to_m1 = get_mm_mus_label_class(K1=K1,K2=K2,L1=L1,L2=L2,D1=D1,D2=D2)
inputs_mm, inputs_2, labels, label_sequences = generate_input_seqs_mm_v1(mus_label_m1=mus_label_m1, mus_class_m1=mus_class_m1, mus_label_m2=mus_label_m2, mus_class_m2=mus_class_m2, labels_class_m2=labels_class_m2, mapping_m2_to_m1=mapping_m2_to_m1, N=N,S=S,eps1=eps1,eps2=eps2, P1 = P1, P2 = P2, B = B, p_B = 1, p_C = 1, no_repeats = no_repeats, seq_labels=True)
print(label_sequences)


tensor([[10., 10., 11.,  1.,  2., 15.,  6.,  0., 15.],
        [ 2., 13.,  4., 13., 15.,  1.,  5., 12., 13.],
        [10., 15.,  9.,  3.,  4., 15., 13.,  0.,  0.],
        [ 7.,  2., 11., 12.,  8.,  2.,  2.,  8.,  7.],
        [ 8.,  0.,  5.,  3.,  5.,  6., 11.,  3.,  6.]])


## Comparison of the feature before and after encoder

### Load the encoder 

In [2]:
import os
import sys
import time
import wandb
import torch
import numpy as np
from model import ModelArgs, MLPEncoderArgs, Transformer, MLPEncoder, MLLMTransformer, TransformerEncoderArgs, TransformerEncoder, CNNEncoderArgs,CNNEncoder
from dataset import ICLDataset, MMDataset, get_mus_label_class, generate_input_seqs, generate_input_seqs_mm_v1,get_mm_mus_label_class
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
SEED = 0
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
K1 = 8192
K2=512
eps0=0.5
D2=512
D1=64
dir = os.getcwd()
ckpt_path_enc = f"{dir}/outs_encoder_transformer/K{K2}_eps{eps0}_feat_dim{D2}_input_dim128_output_dim{D1//2}_num_layers2_num_heads1_niter50000/seed_0/ckpt_49999.pt"
model_args_enc = TransformerEncoderArgs(
            feat_dim=D2,
            input_dim=128,
            output_dim=D1//2,
            num_classes=K2,
            num_layers=2,
            num_heads=1
        )
Encoder = TransformerEncoder(model_args_enc)
Encoder.load_state_dict(torch.load(ckpt_path_enc), strict=True)



<All keys matched successfully>

In [3]:
x2_raw = torch.FloatTensor(mus_class_m2)
x2_vit = Encoder.extract_features(x2_raw).squeeze(1)
x1 = torch.FloatTensor(mus_class_m1)[mapping_m2_to_m1[:,0]]
def l2_norm(v): return v / v.norm(dim=-1, keepdim=True).clamp_min(1e-9)

x1_norm      = l2_norm(x1)
x2_encoder_proj_norm  = l2_norm(x2_raw)
x2_vit_norm = l2_norm(x2_vit)

## Compare x2_raw with x2_vit

#### 1 Global geometry

In [4]:
# Spectrum
Xraw = (x2_raw - x2_raw.mean(0)) / x2_raw.std(0, unbiased=False).clamp_min(1e-9)
Xvit = (x2_vit - x2_vit.mean(0)) / x2_vit.std(0, unbiased=False).clamp_min(1e-9)
vals_raw = torch.linalg.svdvals(Xraw)
vals_vit = torch.linalg.svdvals(Xvit)
print(
    "Top-5 singular values\n raw :",
    (torch.round(vals_raw[:10],  decimals=3)).tolist(),
    "\n vit :",
    (torch.round(vals_vit[:10],  decimals=3)).tolist()
)
# Effective rank (lower ⇢ fewer useful dims)
def eff_rank(v): 
    s=v/v.sum()
    return torch.exp(-(s*torch.log(s)).sum()) 
print("eff-rank raw =", eff_rank(vals_raw).item(), "\n vit =", eff_rank(vals_vit).item())

# Isotropy ratio (≈trace / λ₁; bigger = flatter)
iso = lambda v: v.sum()/(v.max()*len(v)) 
print("isotropy raw =", iso(vals_raw).item(), "\n vit =", iso(vals_vit).item())


Top-5 singular values
 raw : [44.54999923706055, 44.275001525878906, 44.17100143432617, 43.595001220703125, 43.33700180053711, 43.20899963378906, 42.939998626708984, 42.729000091552734, 42.680999755859375, 42.29999923706055] 
 vit : [37.85599899291992, 34.77799987792969, 33.06999969482422, 32.689998626708984, 31.996999740600586, 31.75, 30.42799949645996, 29.457000732421875, 29.24799919128418, 28.30299949645996]
eff-rank raw = 412.1483154296875 
 vit = 25.776538848876953
isotropy raw = 0.43117275834083557 
 vit = 0.5104237794876099


#### 2 Pair-wise structure (cluster separation)

In [5]:
with torch.no_grad():
    # cosine distance matrix (fits on GPU with K2=512)
    D_raw = 1 - x2_raw_norm @ x2_raw_norm.T          # [K2,K2], 0 means vectors are very similar (angle = 0°), 1 means vectors are orthogonal (angle = 90°)
    D_vit = 1 - x2_vit_norm @ x2_vit_norm.T

def stats(D):
    tri = D.triu(1)                      # upper triangle, no diag
    return tri[tri>0].mean(), tri[tri>0].std()

mu_raw,  sd_raw  = stats(D_raw)
mu_vit,  sd_vit  = stats(D_vit) 
print(f"mean cosine-dist  raw={mu_raw:.3f}±{sd_raw:.3f}   vit={mu_vit:.3f}±{sd_vit:.3f}")


NameError: name 'x2_raw_norm' is not defined

In [36]:
from torch.linalg import svd
# ❷ Compute cross-covariance and SVD
M_raw = x2_raw_norm.T @ x1_norm                          # [D2, D1]
U_raw, _, Vt_raw = svd(M_raw, full_matrices=False)   # U:[D2,r], Vt:[r,D1], r = min(D1,D2)
M_vit = x2_vit_norm.T @ x1_norm                          # [D2, D1]
U_vit, _, Vt_vit = svd(M_vit, full_matrices=False)   # U:[D2,r], Vt:[r,D1], r = min(D1,D2)

# ❸ Best orthogonal map W*:      X2 · W* ≈ X1
Wstar_raw = U_raw @ Vt_raw                          # [D2, D1]
Wstar_vit = U_vit @ Vt_vit                          # [D2, D1]

# ❹ Apply and evaluate
X2_to_X1_raw = x2_raw_norm @ Wstar_raw                  # [K2, D1]
X2_to_X1_vit = x2_vit_norm @ Wstar_vit                  # [K2, D1]
paired_cos_raw = torch.sum(x1_norm * X2_to_X1_raw, dim=-1)  # cosine for every class
paired_cos_vit = torch.sum(x1_norm * X2_to_X1_vit, dim=-1)  # cosine for every class
print("⟨cos⟩ after Procrustes, raw:", paired_cos_raw.mean().item(), "vit:", paired_cos_vit.mean().item())


⟨cos⟩ after Procrustes, raw: 0.17050622403621674 vit: 0.17902471125125885


## Comparison of the feature w/wo encoder after projector

### load the proj only model

In [1]:
import os
import sys
import time
import wandb
import torch
import numpy as np
from model import ModelArgs, Transformer, MMTransformer
from dataset import ICLDataset, MMDataset, get_mus_label_class, generate_input_seqs, generate_input_seqs_mm_v1,get_mm_mus_label_class
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
K1 = 8192
K2=512
eps0=0.5
D2=512
D1=64
N=8
model_args_stage2 = ModelArgs(
        m1_dim=D1,
        m2_dim=D2,
        dim=D1,
        n_layers=2,
        n_heads=1,
        n_labels=32,
        max_position_embeddings=3*N+1,
        rope_theta=10000,
        mlp_bias=True,
        rms_norm=True,
        rope=True,
        norm_eps=1e-5,
        L_pos=64
    )
model_stage2 = MMTransformer(model_args_stage2)
ckpt_path_stage2 = "/home/aoq609/ICL/outs_torch/K1_8192_K2_512_N8_D1_64_D2_512_L1_32_L2_16_alpha1_0.0_alpha2_0.0_B2_pB1.0_pC0.0_eps1_0.1_eps2_0.1_no_repeatsFalse_rope_True_rope_theta10000_freeze_layersFalse_n_heads1_n_layers2_rms_normTrue_optimizerSGD_niters80000_n_epochs1/seed_0/ckpt_20000.pt"
model_stage2.load_state_dict(torch.load(ckpt_path_stage2), strict=False)
model_stage2.eval()


MMTransformer(
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attn): Attention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=64, bias=False)
        (wv): Linear(in_features=64, out_features=64, bias=False)
        (rotary_emb): RotaryEmbedding()
      )
      (mlp): MLP(
        (fc1): Linear(in_features=64, out_features=64, bias=True)
        (fc2): Linear(in_features=64, out_features=64, bias=True)
        (fc3): Linear(in_features=64, out_features=64, bias=True)
        (act): SiLU()
      )
      (attn_norm): RMSNorm()
      (mlp_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (out): Linear(in_features=64, out_features=32, bias=False)
  (projector): Projector(
    (fc1): Linear(in_features=512, out_features=64, bias=False)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=64, out_features=64, bias=False)
  )
)

### Load stage 3 model

In [4]:
import os
import sys
import time
import wandb
import torch
import numpy as np
from model import ModelArgs, MLPEncoderArgs, Transformer, MLPEncoder, MLLMTransformer, TransformerEncoderArgs, TransformerEncoder, CNNEncoderArgs,CNNEncoder
from dataset import ICLDataset, MMDataset, get_mus_label_class, generate_input_seqs, generate_input_seqs_mm_v1,get_mm_mus_label_class
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
model_args_mm = ModelArgs(
        m1_dim=D1,
        m2_dim=D1//2,
        dim=D1,
        n_layers=2,
        n_heads=1,
        n_labels=32,
        max_position_embeddings=3*N+1,
        rope_theta=10000,
        mlp_bias=True,
        rms_norm=True,
        rope=True,
        norm_eps=1e-5,
        L_pos=64
    )
model_args_enc = TransformerEncoderArgs(
            feat_dim=D2,
            input_dim=128,
            output_dim=D1//2,
            num_classes=K2,
            num_layers=2,
            num_heads=1
        )
Encoder = TransformerEncoder(model_args_enc)
K1 = 8192
K2=512
eps0=0.5
D2=512
D1=64
dir = os.getcwd()
ckpt_path_enc = f"{dir}/outs_encoder_transformer/K{K2}_eps{eps0}_feat_dim{D2}_input_dim128_output_dim{D1//2}_num_layers2_num_heads1_niter50000/seed_0/ckpt_49999.pt"
Encoder.load_state_dict(torch.load(ckpt_path_enc), strict=True)
model_stage3 = MLLMTransformer(model_args_mm)
model_stage3.init_encoder(Encoder)
ckpt_path = "/home/aoq609/ICL/outs_torch/K1_8192_K2_512_N8_D1_64_D2_512_L1_32_L2_16_alpha1_0.0_alpha2_0.0_B2_pB1.0_pC0.0_eps00.5_eps1_0.1_eps2_0.1_no_repeatsFalse_rope_True_encoder_transformer_freeze_layersFalse_freeze_encoderFalse_n_heads1_n_layers2_niters150000/seed_0/ckpt_140000.pt"
model_stage3.load_state_dict(torch.load(ckpt_path),strict=False)
model_stage3.eval()
print(model_stage3)

MLLMTransformer(
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attn): Attention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=64, bias=False)
        (wv): Linear(in_features=64, out_features=64, bias=False)
        (rotary_emb): RotaryEmbedding()
      )
      (mlp): MLP(
        (fc1): Linear(in_features=64, out_features=64, bias=True)
        (fc2): Linear(in_features=64, out_features=64, bias=True)
        (fc3): Linear(in_features=64, out_features=64, bias=True)
        (act): SiLU()
      )
      (attn_norm): RMSNorm()
      (mlp_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (out): Linear(in_features=64, out_features=32, bias=False)
  (projector): Projector(
    (fc1): Linear(in_features=32, out_features=64, bias=False)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=64, out_features=64, bias=False)
  )
  (encoder): TransformerEncoder(
    (patch_embed): Linear(in_feature



In [14]:
mapping_m2_to_m1.shape

(512, 256)

In [None]:
x2_raw = torch.FloatTensor(mus_class_m2)
x2_proj = model_stage2.projector(x2_raw)
x2_encoder_proj = model_stage3.projector(model_stage3.encoder.extract_features(x2_raw)).squeeze(1)
# x1 = torch.FloatTensor(mus_class_m1)[mapping_m2_to_m1[:,0]]
mus_class_m1_t = torch.as_tensor(mus_class_m1, dtype=torch.float32)     # [K1, D1]

x1 = torch.stack([
        mus_class_m1_t[torch.as_tensor(idxs)].mean(dim=0)               # mean along class axis
        for idxs in mapping_m2_to_m1                                    # one row per modality-2 class
    ])                                                                  # -> [K2, D1]
def l2_norm(v): return v / v.norm(dim=-1, keepdim=True).clamp_min(1e-9)

# x1_norm      = l2_norm(x1).detach()
# x2_proj_norm = l2_norm(x2_proj).detach()
# x2_encoder_proj_norm = l2_norm(x2_encoder_proj).detach()

In [8]:
print("x1_norm.shape", x1_norm.shape)
print("x2_proj_norm.shape", x2_proj_norm.shape)
print("x2_encoder_proj_norm.shape", x2_encoder_proj_norm.shape)

x1_norm.shape torch.Size([512, 64])
x2_proj_norm.shape torch.Size([512, 64])
x2_encoder_proj_norm.shape torch.Size([512, 64])


In [9]:
import torch
import math

def effective_rank(X: torch.Tensor, center: bool = True, use_squared: bool = True, eps: float = 1e-12):
    """
    Effective rank of matrix X via entropy of singular-value distribution.
      - center: subtract column mean before SVD (recommended for embeddings).
      - use_squared: p_i ∝ σ_i^2 (energy). If False, p_i ∝ σ_i (amplitude).
    Returns: dict with e_rank, entropy, p (distribution), and singular values.
    """
    assert X.ndim == 2, "X must be 2D [n_samples, n_features]"
    Xc = X - X.mean(dim=0, keepdim=True) if center else X

    # Singular values (no need for full U,V)
    s = torch.linalg.svdvals(Xc)  # shape [min(n,d)]

    if use_squared:
        w = s**2
    else:
        w = s

    w = torch.clamp(w, min=eps)
    p = w / w.sum()

    H = -(p * torch.log(p)).sum()          # natural log
    e_rank = torch.exp(H)                  # effective rank
    return {"e_rank": e_rank.item(), "H": H.item(), "p": p, "s": s}

# --- Compute for your tensors ---
tensors = {
    "x1_norm": x1_norm.detach().cpu(),
    "x2_raw": x2_raw.detach().cpu(),
    "x2_proj_norm": x2_proj_norm.detach().cpu(),
    "x2_encoder_proj_norm": x2_encoder_proj_norm.detach().cpu(),
}

print("=== Effective Rank (energy-based, centered) ===")
results = {}
for name, X in tensors.items():
    res = effective_rank(X, center=True, use_squared=True)
    results[name] = res["e_rank"]
    print(f"{name:24s}: {res['e_rank']:.3f}")

# (Optional) also report the amplitude-based variant for completeness
print("\n=== Effective Rank (amplitude-based, centered) ===")
for name, X in tensors.items():
    res = effective_rank(X, center=True, use_squared=False)
    print(f"{name:24s}: {res['e_rank']:.3f}")

# (Optional) sanity: show matrix shapes
for name, X in tensors.items():
    print(f"{name:24s} shape = {tuple(X.shape)}")



=== Effective Rank (energy-based, centered) ===
x1_norm                 : 13.615
x2_raw                  : 310.255
x2_proj_norm            : 10.101
x2_encoder_proj_norm    : 17.148

=== Effective Rank (amplitude-based, centered) ===
x1_norm                 : 14.621
x2_raw                  : 411.954
x2_proj_norm            : 29.342
x2_encoder_proj_norm    : 33.910
x1_norm                  shape = (512, 64)
x2_raw                   shape = (512, 512)
x2_proj_norm             shape = (512, 64)
x2_encoder_proj_norm     shape = (512, 64)


In [11]:
import torch
import math

def effective_rank(X: torch.Tensor, center: bool = True, use_squared: bool = True, eps: float = 1e-12):
    """
    Effective rank of matrix X via entropy of singular-value distribution.
      - center: subtract column mean before SVD (recommended for embeddings).
      - use_squared: p_i ∝ σ_i^2 (energy). If False, p_i ∝ σ_i (amplitude).
    Returns: dict with e_rank, entropy, p (distribution), and singular values.
    """
    assert X.ndim == 2, "X must be 2D [n_samples, n_features]"
    Xc = X - X.mean(dim=0, keepdim=True) if center else X

    # Singular values (no need for full U,V)
    s = torch.linalg.svdvals(Xc)  # shape [min(n,d)]

    if use_squared:
        w = s**2
    else:
        w = s

    w = torch.clamp(w, min=eps)
    p = w / w.sum()

    H = -(p * torch.log(p)).sum()          # natural log
    e_rank = torch.exp(H)                  # effective rank
    return {"e_rank": e_rank.item(), "H": H.item(), "p": p, "s": s}

# --- Compute for your tensors ---
tensors = {
    "x1": x1.detach().cpu(),
    "x2_raw": x2_raw.detach().cpu(),
    "x2_proj": x2_proj.detach().cpu(),
    "x2_encoder_proj": x2_encoder_proj.detach().cpu(),
}

print("=== Effective Rank (energy-based, centered) ===")
results = {}
for name, X in tensors.items():
    res = effective_rank(X, center=True, use_squared=True)
    results[name] = res["e_rank"]
    print(f"{name:24s}: {res['e_rank']:.3f}")

# (Optional) also report the amplitude-based variant for completeness
print("\n=== Effective Rank (amplitude-based, centered) ===")
for name, X in tensors.items():
    res = effective_rank(X, center=True, use_squared=False)
    print(f"{name:24s}: {res['e_rank']:.3f}")

# (Optional) sanity: show matrix shapes
for name, X in tensors.items():
    print(f"{name:24s} shape = {tuple(X.shape)}")



=== Effective Rank (energy-based, centered) ===
x1                      : 13.524
x2_raw                  : 310.255
x2_proj                 : 8.654
x2_encoder_proj         : 16.556

=== Effective Rank (amplitude-based, centered) ===
x1                      : 14.589
x2_raw                  : 411.954
x2_proj                 : 27.071
x2_encoder_proj         : 33.195
x1                       shape = (512, 64)
x2_raw                   shape = (512, 512)
x2_proj                  shape = (512, 64)
x2_encoder_proj          shape = (512, 64)


In [10]:
import torch

def gram_linear(X: torch.Tensor) -> torch.Tensor:
    """Compute linear kernel Gram matrix."""
    return X @ X.T

def center_gram(G: torch.Tensor) -> torch.Tensor:
    """Center the Gram matrix (like in HSIC)."""
    n = G.size(0)
    H = torch.eye(n, device=G.device) - torch.ones((n,n), device=G.device)/n
    return H @ G @ H

def cka(X: torch.Tensor, Y: torch.Tensor, center: bool = True) -> float:
    """
    Compute linear CKA between two sets of representations.
    X: [n, d1], Y: [n, d2]
    """
    # ensure same number of rows (samples)
    assert X.size(0) == Y.size(0), "X and Y must have same number of samples"

    Gx = gram_linear(X)
    Gy = gram_linear(Y)

    if center:
        Gx = center_gram(Gx)
        Gy = center_gram(Gy)

    hsic = (Gx * Gy).sum()
    norm_x = torch.sqrt((Gx * Gx).sum())
    norm_y = torch.sqrt((Gy * Gy).sum())
    return (hsic / (norm_x * norm_y)).item()

# --- compute CKA ---
pairs = {
    "x1_norm vs x2_raw": (x1_norm, x2_raw),
    "x1_norm vs x2_proj_norm": (x1_norm, x2_proj_norm),
    "x1_norm vs x2_encoder_proj_norm": (x1_norm, x2_encoder_proj_norm),
}

print("=== CKA with x1_norm ===")
for name, (X, Y) in pairs.items():
    val = cka(X.detach(), Y.detach())
    print(f"{name:32s}: {val:.4f}")



=== CKA with x1_norm ===
x1_norm vs x2_raw               : 0.1097
x1_norm vs x2_proj_norm         : 0.3793
x1_norm vs x2_encoder_proj_norm : 0.0373


In [12]:
import torch

def gram_linear(X: torch.Tensor) -> torch.Tensor:
    """Compute linear kernel Gram matrix."""
    return X @ X.T

def center_gram(G: torch.Tensor) -> torch.Tensor:
    """Center the Gram matrix (like in HSIC)."""
    n = G.size(0)
    H = torch.eye(n, device=G.device) - torch.ones((n,n), device=G.device)/n
    return H @ G @ H

def cka(X: torch.Tensor, Y: torch.Tensor, center: bool = True) -> float:
    """
    Compute linear CKA between two sets of representations.
    X: [n, d1], Y: [n, d2]
    """
    # ensure same number of rows (samples)
    assert X.size(0) == Y.size(0), "X and Y must have same number of samples"

    Gx = gram_linear(X)
    Gy = gram_linear(Y)

    if center:
        Gx = center_gram(Gx)
        Gy = center_gram(Gy)

    hsic = (Gx * Gy).sum()
    norm_x = torch.sqrt((Gx * Gx).sum())
    norm_y = torch.sqrt((Gy * Gy).sum())
    return (hsic / (norm_x * norm_y)).item()

# --- compute CKA ---
pairs = {
    "x1 vs x2_raw": (x1, x2_raw),
    "x1 vs x2_proj": (x1, x2_proj),
    "x1 vs x2_encoder_proj": (x1, x2_encoder_proj),
}

print("=== CKA with x1 ===")
for name, (X, Y) in pairs.items():
    val = cka(X.detach(), Y.detach())
    print(f"{name:32s}: {val:.4f}")



=== CKA with x1 ===
x1 vs x2_raw                    : 0.1089
x1 vs x2_proj                   : 0.3241
x1 vs x2_encoder_proj           : 0.0349


In [13]:
def l2_distance(a, b):
    return (a - b).norm(p=2, dim=1).mean().item()

dist_proj_vs_x1      = l2_distance(x2_proj, x1)
dist_encoder_vs_x1   = l2_distance(x2_encoder_proj, x1)
dist_proj_vs_encoder = l2_distance(x2_proj, x2_encoder_proj)

print(f"L2(x2_proj, x1)      = {dist_proj_vs_x1:.4f}")
print(f"L2(x2_encoder_proj, x1) = {dist_encoder_vs_x1:.4f}")
print(f"L2(x2_proj, x2_encoder_proj) = {dist_proj_vs_encoder:.4f}")



L2(x2_proj, x1)      = 0.0778
L2(x2_encoder_proj, x1) = 1.4049
L2(x2_proj, x2_encoder_proj) = 1.4071
