In [None]:
import jax.numpy as jnp
import jax


In [None]:
def create_random_complex_mult_data(num_points):
    x = jax.random.uniform(jax.random.PRNGKey(0), (num_points,2), minval=-1, maxval=1)
    y = jax.random.uniform(jax.random.PRNGKey(1), (num_points,2), minval=-1, maxval=1)
    # return x, y and the product of x and y
    prod_real = x[:,0] * y[:,0] - x[:,1] * y[:,1]
    prod_imag = x[:,0] * y[:,1] + x[:,1] * y[:,0]
    # stack the real and imaginary parts to get the complex product
    z = jnp.stack((prod_real, prod_imag), axis=1)
    return x, y, z

def expand_dim_complex_dataset(expansion_dim, x, y, z):
    projection_matrix = jax.random.normal(jax.random.PRNGKey(0), (expansion_dim, 2))
    x_expanded = x @ projection_matrix.T
    y_expanded = y @ projection_matrix.T
    z_expanded = z @ projection_matrix.T
    return x_expanded, y_expanded, z_expanded, projection_matrix

#
# 
# def expand_dim_complex_dataset_poly(expansion_dim, x, y, z, max_degree = 4):

def init_structure_constants(num_basis_fxn):
    # Initialize the structure constants
    structure_constants = jax.random.normal(jax.random.PRNGKey(0), (num_basis_fxn, num_basis_fxn, num_basis_fxn))
    return structure_constants

# def multiply_with_structure_consts(x, y, structure_constants):
#     # Perform the multiplication with the structure constants
#     z = jnp.einsum('i,j,kij->k', x, y, structure_constants) 
#     return z

# Corrected multiplication function
def multiply_with_structure_consts(x, y, structure_constants):
    # first check if its batch or 1d 
    if x.ndim == 1:
        x = x[None, :]
    if y.ndim == 1:
        y = y[None, :]
    return jnp.einsum('kij,bi,bj->bk', structure_constants, x, y)

# Data generation remains correct
def create_random_complex_mult_data(num_points):
    x = jax.random.uniform(jax.random.PRNGKey(0), (num_points,2), minval=-1, maxval=1)
    y = jax.random.uniform(jax.random.PRNGKey(1), (num_points,2), minval=-1, maxval=1)
    prod_real = x[:,0] * y[:,0] - x[:,1] * y[:,1]
    prod_imag = x[:,0] * y[:,1] + x[:,1] * y[:,0]
    z = jnp.stack((prod_real, prod_imag), axis=1)
    return x, y, z



# Example usage
# def pred_with_dim_red(x, y, dim_reducer, struct_consts):
#     # Initialize the structure constants
#     # Perform the multiplication with the structure constants
#     x_hat = dim_reducer(x)
#     y_hat = dim_reducer(y)
#     z = multiply_with_structure_consts(x, y, struct_consts)
#     return z

x, y, z = create_random_complex_mult_data(3)
x_expanded, y_expanded, z_expanded, projection_matrix = expand_dim_complex_dataset(100, x, y, z)




In [None]:
from jax import value_and_grad
import optax

In [None]:
# Training setup
def test_complex_right_vector_struct():
    # Generate data
    x, y, z_true = create_random_complex_mult_data(1000)
    
    # Initialize learnable parameters
    structure_constants = init_structure_constants(2)
    
    # Loss function
    def loss(params, x, y, z_true):
        z_pred = multiply_with_structure_consts(x, y, params)
        return jnp.mean((z_pred - z_true)**2)
    
    # Optimizer
    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(structure_constants)
    
    # Training loop
    for epoch in range(5000):
        l, grads = value_and_grad(loss)(structure_constants, x, y, z_true)
        updates, opt_state = optimizer.update(grads, opt_state)
        structure_constants = optax.apply_updates(structure_constants, updates)
        
        if epoch % 500 == 0:
            print(f"Epoch {epoch}, Loss: {l:.4f}")

    # Compare learned vs true constants
    true_C = jnp.array([
        [[1., 0.], [0., -1.]],  # Real part coefficients
        [[0., 1.], [1., 0.]]     # Imaginary part coefficients
    ])
    
    print("\nLearned structure constants:")
    print(structure_constants)
    print("\nTrue structure constants:")
    print(true_C)


