# Vec2Vec: Unsupervised Embedding Translation

This notebook provides a PyTorch implementation of the Vec2Vec model, a method for unsupervised translation of vector embeddings between different semantic spaces. The approach is inspired by techniques in unsupervised machine translation (like CycleGAN) and aims to learn a shared universal latent space where embeddings from different source models can be aligned.

This implementation is based on the concepts described in the research paper "Harnessing the Universal Geometry of Embeddings"

## Core Concept

The primary goal is to translate an embedding `u` from a source space (generated by Model 1) to a target space (generated by Model 2) to obtain `v_approx`, such that `v_approx` is close to the "true" embedding `v` that Model 2 would have generated for the same underlying data. This translation is learned *without* any paired `(u, v)` examples.

The method relies on the **Strong Platonic Representation Hypothesis**: neural networks trained on similar tasks/modalities, even with different architectures or data, might converge to a universal latent space. Vec2Vec attempts to find and leverage this space.

## Model Architecture

The Vec2Vec model consists of the following main components:

1.  **Input Adapters (A1, A2):** Neural networks (MLPs) that map embeddings from their original high-dimensional spaces (dim1, dim2) to a common, lower-dimensional `latent_dim`.
2.  **Shared Backbone (T):** An MLP that processes these latent representations, aiming to refine and align them within the universal latent space.
3.  **Output Adapters (B1, B2):** MLPs that map representations from the shared backbone's output back to one of the original embedding spaces (dim1 or dim2).
4.  **Discriminators (D1, D2, D_l1, D_l2):**
    *   `D1`, `D2`: Distinguish real embeddings in their respective original spaces from "fake" embeddings translated from the other space.
    *   `D_l1`, `D_l2`: Operate in the latent space, trying to distinguish latent codes derived from Model 1 embeddings versus Model 2 embeddings. This encourages the latent representations to become domain-agnostic.

**Translation & Reconstruction:**
*   Translate from space 1 to 2: `F1(x) = B2(T(A1(x)))`
*   Translate from space 2 to 1: `F2(x) = B1(T(A2(x)))`
*   Reconstruct in space 1: `R1(x) = B1(T(A1(x)))`
*   Reconstruct in space 2: `R2(x) = B2(T(A2(x)))`

## Training

The model is trained using a combination of several loss functions in a GAN-like setup:

*   **Adversarial Loss:** Encourages the generated (translated) embeddings and latent codes to be indistinguishable from real ones.
*   **Reconstruction Loss:** Ensures that mapping an embedding to the latent space and then back to its original space reconstructs the original embedding faithfully.
*   **Cycle Consistency Loss:** Enforces that translating an embedding to the other space and then back to its original space should yield the original embedding (e.g., `x1 -> F1 -> x2_fake -> F2 -> x1_reconstructed ≈ x1`).
*   **Vector Space Preservation (VSP) Loss:** Aims to preserve the geometric structure (pairwise similarities) of the embedding space after translation.

## Demo Functionality (`demo()`)

The `demo()` function in the notebook:
1.  Loads two pre-trained sentence embedding models from `sentence-transformers` (`all-MiniLM-L6-v2` - 384 dims, and `all-mpnet-base-v2` - 768 dims) to simulate two different embedding spaces.
2.  Encodes a small sample of 10 texts using both models to create `data1` and `data2` (unpaired training data).
3.  Initializes the `Vec2Vec` model.
4.  Trains the model for a set number of epochs.
5.  Plots the generator and discriminator training losses.
6.  Evaluates the translation quality by:
    *   Calculating the Mean Cosine Similarity between translated embeddings `F1(data1)` and the target embeddings `data2`.
    *   Calculating Top-1 Accuracy (how often the closest vector in `data2` to a translated vector `F1(data1[i])` is actually `data2[i]`).
7.  Visualizes the learned universal latent space using t-SNE, plotting `T(A1(data1))` and `T(A2(data2))`.
8.  Tests translation on a single sample text.

**Disclaimer:** This notebook serves as a demonstration of the Vec2Vec architecture and training procedure. Achieving results comparable to the research paper would require significantly more data, computational resources, and hyperparameter optimization.

## References
[paper](https://arxiv.org/abs/2505.12540)

In [None]:
# CELL BLOCK 1: DATA LOADING AND PREPARATION
# Add this cell at the beginning of your Colab notebook,
# after your initial imports (like torch, nn, etc.) but before your class definitions.

import torch
from sentence_transformers import SentenceTransformer
from sklearn.datasets import fetch_20newsgroups  # For a sample dataset
import numpy as np
from tqdm.auto import tqdm
import nltk

# Download NLTK's sentence tokenizer data (if you haven't already)
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)


print("=== Data Loading and Preparation ===")

# --- Configuration ---
NUM_SAMPLES_FROM_NEWSGROUPS = 20000  # Number of documents to fetch from 20 Newsgroups
# Increase this for more data, but be mindful of Colab RAM for embeddings
# For better results, aim for 10k, 20k, or more if your Colab session can handle it.

MODEL1_NAME = "all-MiniLM-L6-v2"
MODEL2_NAME = "all-mpnet-base-v2"  # Or 'paraphrase-MiniLM-L6-v2' for faster embedding if mpnet is too slow
BATCH_SIZE_EMBEDDING = 64  # Batch size for encoding with sentence-transformers

# --- 1. Load Sentence Transformer Models ---
print("Loading embedding models...")
# It's good practice to define these once if they are used globally
# Or pass them into the data loading function if you prefer
try:
    sbert_model1 = SentenceTransformer(MODEL1_NAME)
    sbert_model2 = SentenceTransformer(MODEL2_NAME)
    print(
        f"Model 1 ({MODEL1_NAME}) loaded. Dim: {sbert_model1.get_sentence_embedding_dimension()}"
    )
    print(
        f"Model 2 ({MODEL2_NAME}) loaded. Dim: {sbert_model2.get_sentence_embedding_dimension()}"
    )
except Exception as e:
    print(f"Error loading sentence transformer models: {e}")
    sbert_model1, sbert_model2 = None, None

# --- 2. Load 20 Newsgroups Dataset ---
if sbert_model1 and sbert_model2:
    print(
        f"\nFetching {NUM_SAMPLES_FROM_NEWSGROUPS} documents from 20 Newsgroups dataset..."
    )
    newsgroups_data = fetch_20newsgroups(
        subset="all",
        shuffle=True,
        random_state=42,
        remove=("headers", "footers", "quotes"),
    )
    # We take a slice of the data
    documents = newsgroups_data.data[:NUM_SAMPLES_FROM_NEWSGROUPS]
    print(f"Fetched {len(documents)} documents.")

    # --- 3. Extract Sentences (Optional, but good for finer-grained embeddings) ---
    # For simplicity here, we'll treat each document as a single "sentence" or text passage.
    # For better results with sentence embeddings, you'd tokenize into actual sentences.
    # Example for sentence tokenization (can be slow for many docs):
    all_sentences = []
    for doc in tqdm(documents, desc="Tokenizing documents into sentences"):
        all_sentences.extend(nltk.sent_tokenize(doc))
    print(f"Extracted {len(all_sentences)} sentences in total.")
    texts_to_embed = all_sentences
    # texts_to_embed = documents # Using full documents as "sentences" for this simpler example

    if len(texts_to_embed) < 2:
        print("Not enough texts to proceed. Increase NUM_SAMPLES_FROM_NEWSGROUPS.")
        data1_loaded, data2_loaded = None, None
    else:
        # --- 4. Split Texts into Two Disjoint Sets ---
        np.random.shuffle(texts_to_embed)
        split_point = len(texts_to_embed) // 2
        texts_for_model1 = texts_to_embed[:split_point]
        texts_for_model2 = texts_to_embed[split_point:]
        print(
            f"\nSplit into {len(texts_for_model1)} texts for model 1 and {len(texts_for_model2)} for model 2."
        )

        # --- 5. Generate Embeddings ---
        print("\nGenerating embeddings for model 1...")
        embeddings1_list = sbert_model1.encode(
            texts_for_model1, show_progress_bar=True, batch_size=BATCH_SIZE_EMBEDDING
        )
        data1_loaded = torch.FloatTensor(embeddings1_list)

        print("\nGenerating embeddings for model 2...")
        embeddings2_list = sbert_model2.encode(
            texts_for_model2, show_progress_bar=True, batch_size=BATCH_SIZE_EMBEDDING
        )
        data2_loaded = torch.FloatTensor(embeddings2_list)

        print(f"\nShape of loaded data1: {data1_loaded.shape}")
        print(f"Shape of loaded data2: {data2_loaded.shape}")
        print("=== Data Loading and Preparation Complete ===")
