In [1]:
import numpy as np

In [2]:
def generate_simplex_vectors_with_projection(dimension, apply_random_rotation=False):
    """
    Generate n wave vectors (omega) in an n-dimensional space using a regular simplex projection.
    Optionally apply a random rotation to the output vectors.

    Parameters:
        dimension (int): The dimension of the space.
        apply_random_rotation (bool): Whether to apply a random rotation.

    Returns:
        np.ndarray: An (n, n) array of normalized vectors.
    """
    if dimension == 1:
        return np.array([[1.0]])

    # Step 1: Create n+1 identity points
    points = np.eye(dimension + 1)

    # Step 2: Center points to lie on hyperplane
    points -= np.mean(points, axis=0)

    # Step 3: SVD -> get orthogonal basis (n x n)
    U, _, _ = np.linalg.svd(points.T, full_matrices=False)
    reduced_vectors = U[:, :-1]  # Shape: (n+1, n)

    # Step 4: Normalize each vector
    reduced_vectors /= np.linalg.norm(reduced_vectors, axis=1, keepdims=True)

    # Step 5: Optional random rotation
    if apply_random_rotation:
        # Generate a random orthogonal matrix using QR decomposition
        Q, _ = np.linalg.qr(np.random.randn(dimension, dimension))
        reduced_vectors = reduced_vectors @ Q.T  # Apply rotation

    return reduced_vectors

In [3]:
def generate_multi_scale_positional_encoding(omega, x, embedding_dim):
    """
    Generate multi-scale positional encoding using different scales.

    Parameters:
    - omega: (n, d) array, wave vectors.
    - x: (d,) array, input position vector.
    - embedding_dim: int, dimensionality of the embedding space.

    Returns:
    - pe_multi: (S, n) complex-valued matrix.
    """
    n, _ = omega.shape
    S = embedding_dim // (2 * n) # int, number of scales.
    assert embedding_dim % (2*n) == 0, "embedding_dim must be divisible by 2*n."
    scales = 1 / (
        10000 ** (2 * n * np.arange(S)[:, None] / embedding_dim)
    )  # Shape (S, 1)
    theta = (scales * (omega @ x)).T  # Shape (n, S)
    pe_multi = np.exp(1j * theta)  # Shape (n, S)
    return pe_multi, theta

# Full Rotation Matrix

In [4]:
def complex_to_rotation_matrix(angles):
    """
    Convert complex positional encoding values into 2D rotation matrices.

    Parameters:
    - pe: (n, S) complex array.

    Returns:
    - rotation_matrices: (n, S, 2, 2) array, each slice is a 2D rotation matrix.
    """
    cos_vals, sin_vals = np.cos(angles), np.sin(angles)
    rotation_matrices = np.stack([
        np.stack([cos_vals, -sin_vals], axis=-1),
        np.stack([sin_vals, cos_vals], axis=-1)
    ], axis=-2)  # Shape (n, S, 2, 2)
    return rotation_matrices

In [5]:
def construct_block_diagonal_rotation_matrix(rotation_matrices):
    """
    Construct the block diagonal rotation matrix.

    Parameters:
    - rotation_matrices: (n, S, 2, 2) array.

    Returns:
    - R: (2 * S * n, 2 * S * n) block diagonal matrix.
    """
    n, S, _, _ = rotation_matrices.shape
    R = np.zeros((2 * S * n, 2 * S * n))  # Initialize matrix

    for s in range(S):
        for i in range(n):
            row, col = 2 * (s * n + i), 2 * (s * n + i)
            R[row:row+2, col:col+2] = rotation_matrices[i, s]

    return R

In [6]:
# Example usage
dimension, S = 2, 4  # 2D space, assume 4 scales
embedding_length = 2 * S * (dimension+1)  # Embedding length
omega = generate_simplex_vectors_with_projection(dimension)  # Generate omega

x = np.array([1.0, 2.0])  # Position vector
pe_multi, theta = generate_multi_scale_positional_encoding(omega, x, embedding_length)  # Multi-scale positional encoding
rotation_matrices = complex_to_rotation_matrix(theta)  # Convert to rotation matrices
R = construct_block_diagonal_rotation_matrix(rotation_matrices)  # Construct block diagonal matrix

# Space Saved Format

In [7]:
def compute_rotation_vectors(theta):
    """
    Convert positional encoding angles directly into cosine and sine components.

    Parameters:
    - theta: (n, S) matrix of angles.

    Returns:
    - cos_vec: (n, S) cosine components.
    - sin_vec: (n, S) sine components.
    """
    cos_vec, sin_vec = np.cos(theta), np.sin(theta)  # Shape (n, S)
    return cos_vec, sin_vec

In [8]:
def apply_rotation(q, cos_vec, sin_vec):
    """
    Apply the rotation transformation in a memory-efficient manner.

    Parameters:
    - q: (2nS,) Input vector representing token embedding.
    - cos_vec: (n, S) Cosine components.
    - sin_vec: (n, S) Sine components.

    Returns:
    - q_rotated: (2n,) Output vector after transformation.
    """
    n, S = cos_vec.shape
    assert q.shape[0] % (2 * n) == 0, f"Input q's dimension {q.shape[0]} must be a multiple of 2n = {2 * n}"

    # Compute number of scales
    S_from_q = q.shape[0] // (2 * n)
    assert S_from_q == S, f"Inconsistent S: expected {S}, but got {S_from_q} from q.shape"
    
    # Reshape q to (S, 2n) for correct processing
    q = q.reshape(S, 2 * n)
    # Split q into even and odd components
    q_even, q_odd = q[:, ::2], q[:, 1::2]  # Shape: (S, n)

    # Compute rotated values
    q_rotated_even = q_even * cos_vec.T - q_odd * sin_vec.T  # Shape: (S, n)
    q_rotated_odd = q_even * sin_vec.T + q_odd * cos_vec.T  # Shape: (S, n)

    # Interleave (even, odd) pairs to restore shape
    q_rotated = np.empty((S, 2 * n))  # Shape: (S, 2n)
    q_rotated[:, 0::2] = q_rotated_even
    q_rotated[:, 1::2] = q_rotated_odd

    q_rotated = q_rotated.T.ravel()  # Shape: (2nS, )
    return q_rotated

In [9]:
# Example usage
dimension, S = 2, 4  # 2D space, assume 4 scales
embedding_length = 2 * S * (dimension+1)  # Embedding length
omega = generate_simplex_vectors_with_projection(dimension)  # Generate omega
x = np.array([1.0, 2.0])  # Position vector
pe_multi, theta = generate_multi_scale_positional_encoding(omega, x, embedding_length)  # Multi-scale positional encoding

# Convert q with to a rotation transformation
cos_vec, sin_vec = compute_rotation_vectors(theta)  # Compute cosine and sine components
q = np.random.randn(2 * S * (dimension+1))  # Random input vector
q_rotated = apply_rotation(q, cos_vec, sin_vec)  # Apply rotation transformation