In [None]:
# Training setup
def test_complex_high_dim_vector_struct():
    # Generate data
    x, y, z_true = create_random_complex_mult_data(1000)
    x_expanded, y_expanded, z_expanded, projection_matrix = expand_dim_complex_dataset(100, x, y, z_true)
    
    # Initialize learnable parameters
    structure_constants = init_structure_constants(2)
    
    # Loss function
    def loss(params, x, y, z_true):
        z_pred = multiply_with_structure_consts(x, y, params)
        return jnp.mean((z_pred - z_true)**2)
    
    # Optimizer
    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(structure_constants)
    
    # Training loop
    for epoch in range(5000):
        l, grads = value_and_grad(loss)(structure_constants, x, y, z_true)
        updates, opt_state = optimizer.update(grads, opt_state)
        structure_constants = optax.apply_updates(structure_constants, updates)
        
        if epoch % 500 == 0:
            print(f"Epoch {epoch}, Loss: {l:.4f}")

    # Compare learned vs true constants
    true_C = jnp.array([
        [[1., 0.], [0., -1.]],  # Real part coefficients
        [[0., 1.], [1., 0.]]     # Imaginary part coefficients
    ])
    
    print("\nLearned structure constants:")
    print(structure_constants)
    print("\nTrue structure constants:")
    print(true_C)


In [None]:
true_C = jnp.array([
    [[1., 0.], [0., -1.]],  # Real part coefficients
    [[0., 1.], [1., 0.]]     # Imaginary part coefficients
])

In [None]:
test_complex_right_vector_struct()

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial 

class ComplexMultiplicationModel(nn.Module):
    embed_dim: int = 2
    mlp_hidden: int = 64
    use_mlp: bool = True
    expansion_dim: int = 100
    
    def setup(self):
        # Learnable encoder/decoder projections
        if self.use_mlp:
            self.encoder = nn.Sequential([
                nn.Dense(self.mlp_hidden),
                nn.relu,
                nn.Dense(self.embed_dim)
            ])
            self.decoder = nn.Sequential([
                nn.Dense(self.mlp_hidden),
                nn.relu,
                nn.Dense(self.expansion_dim)
            ])
        else:
            self.encoder = nn.Dense(self.embed_dim)
            self.decoder = nn.Dense(self.expansion_dim)
            
        # Learnable structure constants
        self.structure_constants = self.param('C', 
            nn.initializers.normal(0.1), 
            (self.embed_dim, self.embed_dim, self.embed_dim))

    def __call__(self, x, y):
        # Encode inputs to latent space
        x_embed = self.encoder(x)
        y_embed = self.encoder(y)
        
        # Complex multiplication in latent space
        product = multiply_with_structure_consts(x_embed, y_embed, self.structure_constants)
        #product = jnp.einsum('bi,bj,ijk->bk', x_embed, y_embed, self.structure_constants)
        z_projected = self.decoder(product)
        x_reconstructed = self.decoder(x_embed)
        y_reconstructed = self.decoder(y_embed)
        
        # Decode back to expanded space
        return z_projected, x_reconstructed, y_reconstructed


def structure_const_entropy_loss(matrix):
    """
    The structure const is an NxNxN tensor, where N is the number of basis functions.
    For each structure const, structure_const[i, :, :], we want to ensure that the rows and 
    columns look like one hot vectors. So we define a loss as H(rows) + H(columns), where H is the entropy.
    """
    entropy = lambda x: jnp.sum(x * jnp.log(x + 1e-10), axis=-1)
    row_entropy = entropy(jnp.abs(matrix))
    col_entropy = entropy(jnp.abs(matrix.transpose(0, 2, 1)))
    return jnp.mean(row_entropy + col_entropy)