else:
    print("Sentence Transformer models not loaded. Cannot prepare data.")
    data1_loaded, data2_loaded = None, None

# These tensors 'data1_loaded' and 'data2_loaded' are now available globally
# for your demo function to use.

=== Data Loading and Preparation ===
Loading embedding models...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Model 1 (all-MiniLM-L6-v2) loaded. Dim: 384
Model 2 (all-mpnet-base-v2) loaded. Dim: 768

Fetching 20000 documents from 20 Newsgroups dataset...
Fetched 18846 documents.


Tokenizing documents into sentences:   0%|          | 0/18846 [00:00<?, ?it/s]

Extracted 208743 sentences in total.

Split into 104371 texts for model 1 and 104372 for model 2.

Generating embeddings for model 1...


Batches:   0%|          | 0/1631 [00:00<?, ?it/s]


Generating embeddings for model 2...


Batches:   0%|          | 0/1631 [00:00<?, ?it/s]


Shape of loaded data1: torch.Size([104371, 384])
Shape of loaded data2: torch.Size([104372, 768])
=== Data Loading and Preparation Complete ===


In [None]:
# @title <<< Yeo-Johnson Power Transform and λ Estimation Helpers (Stability Focus) >>>

# Epsilon for numerical stability - Maybe slightly larger?
EPS = 1e-5  # Increased slightly from 1e-6


def psi(h, lambda_):
    """
    Applies the Yeo-Johnson power transform element-wise.
    (Added NaN/Inf checks and clipping)
    """
    # --- Input Checks ---
    if torch.isnan(h).any() or torch.isinf(h).any():
        # warnings.warn("NaN/Inf detected in input 'h' to psi. Clipping and continuing.")
        h = torch.nan_to_num(h, nan=0.0, posinf=1e6, neginf=-1e6)  # Replace NaN/Inf
        # Consider more aggressive clipping if needed: h = torch.clamp(h, -10.0, 10.0)

    if torch.isnan(lambda_).any() or torch.isinf(lambda_).any():
        # warnings.warn("NaN/Inf detected in input 'lambda_' to psi. Clipping to 1.0.")
        lambda_ = torch.nan_to_num(
            lambda_, nan=1.0, posinf=1.0, neginf=1.0
        )  # Fallback lambda

    # --- Original psi logic (with existing safety checks) ---
    try:
        _ = torch.broadcast_shapes(h.shape, lambda_.shape)
    except RuntimeError as e:
        raise RuntimeError(
            f"Shape mismatch: Cannot broadcast lambda_ {lambda_.shape} to h {h.shape}. Error: {e}"
        )

    h_ge_0 = h >= 0

    # --- Case: h >= 0 ---
    psi_ge0_lam0 = torch.log1p(h)
    denominator_ge0 = lambda_ + torch.sign(lambda_) * EPS
    denominator_ge0 = torch.where(
        denominator_ge0 == 0,
        torch.copysign(torch.tensor(EPS), lambda_),
        denominator_ge0,
    )  # Use signed EPS if zero
    # Clip input to pow to prevent overflow/NaN with extreme lambdas
    pow_input_ge0 = torch.clamp(1 + h, min=EPS)  # Ensure base is positive
    # Handle potential NaN from pow itself
    pow_result_ge0 = torch.pow(pow_input_ge0, lambda_)
    pow_result_ge0 = torch.nan_to_num(
        pow_result_ge0, nan=1.0
    )  # If pow fails, treat as if (1+h)^0=1
    psi_ge0_lam_ne0 = (pow_result_ge0 - 1) / denominator_ge0
    psi_ge0 = torch.where(lambda_ == 0, psi_ge0_lam0, psi_ge0_lam_ne0)

    # --- Case: h < 0 ---
    lambda_minus_2 = 2.0 - lambda_
    psi_lt0_lam2 = -torch.log1p(-h)  # Where lambda_minus_2 = 0
    denominator_lt0 = lambda_minus_2 + torch.sign(lambda_minus_2) * EPS
    denominator_lt0 = torch.where(
        denominator_lt0 == 0,
        torch.copysign(torch.tensor(EPS), lambda_minus_2),
        denominator_lt0,
    )  # Use signed EPS if zero
    # Clip input to pow
    pow_input_lt0 = torch.clamp(1 - h, min=EPS)  # Ensure base is positive
    # Handle potential NaN from pow
    pow_result_lt0 = torch.pow(pow_input_lt0, lambda_minus_2)
    pow_result_lt0 = torch.nan_to_num(
        pow_result_lt0, nan=1.0
    )  # If pow fails, treat as if (1-h)^0=1
    psi_lt0_lam_ne2 = (1.0 - pow_result_lt0) / denominator_lt0
    psi_lt0 = torch.where(lambda_minus_2 == 0, psi_lt0_lam2, psi_lt0_lam_ne2)

    # --- Combine ---
    result = torch.where(h_ge_0, psi_ge0, psi_lt0)

    # --- Final Output Check ---
    if torch.isnan(result).any() or torch.isinf(result).any():
        # warnings.warn("NaN/Inf encountered in psi function final output. Replacing with 0.")
        result = torch.nan_to_num(
            result, nan=0.0, posinf=0.0, neginf=0.0
        )  # Last resort

    return result


def estimate_lambda_hat(h, dims):
    """
    Estimates lambda_hat. (Added input check and clamping)
    """
    with torch.no_grad():
        # --- Input Check ---
        if torch.isnan(h).any() or torch.isinf(h).any():
            # warnings.warn("NaN/Inf detected in input 'h' to estimate_lambda_hat. Clipping and continuing.")
            h = torch.nan_to_num(h, nan=0.0, posinf=1e6, neginf=-1e6)  # Replace NaN/Inf

        # --- Original estimation logic ---
        h_abs = torch.abs(h)
        # Add EPS inside log1p? May not be needed if h isn't exactly 0 after clip.
        log1p_h_abs = torch.log1p(h_abs)  # + EPS inside if needed

        s3 = torch.mean(torch.pow(h, 3), dim=dims, keepdim=True)
        k = torch.mean(h * log1p_h_abs, dim=dims, keepdim=True)
        g = torch.mean(
            torch.pow(h, 2) * torch.pow(log1p_h_abs, 2), dim=dims, keepdim=True
        )

        # Check moments for NaN/Inf before using them
        s3 = torch.nan_to_num(s3, nan=0.0)
        k = torch.nan_to_num(k, nan=0.0)
        g = torch.nan_to_num(g, nan=1.0)  # Replace g NaN with 1 to avoid issues in L''

        L_prime_at_1 = k - 0.5 * s3
        # Ensure L'' denominator is reasonably large and positive
        L_double_prime_at_1 = torch.clamp(g - k + 1.0, min=EPS)  # Clamp denominator > 0

        lambda_hat = 1.0 - L_prime_at_1 / L_double_prime_at_1

        # --- Final Checks on lambda_hat ---
        if torch.isnan(lambda_hat).any() or torch.isinf(lambda_hat).any():
            # warnings.warn("NaN/Inf encountered after lambda_hat calculation. Setting lambda_hat=1.0.")
            lambda_hat = torch.nan_to_num(
                lambda_hat, nan=1.0, posinf=1.0, neginf=1.0
            )  # Fallback

        # Optional but recommended: Clamp lambda_hat to a reasonable range
        lambda_hat = torch.clamp(
            lambda_hat, -5.0, 5.0
        )  # Clamp to avoid extreme transforms

    return lambda_hat

In [None]:
# @title <<< Normality Normalization Layer Implementation (CNN Adaptation) >>>
import torch.nn as nn


class NormalityNormalization(nn.Module):
    # __init__ remains the same
    def __init__(
        self, normalized_shape, eps=1e-5, elementwise_affine=True, noise_factor=0.0
    ):
        """(No changes here needed for fix)"""
        super().__init__()
        # Allow integer or tuple for normalized_shape
        if isinstance(normalized_shape, int):
            # For CNNs, assume integer means number of channels
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.noise_factor = noise_factor

        # Determine expected number of dimensions based on normalized_shape length
        # This is a heuristic: if shape is like (C,), expect 4D (B,C,H,W) or 2D (B,C) input
        # If shape is like (H, W), expect 3D (B, H, W) ?? - less common
        # Let's focus on the CNN case: normalized_shape = (Channels,)

        if len(self.normalized_shape) == 1:
            self.num_features = self.normalized_shape[0]  # Expect this to be channels
        else:
            # Handle other cases if needed, for now assume LayerNorm-like for non-1D shape
            self.num_features = None  # Not used in LayerNorm-like mode

        if self.elementwise_affine:
            # gamma/beta should match the number of features (channels in CNN case)
            param_shape = (
                (self.num_features,)
                if self.num_features is not None
                else self.normalized_shape
            )
            self.gamma = nn.Parameter(torch.ones(param_shape))
            self.beta = nn.Parameter(torch.zeros(param_shape))
        else:
            self.register_parameter("gamma", None)
            self.register_parameter("beta", None)

    def forward(self, x):
        input_shape = x.shape
        input_ndim = x.ndim

        # --- 0. Input Check --- (same as before)
        if torch.isnan(x).any() or torch.isinf(x).any():
            # warnings.warn("NaN/Inf detected in input 'x'. Clipping.")
            x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)

        # --- Determine Normalization Dimensions ---
        if input_ndim == 4 and len(self.normalized_shape) == 1:
            # CNN Case: Input (B, C, H, W), normalized_shape=(C,)
            # Normalize across H, W dimensions (like BatchNorm, but instance-wise stats here)
            # Keep stats per channel per batch element.
            assert self.normalized_shape[0] == input_shape[1], (
                f"Expected {self.normalized_shape[0]} channels, got {input_shape[1]}"
            )
            dims = (2, 3)  # Dimensions H, W
            # Calculate N for noise scaling based on spatial dims
            N = input_shape[2] * input_shape[3]
            # Gamma/beta shape (1, C, 1, 1) for broadcasting
            affine_shape = (1, self.num_features, 1, 1)

        elif (
            input_ndim >= 2
            and self.normalized_shape == input_shape[-len(self.normalized_shape) :]
        ):
            # LayerNorm Case: Input (B, ..., F1, F2,..), normalized_shape=(F1, F2,..)
            norm_shape_len = len(self.normalized_shape)
            dims = tuple(range(input_ndim - norm_shape_len, input_ndim))
            N = math.prod(self.normalized_shape)
            affine_shape = [1] * (input_ndim - norm_shape_len) + list(
                self.normalized_shape
            )

        else:
            raise ValueError(
                f"Input shape {input_shape} and normalized_shape {self.normalized_shape} are incompatible."
            )

        # --- 1. Standard Normalization ---
        # Keep channel/feature dim separate, normalize over others specified in 'dims'
        mean = x.mean(dim=dims, keepdim=True)
        var = x.var(dim=dims, keepdim=True, unbiased=False)
        var = torch.clamp(var, min=self.eps * self.eps)  # Clamp variance
        std = torch.sqrt(var)
        h = (x - mean) / std

        # --- Check h ---
        if torch.isnan(h).any() or torch.isinf(h).any():
            # warnings.warn("NaN/Inf detected after standardization 'h'. Clipping.")
            h = torch.nan_to_num(h, nan=0.0, posinf=1e6, neginf=-1e6)
        h = torch.clamp(h, -1e5, 1e5)
        # --- 2. Estimate lambda_hat ---
        # Estimate lambda per element/channel group, over the normalized dims
        lambda_hat = estimate_lambda_hat(h.detach(), dims=dims)  # Pass correct dims

        # --- 3. Apply Power Transform ---
        x_transformed = psi(h, lambda_hat)
        x_transformed = torch.clamp(x_transformed, -1e5, 1e5)

        # --- 4. Add Scaled Gaussian Noise ---
        y = x_transformed
        if self.training and self.noise_factor > 0.0:
            with torch.no_grad():
                if torch.isnan(y).any() or torch.isinf(y).any():
                    # warnings.warn("NaN/Inf detected before noise. Clipping.")
                    y = torch.nan_to_num(y, nan=0.0, posinf=1e6, neginf=-1e6)

                xt_mean = y.mean(dim=dims, keepdim=True)  # Mean over H,W or Features
                if N == 0:
                    N = 1
                # L1 norm over H,W or Features
                # s = torch.linalg.norm(y - xt_mean, ord=1, dim=dims, keepdim=True) / N
                s = torch.linalg.norm(
                    y - xt_mean, ord=2, dim=dims, keepdim=True
                ) / torch.sqrt(
                    torch.tensor(N, device=y.device, dtype=y.dtype)
                )  # L2 norm scaled by sqrt(N)
                s = torch.nan_to_num(s, nan=0.0)
                s = torch.clamp(s, min=0.0, max=1e6)

            noise = torch.randn_like(y)
            y = y + noise * self.noise_factor * s

        # --- 5. Affine Transform ---
        if torch.isnan(y).any() or torch.isinf(y).any():
            # warnings.warn("NaN/Inf detected before affine. Clipping.")
            y = torch.nan_to_num(y, nan=0.0, posinf=1e6, neginf=-1e6)
        y = torch.clamp(y, -1e5, 1e5)
        if self.elementwise_affine:
            # Use pre-determined affine_shape for broadcasting
            gamma_reshaped = self.gamma.view(affine_shape)
            beta_reshaped = self.beta.view(affine_shape)
            out = y * gamma_reshaped + beta_reshaped
        else:
            out = y

        # --- Final Output Check --- (same as before)
        if torch.isnan(out).any() or torch.isinf(out).any():
            # warnings.warn("NaN/Inf detected in final output. Clipping.")
            out = torch.nan_to_num(out, nan=0.0, posinf=1.0, neginf=0.0)
        out = torch.clamp(out, -1e5, 1e5)  # Final clamping to a wide but finite range

        return out

    # extra_repr remains the same
    def extra_repr(self):
        return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}, noise_factor={noise_factor}".format(
            normalized_shape=self.normalized_shape,
            eps=self.eps,
            elementwise_affine=self.elementwise_affine,
            noise_factor=self.noise_factor,
        )

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
import warnings