def train_model(x_e, y_e, z_e, config):
    model = ComplexMultiplicationModel(**config)
    
    def loss(params, x, y, z):
        pred, x_recon, y_recon = model.apply(params, x, y)
        mse_z = jnp.mean((pred - z) ** 2)
        mse_x = jnp.mean((x_recon - x) ** 2)
        mse_y = jnp.mean((y_recon - y) ** 2)
        C = params['params']['C']
        entropy_lambda = 0
        entropy_loss = structure_const_entropy_loss(C)
        return mse_z + mse_x + mse_y + entropy_lambda * entropy_loss
    
    grad_fn = jax.value_and_grad(loss)

    # Initialize
    rng = jax.random.PRNGKey(0)
    params = model.init(rng, x_e[0], y_e[0])
    optimizer = optax.adamw(1e-3, weight_decay=1e-4)
    opt_state = optimizer.init(params)

    @jax.jit
    def update(params, opt_state, x, y, z):
        l, grads = grad_fn(params, x, y, z)
        updates, opt_state = optimizer.update(grads, opt_state, params)  # Pass params here
        params = optax.apply_updates(params, updates)
        return params, opt_state, l

    # Training loop
    for epoch in range(5000):
        params, opt_state, current_loss = update(params, opt_state, x_e, y_e, z_e)
        if epoch % 500 == 0:
            print(f"Epoch {epoch}, Loss: {current_loss:.4f}")
            
    return params

In [None]:
structure_const_entropy_loss(jnp.array([[[5, 5], [2, -1.]], [[2., 1.], [1., 0.]]]))


In [None]:
# Generate expanded data
x, y, z = create_random_complex_mult_data(1000)
x_e, y_e, z_e, proj_mat = expand_dim_complex_dataset(100, x, y, z)

# Train with different configurations
config = {
    'embed_dim': 2,
    'mlp_hidden': 64,
    'use_mlp': False,
    'expansion_dim': 100
}

params = train_model(x_e, y_e, z_e, config)

In [None]:
structure_const_entropy_loss(params['params']['C'])

In [None]:
params

In [None]:
# After training, inspect the learned structure constants
learned_constants = params['params']['C']  # Access the 'C' parameter we defined

print("Learned Structure Constants:")
print(learned_constants)

In [None]:
from typing import List, Tuple

In [None]:
def get_structure_constants(basis_elts, proj_func, multiplication_function) -> jnp.ndarray:
    """
    Given a list of basis elements, return the structure constants.
    """
    num_basis = len(basis_elts)
    structure_constants = jnp.zeros((num_basis, num_basis, num_basis))
    
    for gamma, target_basis in enumerate(basis_elts):
        for alpha, a in enumerate(basis_elts):
            for beta, b in enumerate(basis_elts):
                # Get the projector for the target basis
                #proj_func = basis_element_projectors[gamma]
                # Compute the structure constant
                projection_val = proj_func(multiplication_function(a,b), target_basis)
                structure_constants = structure_constants.at[gamma, alpha, beta].set(projection_val)
    
    return structure_constants 

In [None]:
def complex_proj_func(z, target_basis):
    return jnp.dot(jnp.array([z.real, z.imag]), jnp.array([target_basis.real, target_basis.imag]))

def complex_multiplication_function(z1, z2):
    return z1 * z2

basis_elts = [1, 1j]
structure_constants = get_structure_constants(basis_elts, complex_proj_func, complex_multiplication_function)
print(structure_constants)


In [None]:

basis_elts = [ (1 - 1j)/jnp.sqrt(2), (1 + 1j)/jnp.sqrt(2)]
structure_constants = get_structure_constants(basis_elts, complex_proj_func, complex_multiplication_function)
print(structure_constants)


In [None]:
P = jnp.array([[1, 1],[1, -1]]) * (1/jnp.sqrt(2))
P_inv = jnp.linalg.inv(P)

In [None]:
one_mat = jnp.array([[1, 0], [0, 1]])
i_mat = jnp.array([[0, -1], [1, 0]])

In [None]:
def complex_proj_func_mat(z, target_basis):
    return 1/2 * jnp.trace(target_basis.T @ z)

def complex_multiplication_function_mat(z1, z2):
    return z1 @ z2

basis_elts = [one_mat, i_mat]
structure_constants = get_structure_constants(basis_elts, complex_proj_func_mat, complex_multiplication_function_mat)
print(structure_constants)

def two_d_rotation_matrix(theta):
    return jnp.array([[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]])
theta = jnp.pi/4
R = two_d_rotation_matrix(theta)
R_inv = two_d_rotation_matrix(-theta)

R_inv @ i_mat @ R