# import datasets # datasets module is imported but not used in the provided code snippets
import math
import sys  # Import sys to flush stdout

warnings.filterwarnings("ignore")

# Assume NormalityNormalization, psi, estimate_lambda_hat classes/functions are defined elsewhere
# as they are in the provided context.


# Helper function to check for NaNs/Infs and report location
def check_nan_inf(tensor, name, location):
    if torch.isnan(tensor).any():
        print(f"WARNING: NaN detected in tensor '{name}' after {location}")
        sys.stdout.flush()  # Ensure print statement is shown immediately
        # Optional: Add a debugger breakpoint here:
        # import pdb; pdb.set_trace()
    if torch.isinf(tensor).any():
        print(f"WARNING: Inf detected in tensor '{name}' after {location}")
        sys.stdout.flush()  # Ensure print statement is shown immediately
        # Optional: Add a debugger breakpoint here:
        # import pdb; pdb.set_trace()


class InputAdapter(nn.Module):
    """Input adapter that transforms embeddings to universal latent space"""

    def __init__(self, input_dim: int, latent_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.norm1 = NormalityNormalization(normalized_shape=hidden_dim)
        self.silu1 = nn.SiLU()
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.norm2 = NormalityNormalization(normalized_shape=hidden_dim)
        self.silu2 = nn.SiLU()
        self.linear3 = nn.Linear(hidden_dim, latent_dim)
        self.norm3 = NormalityNormalization(normalized_shape=latent_dim)

    def forward(self, x):
        check_nan_inf(x, "InputAdapter_input", "input")
        x = self.linear1(x)
        check_nan_inf(x, "InputAdapter_linear1", "linear1")
        x = self.norm1(x)
        check_nan_inf(x, "InputAdapter_norm1", "norm1")
        x = self.silu1(x)
        check_nan_inf(x, "InputAdapter_silu1", "silu1")
        x = self.linear2(x)
        check_nan_inf(x, "InputAdapter_linear2", "linear2")
        x = self.norm2(x)
        check_nan_inf(x, "InputAdapter_norm2", "norm2")
        x = self.silu2(x)
        check_nan_inf(x, "InputAdapter_silu2", "silu2")
        x = self.linear3(x)
        check_nan_inf(x, "InputAdapter_linear3", "linear3")
        x = self.norm3(x)
        check_nan_inf(x, "InputAdapter_norm3", "norm3")
        return x


class OutputAdapter(nn.Module):
    """Output adapter that transforms from universal latent space to target space"""

    def __init__(self, latent_dim: int, output_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, hidden_dim)
        self.norm1 = NormalityNormalization(normalized_shape=hidden_dim)
        self.silu1 = nn.SiLU()
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.norm2 = NormalityNormalization(normalized_shape=hidden_dim)
        self.silu2 = nn.SiLU()
        self.linear3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        check_nan_inf(x, "OutputAdapter_input", "input")
        x = self.linear1(x)
        check_nan_inf(x, "OutputAdapter_linear1", "linear1")
        x = self.norm1(x)
        check_nan_inf(x, "OutputAdapter_norm1", "norm1")
        x = self.silu1(x)
        check_nan_inf(x, "OutputAdapter_silu1", "silu1")
        x = self.linear2(x)
        check_nan_inf(x, "OutputAdapter_linear2", "linear2")
        x = self.norm2(x)
        check_nan_inf(x, "OutputAdapter_norm2", "norm2")
        x = self.silu2(x)
        check_nan_inf(x, "OutputAdapter_silu2", "silu2")
        x = self.linear3(x)
        check_nan_inf(x, "OutputAdapter_linear3", "linear3")
        # OutputAdapter does not have a final NormalityNormalization in the original code
        # If you add one, uncomment the check below
        # x = self.norm3(x)
        # check_nan_inf(x, "OutputAdapter_norm3", "norm3")
        return x


class SharedBackbone(nn.Module):
    """Shared backbone network that processes universal latent representations"""

    def __init__(self, latent_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, hidden_dim)
        self.norm1 = NormalityNormalization(normalized_shape=hidden_dim)
        self.silu1 = nn.SiLU()
        self.linear2 = nn.Linear(hidden_dim, latent_dim)
        self.norm2 = NormalityNormalization(normalized_shape=latent_dim)

        # Residual connection
        self.residual = nn.Identity()

    def forward(self, x):
        check_nan_inf(x, "SharedBackbone_input", "input")
        identity = self.residual(x)
        check_nan_inf(identity, "SharedBackbone_identity", "residual")

        out = self.linear1(x)
        check_nan_inf(out, "SharedBackbone_linear1", "linear1")
        out = self.norm1(out)
        check_nan_inf(out, "SharedBackbone_norm1", "norm1")
        out = self.silu1(out)
        check_nan_inf(out, "SharedBackbone_silu1", "silu1")
        out = self.linear2(out)
        check_nan_inf(out, "SharedBackbone_linear2", "linear2")
        out = self.norm2(out)
        check_nan_inf(out, "SharedBackbone_norm2", "norm2")

        out = out + identity  # Residual connection
        check_nan_inf(out, "SharedBackbone_output", "residual_addition")
        return out


class Discriminator(nn.Module):
    """Discriminator for adversarial training"""

    def __init__(self, input_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.LeakyReLU(0.2)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.LeakyReLU(0.2)
        self.linear3 = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        check_nan_inf(x, "Discriminator_input", "input")
        out = self.linear1(x)
        check_nan_inf(out, "Discriminator_linear1", "linear1")
        out = self.relu1(out)
        check_nan_inf(out, "Discriminator_relu1", "relu1")
        out = self.linear2(out)
        check_nan_inf(out, "Discriminator_linear2", "linear2")
        out = self.relu2(out)
        check_nan_inf(out, "Discriminator_relu2", "relu2")
        out = self.linear3(out)
        check_nan_inf(out, "Discriminator_linear3", "linear3")
        output = self.sigmoid(out)  # Output values between 0 and 1
        check_nan_inf(output, "Discriminator_output", "sigmoid")
        return output


class Vec2Vec(nn.Module):
    """Main Vec2Vec model for unsupervised embedding translation"""

    def __init__(
        self, dim1: int, dim2: int, latent_dim: int = 256, hidden_dim: int = 256
    ):
        super().__init__()

        # Adapters
        self.A1 = InputAdapter(dim1, latent_dim, hidden_dim)
        self.A2 = InputAdapter(dim2, latent_dim, hidden_dim)
        self.B1 = OutputAdapter(latent_dim, dim1, hidden_dim)
        self.B2 = OutputAdapter(latent_dim, dim2, hidden_dim)

        # Shared backbone
        self.T = SharedBackbone(latent_dim, hidden_dim)

        # Discriminators
        self.D1 = Discriminator(dim1, hidden_dim)
        self.D2 = Discriminator(dim2, hidden_dim)
        self.D_l1 = Discriminator(latent_dim, hidden_dim)
        self.D_l2 = Discriminator(latent_dim, hidden_dim)

    def F1(self, x):
        """Translation from space 1 to space 2"""
        check_nan_inf(x, "F1_input", "F1_start")
        a1_out = self.A1(x)
        check_nan_inf(a1_out, "F1_A1_out", "A1")
        t_out = self.T(a1_out)
        check_nan_inf(t_out, "F1_T_out", "T")
        b2_out = self.B2(t_out)
        check_nan_inf(b2_out, "F1_B2_out", "B2")
        return b2_out

    def F2(self, x):
        """Translation from space 2 to space 1"""
        check_nan_inf(x, "F2_input", "F2_start")
        a2_out = self.A2(x)
        check_nan_inf(a2_out, "F2_A2_out", "A2")
        t_out = self.T(a2_out)
        check_nan_inf(t_out, "F2_T_out", "T")
        b1_out = self.B1(t_out)
        check_nan_inf(b1_out, "F2_B1_out", "B1")
        return b1_out

    def R1(self, x):
        """Reconstruction in space 1"""
        check_nan_inf(x, "R1_input", "R1_start")
        a1_out = self.A1(x)
        check_nan_inf(a1_out, "R1_A1_out", "A1")
        t_out = self.T(a1_out)
        check_nan_inf(t_out, "R1_T_out", "T")
        b1_out = self.B1(t_out)
        check_nan_inf(b1_out, "R1_B1_out", "B1")
        return b1_out

    def R2(self, x):
        """Reconstruction in space 2"""
        check_nan_inf(x, "R2_input", "R2_start")
        a2_out = self.A2(x)
        check_nan_inf(a2_out, "R2_A2_out", "A2")
        t_out = self.T(a2_out)
        check_nan_inf(t_out, "R2_T_out", "T")
        b2_out = self.B2(t_out)
        check_nan_inf(b2_out, "R2_B2_out", "B2")
        return b2_out

    def get_latent_1(self, x):
        """Get latent representation from space 1"""
        check_nan_inf(x, "get_latent_1_input", "start")
        a1_out = self.A1(x)
        check_nan_inf(a1_out, "get_latent_1_A1_out", "A1")
        t_out = self.T(a1_out)
        check_nan_inf(t_out, "get_latent_1_T_out", "T")
        return t_out

    def get_latent_2(self, x):
        """Get latent representation from space 2"""
        check_nan_inf(x, "get_latent_2_input", "start")
        a2_out = self.A2(x)
        check_nan_inf(a2_out, "get_latent_2_A2_out", "A2")
        t_out = self.T(a2_out)
        check_nan_inf(t_out, "get_latent_2_T_out", "T")
        return t_out


def adversarial_loss(discriminator, real_data, fake_data):
    """Standard GAN loss"""
    check_nan_inf(real_data, "adv_loss_real_data", "input")
    check_nan_inf(fake_data, "adv_loss_fake_data", "input")

    real_pred = discriminator(real_data)
    check_nan_inf(real_pred, "adv_loss_real_pred", "discriminator_real")
    fake_pred = discriminator(fake_data)
    check_nan_inf(fake_pred, "adv_loss_fake_pred", "discriminator_fake")

    # BCELoss expects inputs in the range [0, 1]. Discriminator output already has Sigmoid.
    try:
        real_loss = nn.BCELoss()(real_pred, torch.ones_like(real_pred))
        check_nan_inf(real_loss, "adv_loss_real_loss", "BCELoss_real")
    except RuntimeError as e:
        print(f"Error in BCELoss for real_pred: {e}")
        check_nan_inf(real_pred, "adv_loss_real_pred_before_bce", "before_BCELoss_real")
        raise  # Re-raise the exception after checking

    try:
        fake_loss = nn.BCELoss()(fake_pred, torch.zeros_like(fake_pred))
        check_nan_inf(fake_loss, "adv_loss_fake_loss", "BCELoss_fake")
    except RuntimeError as e:
        print(f"Error in BCELoss for fake_pred: {e}")
        check_nan_inf(fake_pred, "adv_loss_fake_pred_before_bce", "before_BCELoss_fake")
        raise  # Re-raise the exception after checking

    total_loss = real_loss + fake_loss
    check_nan_inf(total_loss, "adv_loss_total", "sum")
    return total_loss


def generator_adversarial_loss(discriminator, fake_data):
    """Generator loss for fooling discriminator"""
    check_nan_inf(fake_data, "gen_adv_loss_fake_data", "input")
    fake_pred = discriminator(fake_data)
    check_nan_inf(fake_pred, "gen_adv_loss_fake_pred", "discriminator_fake")

    # BCELoss expects inputs in the range [0, 1]. Discriminator output already has Sigmoid.
    try:
        loss = nn.BCELoss()(
            fake_pred, torch.ones_like(fake_pred)
        )  # Generator wants fake to be classified as real
        check_nan_inf(loss, "gen_adv_loss", "BCELoss")
    except RuntimeError as e:
        print(f"Error in BCELoss for gen_adv_loss: {e}")
        check_nan_inf(fake_pred, "gen_adv_loss_fake_pred_before_bce", "before_BCELoss")
        raise  # Re-raise the exception after checking

    return loss


def reconstruction_loss(model, x1, x2):
    """Reconstruction loss"""
    check_nan_inf(x1, "rec_loss_x1", "input")
    check_nan_inf(x2, "rec_loss_x2", "input")

    recon1 = model.R1(x1)
    check_nan_inf(recon1, "rec_loss_recon1", "R1")
    recon2 = model.R2(x2)
    check_nan_inf(recon2, "rec_loss_recon2", "R2")

    loss1 = nn.MSELoss()(recon1, x1)
    check_nan_inf(loss1, "rec_loss_loss1", "MSELoss1")
    loss2 = nn.MSELoss()(recon2, x2)
    check_nan_inf(loss2, "rec_loss_loss2", "MSELoss2")

    total_loss = loss1 + loss2
    check_nan_inf(total_loss, "rec_loss_total", "sum")
    return total_loss


def cycle_consistency_loss(model, x1, x2):
    """Cycle consistency loss"""
    check_nan_inf(x1, "cc_loss_x1", "input")
    check_nan_inf(x2, "cc_loss_x2", "input")

    cycle1 = model.F2(model.F1(x1))
    check_nan_inf(cycle1, "cc_loss_cycle1", "F2(F1(x1))")
    cycle2 = model.F1(model.F2(x2))
    check_nan_inf(cycle2, "cc_loss_cycle2", "F1(F2(x2))")

    loss1 = nn.MSELoss()(cycle1, x1)
    check_nan_inf(loss1, "cc_loss_loss1", "MSELoss1")
    loss2 = nn.MSELoss()(cycle2, x2)
    check_nan_inf(loss2, "cc_loss_loss2", "MSELoss2")

    total_loss = loss1 + loss2
    check_nan_inf(total_loss, "cc_loss_total", "sum")
    return total_loss


def vector_space_preservation_loss(model, x1, x2):
    """Vector space preservation loss"""
    check_nan_inf(x1, "vsp_loss_x1", "input")
    check_nan_inf(x2, "vsp_loss_x2", "input")

    batch_size = x1.size(0)
    if batch_size == 0:
        print("WARNING: Batch size is 0 in VSP loss, returning 0.")
        return torch.tensor(0.0, device=x1.device)

    # Original pairwise similarities (dot product)
    sim1_orig = torch.mm(x1, x1.t())
    check_nan_inf(sim1_orig, "vsp_loss_sim1_orig", "sim1_orig")
    sim2_orig = torch.mm(x2, x2.t())
    check_nan_inf(sim2_orig, "vsp_loss_sim2_orig", "sim2_orig")

    # Translated pairwise similarities (dot product)
    x1_trans = model.F1(x1)
    check_nan_inf(x1_trans, "vsp_loss_x1_trans", "F1")
    x2_trans = model.F2(x2)
    check_nan_inf(x2_trans, "vsp_loss_x2_trans", "F2")

    sim1_trans = torch.mm(x1_trans, x1_trans.t())
    check_nan_inf(sim1_trans, "vsp_loss_sim1_trans", "sim1_trans")
    sim2_trans = torch.mm(x2_trans, x2_trans.t())
    check_nan_inf(sim2_trans, "vsp_loss_sim2_trans", "sim2_trans")

    loss1 = nn.MSELoss()(sim1_trans, sim1_orig)
    check_nan_inf(loss1, "vsp_loss_loss1", "MSELoss1")
    loss2 = nn.MSELoss()(sim2_trans, sim2_orig)
    check_nan_inf(loss2, "vsp_loss_loss2", "MSELoss2")

    total_loss = loss1 + loss2
    check_nan_inf(total_loss, "vsp_loss_total", "sum")
    return total_loss


def train_vec2vec(
    model,
    data1,
    data2,
    epochs=100,
    lr=0.0002,
    lambda_rec=10.0,
    lambda_cc=10.0,
    lambda_vsp=1.0,
):
    """Train the Vec2Vec model"""

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Using device: {device}")

    # Ensure data is on the correct device
    # data1 = data1.to(device) # Handled by DataLoader implicitly moving tensors
    # data2 = data2.to(device) # Handled by DataLoader implicitly moving tensors

    # Optimizers
    gen_params = (
        list(model.A1.parameters())
        + list(model.A2.parameters())
        + list(model.B1.parameters())
        + list(model.B2.parameters())
        + list(model.T.parameters())
    )

    disc_params = (
        list(model.D1.parameters())
        + list(model.D2.parameters())
        + list(model.D_l1.parameters())
        + list(model.D_l2.parameters())
    )

    gen_optimizer = optim.Adam(gen_params, lr=lr, betas=(0.5, 0.999))
    disc_optimizer = optim.Adam(disc_params, lr=lr, betas=(0.5, 0.999))

    losses = {"gen": [], "disc": []}

    # Use DataLoader for more efficient batching
    # Create TensorDatasets and DataLoaders
    # Need to ensure data1 and data2 are Tensors before creating TensorDataset
    if not isinstance(data1, torch.Tensor):
        data1 = torch.FloatTensor(data1)
    if not isinstance(data2, torch.Tensor):
        data2 = torch.FloatTensor(data2)

    dataset1 = torch.utils.data.TensorDataset(data1)
    dataset2 = torch.utils.data.TensorDataset(data2)

    batch_size_train = min(
        64, len(data1), len(data2)
    )  # Use a reasonable training batch size
    dataloader1 = torch.utils.data.DataLoader(
        dataset1, batch_size=batch_size_train, shuffle=True, num_workers=2
    )  # Add num_workers
    dataloader2 = torch.utils.data.DataLoader(
        dataset2, batch_size=batch_size_train, shuffle=True, num_workers=2
    )  # Add num_workers

    for epoch in range(epochs):
        model.train()  # Set model to training mode
        total_gen_loss = 0
        total_disc_loss = 0
        num_batches = min(
            len(dataloader1), len(dataloader2)
        )  # Process up to the smaller number of batches

        # Use iterators to get batches from both DataLoaders
        iter1 = iter(dataloader1)
        iter2 = iter(dataloader2)

        for i in tqdm(range(num_batches), desc=f"Epoch {epoch + 1}/{epochs}"):
            try:
                # Get batches from both data loaders and move to device
                (x1_batch,) = next(iter1)
                x1_batch = x1_batch.to(device)
                (x2_batch,) = next(iter2)
                x2_batch = x2_batch.to(device)

            except StopIteration:
                # This should not happen if we iterate up to min(len(dataloader1), len(dataloader2))
                # but good practice to handle
                break
            except Exception as e:
                print(f"Error loading batch {i} in epoch {epoch}: {e}")
                continue  # Skip this batch

            check_nan_inf(x1_batch, "x1_batch", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(x2_batch, "x2_batch", f"Epoch {epoch}, Batch {i}")

            # --- Train Discriminators ---
            disc_optimizer.zero_grad()

            # Adversarial losses for discriminators
            # Use .detach() for fake data when training discriminators
            with torch.no_grad():
                # Ensure these forward passes are stable
                try:
                    fake_x1_from_x2 = model.F2(x2_batch)
                    fake_x2_from_x1 = model.F1(x1_batch)
                    latent1 = model.get_latent_1(x1_batch)
                    latent2 = model.get_latent_2(x2_batch)
                except Exception as e:
                    print(
                        f"Error during NO GRAD forward pass (Discriminator inputs) in epoch {epoch}, batch {i}: {e}"
                    )
                    check_nan_inf(
                        x1_batch,
                        "x1_batch_before_disc_forward",
                        f"Epoch {epoch}, Batch {i}",
                    )
                    check_nan_inf(
                        x2_batch,
                        "x2_batch_before_disc_forward",
                        f"Epoch {epoch}, Batch {i}",
                    )
                    continue  # Skip this batch if forward pass fails

            disc_loss1 = adversarial_loss(model.D1, x1_batch, fake_x1_from_x2)
            disc_loss2 = adversarial_loss(model.D2, x2_batch, fake_x2_from_x1)
            # Note: The paper/common practice often has D_l1 try to distinguish A1(x1) vs A2(x2)
            # and D_l2 distinguish A2(x2) vs A1(x1), essentially encouraging both A1 and A2
            # to map their respective inputs into latent codes that fool *both* latent discriminators.
            # The current code has D_l1 distinguish latent1 vs latent2 and D_l2 distinguish latent2 vs latent1,
            # which is slightly different but aims for the same goal (domain-agnostic latent space).
            disc_loss_l1 = adversarial_loss(
                model.D_l1, latent1.detach(), latent2.detach()
            )  # Ensure detach
            disc_loss_l2 = adversarial_loss(
                model.D_l2, latent2.detach(), latent1.detach()
            )  # Ensure detach

            check_nan_inf(disc_loss1, "disc_loss1", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(disc_loss2, "disc_loss2", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(disc_loss_l1, "disc_loss_l1", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(disc_loss_l2, "disc_loss_l2", f"Epoch {epoch}, Batch {i}")

            disc_total_loss = disc_loss1 + disc_loss2 + disc_loss_l1 + disc_loss_l2
            check_nan_inf(
                disc_total_loss, "disc_total_loss", f"Epoch {epoch}, Batch {i}"
            )

            # Only step if loss is not NaN/Inf
            if torch.isfinite(disc_total_loss):
                disc_total_loss.backward()
                disc_optimizer.step()
                total_disc_loss += disc_total_loss.item()
            else:
                print(
                    f"WARNING: Disc loss is NaN/Inf in Epoch {epoch}, Batch {i}. Skipping optimizer step."
                )
                sys.stdout.flush()

            # --- Train Generators ---
            gen_optimizer.zero_grad()

            # Adversarial losses for generators (try to fool discriminators)
            # Feed *non-detached* fake data to calculate gradients for generators
            # Ensure these forward passes are stable
            try:
                fake_x1_from_x2 = model.F2(x2_batch)
                fake_x2_from_x1 = model.F1(x1_batch)
                latent1 = model.get_latent_1(x1_batch)
                latent2 = model.get_latent_2(x2_batch)
            except Exception as e:
                print(
                    f"Error during Generator forward pass in epoch {epoch}, batch {i}: {e}"
                )
                check_nan_inf(
                    x1_batch, "x1_batch_before_gen_forward", f"Epoch {epoch}, Batch {i}"
                )
                check_nan_inf(
                    x2_batch, "x2_batch_before_gen_forward", f"Epoch {epoch}, Batch {i}"
                )
                continue  # Skip this batch

            gen_adv_loss1 = generator_adversarial_loss(
                model.D1, fake_x2_from_x1
            )  # G tries to make F1(x1) look real to D2 - Typo? Should be D2? Yes, F1 -> Space 2, Discriminator D2
            gen_adv_loss2 = generator_adversarial_loss(
                model.D2, fake_x1_from_x2
            )  # G tries to make F2(x2) look real to D1 - Typo? Should be D1? Yes, F2 -> Space 1, Discriminator D1
            gen_adv_loss_l1 = generator_adversarial_loss(
                model.D_l1, latent2
            )  # G tries to make latent2 look like latent1 to D_l1 - Typo? Should be D_l1(latent2) vs ones? Yes.
            gen_adv_loss_l2 = generator_adversarial_loss(
                model.D_l2, latent1
            )  # G tries to make latent1 look like latent2 to D_l2 - Typo? Should be D_l2(latent1) vs ones? Yes.

            check_nan_inf(gen_adv_loss1, "gen_adv_loss1", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(gen_adv_loss2, "gen_adv_loss2", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(
                gen_adv_loss_l1, "gen_adv_loss_l1", f"Epoch {epoch}, Batch {i}"
            )
            check_nan_inf(
                gen_adv_loss_l2, "gen_adv_loss_l2", f"Epoch {epoch}, Batch {i}"
            )

            # Other losses - Ensure these forward passes are stable
            try:
                rec_loss = reconstruction_loss(model, x1_batch, x2_batch)
                cc_loss = cycle_consistency_loss(model, x1_batch, x2_batch)
                vsp_loss = vector_space_preservation_loss(model, x1_batch, x2_batch)
            except Exception as e:
                print(
                    f"Error during Reconstruction/Cycle/VSP loss calculation in epoch {epoch}, batch {i}: {e}"
                )
                check_nan_inf(
                    x1_batch,
                    "x1_batch_before_other_losses",
                    f"Epoch {epoch}, Batch {i}",
                )
                check_nan_inf(
                    x2_batch,
                    "x2_batch_before_other_losses",
                    f"Epoch {epoch}, Batch {i}",
                )
                continue  # Skip this batch

            check_nan_inf(rec_loss, "rec_loss", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(cc_loss, "cc_loss", f"Epoch {epoch}, Batch {i}")
            check_nan_inf(vsp_loss, "vsp_loss", f"Epoch {epoch}, Batch {i}")

            gen_total_loss = (
                gen_adv_loss1
                + gen_adv_loss2
                + gen_adv_loss_l1
                + gen_adv_loss_l2
                + lambda_rec * rec_loss
                + lambda_cc * cc_loss
                + lambda_vsp * vsp_loss
            )
        gen_total_loss.backward()
        gen_optimizer.step()

        losses["gen"].append(gen_total_loss.item())
        losses["disc"].append(disc_total_loss.item())

        if epoch % 20 == 0:
            print(
                f"Epoch {epoch}: Gen Loss = {gen_total_loss.item():.4f}, "
                f"Disc Loss = {disc_total_loss.item():.4f}"
            )

    return losses


def evaluate_translation_batched(
    model, data1, data2, original_texts=None, batch_size_eval=1024
):
    """Evaluate translation quality in batches to reduce RAM usage."""
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        # Move data to the same device as the model
        device = next(model.parameters()).device
        data1 = data1.to(device)
        data2 = data2.to(device)

        # Normalize target data once
        norm_data2 = F.normalize(data2, p=2, dim=1)

        all_max_sim_values = []
        all_max_sim_indices = []

        # Process data1 in batches for translation and similarity calculation
        for i in tqdm(
            range(0, data1.size(0), batch_size_eval), desc="Evaluating in batches"
        ):
            batch_data1 = data1[i : i + batch_size_eval]
            translated_batch = model.F1(batch_data1)

            # Normalize the translated batch
            norm_translated_batch = F.normalize(translated_batch, p=2, dim=1)

            # Calculate cosine similarity for this batch against all of data2
            # (batch_size, D) @ (D, N2) -> (batch_size, N2)
            batch_cos_sim = torch.matmul(norm_translated_batch, norm_data2.t())

            # Get the max similarity and its index for each row in the batch
            batch_max_sim_values, batch_max_sim_indices = torch.max(
                batch_cos_sim, dim=1
            )

            all_max_sim_values.append(batch_max_sim_values.cpu())
            all_max_sim_indices.append(batch_max_sim_indices.cpu())

        # Concatenate results from all batches
        max_sim_values = torch.cat(all_max_sim_values, dim=0)
        max_sim_indices = torch.cat(all_max_sim_indices, dim=0)

        # Calculate Mean Cosine Similarity
        mean_similarity = torch.mean(max_sim_values).item()

        # Calculate Top-1 Accuracy
        # This assumes an index-wise correspondence (data1[i] corresponds to data2[i])
        # This is only meaningful if the data was generated from the same original list
        # and split for embedding, maintaining the order.
        n_samples_data1 = data1.size(0)
        # We compare the found index (index into data2) against the expected index
        # The expected index for the i-th element of data1 is just i (if len(data1) <= len(data2))
        # Or, for the first min(len(data1), len(data2)) elements if lengths differ.

        # Create the "correct" indices assuming a one-to-one mapping was desired
        # This means translated_12[k] should ideally be closest to data2[k]
        ideal_matches = torch.arange(n_samples_data1)  # On CPU now

        if n_samples_data1 <= data2.size(0):
            top1_matches = (
                (max_sim_indices[:n_samples_data1] == ideal_matches).sum().item()
            )
            top1_accuracy = top1_matches / n_samples_data1 if n_samples_data1 > 0 else 0
        else:  # If data1 is larger than data2, this top-1 interpretation is problematic.
            # We can only evaluate for the first len(data2) items of data1.
            # print_once("Warning: len(data1) > len(data2). Top-1 accuracy calculated for the first len(data2) samples of data1.")
            eval_count = data2.size(0)
            top1_matches = (
                (max_sim_indices[:eval_count] == ideal_matches[:eval_count])
                .sum()
                .item()
            )
            top1_accuracy = top1_matches / eval_count if eval_count > 0 else 0

        print(f"Mean Cosine Similarity: {mean_similarity:.4f}")
        print(
            f"Top-1 Accuracy (assuming index-wise correspondence): {top1_accuracy:.4f}"
        )

        model.train()  # Set model back to training mode
        return mean_similarity, top1_accuracy


# Helper function to print a warning only once
_printed_warnings = set()


def print_once(message):
    if message not in _printed_warnings:
        print(message)
        _printed_warnings.add(message)


def visualize_embeddings(model, data1, data2, title="Embedding Visualization"):
    """Visualize embeddings using t-SNE"""
    with torch.no_grad():
        # Get latent representations
        latent1 = model.get_latent_1(data1).cpu().numpy()
        latent2 = model.get_latent_2(data2).cpu().numpy()

        # Combine for t-SNE
        combined = np.vstack([latent1, latent2])

        # Apply t-SNE
        tsne = TSNE(
            n_components=2, random_state=42, perplexity=min(30, len(combined) // 4)
        )
        embedded = tsne.fit_transform(combined)

        # Plot
        plt.figure(figsize=(10, 8))

        # Plot latent representations from model 1
        plt.scatter(
            embedded[: len(latent1), 0],
            embedded[: len(latent1), 1],
            c="red",
            alpha=0.6,
            label="Model 1 Latents",
            s=50,
        )

        # Plot latent representations from model 2
        plt.scatter(
            embedded[len(latent1) :, 0],
            embedded[len(latent1) :, 1],
            c="blue",
            alpha=0.6,
            label="Model 2 Latents",
            s=50,
        )

        plt.title(title)
        plt.legend()
        plt.xlabel("t-SNE Dimension 1")
        plt.ylabel("t-SNE Dimension 2")
        plt.grid(True, alpha=0.3)
        plt.show()


def demo():  # This is your existing demo function
    """Run a complete demo of Vec2Vec"""
    print("Vec2Vec Demo: Unsupervised Embedding Translation")
    print("=" * 50)

    # Check if data was loaded successfully from the cell above
    if data1_loaded is None or data2_loaded is None:
        print("Error: Data not loaded. Please run the data loading cell first.")
        print("Falling back to original small demo data...")
        # Original small demo texts (as a fallback)
        texts = [
            "The cat sat on the mat",
            "A dog runs in the park",
            "Machine learning is fascinating",
            "Natural language processing",
            "Deep neural networks",
            "Artificial intelligence revolution",
            "Computer vision applications",
            "Data science and analytics",
            "Quantum computing future",
            "Robotics and automation",
        ]
        # Use globally defined sbert_model1 and sbert_model2 if they exist from data loading cell
        if (
            "sbert_model1" in globals()
            and "sbert_model2" in globals()
            and sbert_model1
            and sbert_model2
        ):
            print("Using pre-loaded SBERT models for fallback demo data.")
            model1_sbert_instance = sbert_model1
            model2_sbert_instance = sbert_model2
        else:
            print("Loading SBERT models for fallback demo data...")
            model1_sbert_instance = SentenceTransformer("all-MiniLM-L6-v2")
            model2_sbert_instance = SentenceTransformer("all-mpnet-base-v2")

        embeddings1 = model1_sbert_instance.encode(texts)
        embeddings2 = model2_sbert_instance.encode(texts)
        data1 = torch.FloatTensor(embeddings1)
        data2 = torch.FloatTensor(embeddings2)
        # For fallback, original_texts for evaluation are the small 'texts' list
        original_texts_for_eval = texts
        epochs_to_run = 200  # Epochs for small fallback demo
    else:
        print("Using data loaded from the 20 Newsgroups dataset.")
        data1 = data1_loaded
        data2 = data2_loaded
        # For the loaded data, we don't have a direct "paired" original_texts list for simple eval.
        # The evaluate_translation function will compare against the entire data2 set.
        original_texts_for_eval = (
            None  # Or you could pass a subset of texts_for_model1 for sample display
        )
        epochs_to_run = (
            5000  # Increase epochs for the larger dataset! Start with 500-1000.
        )
        # For 5000 samples, 1000 epochs is a starting point.
        # You might need many more (e.g., 5000-10000+) for good convergence.

    print(
        f"Model 1 embedding dimension: {data1.shape[1]}, Number of samples: {data1.shape[0]}"
    )
    print(
        f"Model 2 embedding dimension: {data2.shape[1]}, Number of samples: {data2.shape[0]}"
    )

    # Initialize Vec2Vec model
    print("\nInitializing Vec2Vec model...")
    # You might want to adjust latent_dim based on your data size and complexity
    vec2vec = Vec2Vec(
        dim1=data1.shape[1], dim2=data2.shape[1], latent_dim=128, hidden_dim=256
    )

    print(f"Training Vec2Vec for {epochs_to_run} epochs...")
    losses = train_vec2vec(
        vec2vec,
        data1,
        data2,
        epochs=epochs_to_run,
        lr=0.0002,
        lambda_rec=10.0,
        lambda_cc=10.0,
        lambda_vsp=1.0,
    )  # lambdas can also be tuned

    # Plot training losses (same as before)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(losses["gen"], label="Generator")
    plt.plot(losses["disc"], label="Discriminator")
    plt.title("Training Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    recent_gen = losses["gen"][-50:] if len(losses["gen"]) > 50 else losses["gen"]
    recent_disc = losses["disc"][-50:] if len(losses["disc"]) > 50 else losses["disc"]
    plt.plot(recent_gen, label="Generator (Recent)")
    plt.plot(recent_disc, label="Discriminator (Recent)")
    plt.title("Recent Training Losses")
    plt.xlabel("Recent Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Evaluate translation
    print("\nEvaluating translation quality...")
    # Pass original_texts_for_eval if you want to see specific examples from the source
    # The current evaluate_translation function doesn't strictly need it for its metrics
    mean_sim, top1_acc = evaluate_translation_batched(
        vec2vec,
        data1,
        data2,
        original_texts=original_texts_for_eval,
        batch_size_eval=1024,
    )
    print(f"Mean Cosine Similarity: {mean_sim:.4f}")
    print(f"Top-1 Accuracy: {top1_acc:.4f}")

    # Visualize embeddings
    print("\nVisualizing learned universal latent space...")
    # For t-SNE with larger data, you might want to sample a subset to avoid long computation times
    sample_size_tsne = min(
        1000, data1.shape[0], data2.shape[0]
    )  # e.g., plot 1000 points
    idx1_tsne = torch.randperm(data1.shape[0])[:sample_size_tsne]
    idx2_tsne = torch.randperm(data2.shape[0])[:sample_size_tsne]
    visualize_embeddings(
        vec2vec,
        data1[idx1_tsne],
        data2[idx2_tsne],
        "Vec2Vec Universal Latent Space (t-SNE on Sample)",
    )

    # Test translation on a specific example (if using fallback or have access to texts_for_model1)
    print("\nTesting translation on sample text...")
    if original_texts_for_eval:  # If using the small demo data
        sample_idx = 0
        sample_text_display = original_texts_for_eval[sample_idx]
    elif "texts_for_model1" in globals() and texts_for_model1:  # If using 20 newsgroups
        sample_idx = 0  # Or a random index
        sample_text_display = (
            texts_for_model1[sample_idx][:100] + "..."
        )  # Display snippet
    else:
        sample_text_display = "N/A (original text not available for this sample)"
        sample_idx = 0

    with torch.no_grad():
        # Ensure sample_idx is valid for the current data1
        if sample_idx < data1.shape[0] and sample_idx < data2.shape[0]:
            original_emb1 = data1[sample_idx : sample_idx + 1]
            translated_emb = vec2vec.F1(original_emb1)
            target_emb2 = data2[
                sample_idx : sample_idx + 1
            ]  # Note: this is NOT the paired target
            # it's just data2[sample_idx] for comparison
            # True evaluation uses the whole data2 in evaluate_translation

            similarity = torch.cosine_similarity(translated_emb, target_emb2).item()
            print(
                f"Sample text (from source set 1, index {sample_idx}): '{sample_text_display}'"
            )
            print(
                f"Similarity of its translation with text at index {sample_idx} from source set 2: {similarity:.4f}"
            )
        else:
            print("Sample index out of bounds for loaded data.")

    print("\nVec2Vec demo completed!")
    print("With a larger dataset and more epochs, translation quality should improve.")
    print("Current Top-1 Accuracy is a key metric to watch.")


if __name__ == "__main__":
    demo()

Vec2Vec Demo: Unsupervised Embedding Translation
Using data loaded from the 20 Newsgroups dataset.
Model 1 embedding dimension: 384, Number of samples: 104371
Model 2 embedding dimension: 768, Number of samples: 104372

Initializing Vec2Vec model...
Training Vec2Vec for 5000 epochs...
Using device: cuda


Epoch 1/5000:   0%|          | 0/1631 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x768 and 384x256)

# Very difficult to get the Groups close together. We've gotten them closer. Going to try a few more hyperparams.