In [2]:
!pip install mlx

Collecting mlx
  Downloading mlx-0.21.1-cp311-cp311-macosx_14_0_arm64.whl.metadata (5.1 kB)
Downloading mlx-0.21.1-cp311-cp311-macosx_14_0_arm64.whl (27.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.3/27.3 MB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: mlx
Successfully installed mlx-0.21.1
Note: you may need to restart the kernel to use updated packages.


In [6]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from typing import List, Tuple

class PreNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)  # Bounded activation for hyperbolic space

class PostNet(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, output_dim)
        
    def __call__(self, x):
        return self.linear(x)

class PyramidLayer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)

class HyperbolicCube(nn.Module):
    def __init__(self, layer_dims: List[int]):
        super().__init__()
        self.pyramid_layers = [
            PyramidLayer(layer_dims[i], layer_dims[i+1])
            for i in range(len(layer_dims) - 1)
        ]
        
    def __call__(self, x):
        for layer in self.pyramid_layers:
            x = layer(x)
        return x

class DualHiddenLCM(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], hidden_dim2: int, output_dim: int):
        super().__init__()
        
        # Hidden Dimension 1: Pyramid structure
        self.prenet = PreNet(input_dim, hidden_dims[0])
        self.hyperbolic_cube = HyperbolicCube(hidden_dims)
        self.postnet = PostNet(hidden_dims[-1], output_dim)
        
        # Hidden Dimension 2: 20D bottleneck
        self.hidden_dim2 = nn.Linear(input_dim, hidden_dim2)
        self.hidden_dim2_output = nn.Linear(hidden_dim2, output_dim)
        
    def __call__(self, x):
        # Hidden Dimension 1
        x_hidden1 = self.prenet(x)
        x_hidden1 = self.hyperbolic_cube(x_hidden1)
        x_hidden1 = self.postnet(x_hidden1)
        
        # Hidden Dimension 2
        x_hidden2 = mx.maximum(0, self.hidden_dim2(x))  # ReLU
        x_hidden2 = self.hidden_dim2_output(x_hidden2)
        
        return x_hidden1 + x_hidden2

def compute_accuracy(predicted, target, threshold=0.1):
    # Cosine similarity implementation
    norm_pred = mx.sqrt(mx.sum(predicted * predicted, axis=-1))
    norm_target = mx.sqrt(mx.sum(target * target, axis=-1))
    cos_sim = mx.sum(predicted * target, axis=-1) / (norm_pred * norm_target)
    correct = (cos_sim > threshold).astype(mx.float32)
    return mx.mean(correct)

def train_step(model, x, y, optimizer):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

def main():
    # Model parameters
    input_dim = 300
    hidden_dims = [512, 256, 128, 64]
    hidden_dim2 = 20
    output_dim = 300
    batch_size = 4
    num_samples = 100
    epochs = 40
    
    # Generate synthetic data
    rng = np.random.default_rng(42)
    data = rng.normal(0, 1, (num_samples, input_dim)).astype(np.float32)
    target = data + rng.normal(0, 0.01, data.shape).astype(np.float32)
    
    # Convert to MLX arrays
    data = mx.array(data)
    target = mx.array(target)
    
    # Initialize model and optimizer
    model = DualHiddenLCM(input_dim, hidden_dims, hidden_dim2, output_dim)
    optimizer = optim.Adam(learning_rate=1e-4)
    
    # Training loop
    for epoch in range(epochs):
        total_loss = 0
        num_batches = num_samples // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_data = data[start_idx:end_idx]
            batch_target = target[start_idx:end_idx]
            
            loss = train_step(model, batch_data, batch_target, optimizer)
            mx.eval(loss)  # Ensure computation is complete
            total_loss += loss
            
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f}")

if __name__ == "__main__":
    main()

Epoch 1/40 | Loss: 1.0686
Epoch 2/40 | Loss: 0.9604
Epoch 3/40 | Loss: 0.9068
Epoch 4/40 | Loss: 0.8688
Epoch 5/40 | Loss: 0.8396
Epoch 6/40 | Loss: 0.8161
Epoch 7/40 | Loss: 0.7965
Epoch 8/40 | Loss: 0.7796
Epoch 9/40 | Loss: 0.7647
Epoch 10/40 | Loss: 0.7514
Epoch 11/40 | Loss: 0.7392
Epoch 12/40 | Loss: 0.7279
Epoch 13/40 | Loss: 0.7175
Epoch 14/40 | Loss: 0.7077
Epoch 15/40 | Loss: 0.6985
Epoch 16/40 | Loss: 0.6898
Epoch 17/40 | Loss: 0.6816
Epoch 18/40 | Loss: 0.6738
Epoch 19/40 | Loss: 0.6663
Epoch 20/40 | Loss: 0.6591
Epoch 21/40 | Loss: 0.6522
Epoch 22/40 | Loss: 0.6456
Epoch 23/40 | Loss: 0.6392
Epoch 24/40 | Loss: 0.6330
Epoch 25/40 | Loss: 0.6270
Epoch 26/40 | Loss: 0.6212
Epoch 27/40 | Loss: 0.6155
Epoch 28/40 | Loss: 0.6100
Epoch 29/40 | Loss: 0.6046
Epoch 30/40 | Loss: 0.5994
Epoch 31/40 | Loss: 0.5944
Epoch 32/40 | Loss: 0.5894
Epoch 33/40 | Loss: 0.5846
Epoch 34/40 | Loss: 0.5798
Epoch 35/40 | Loss: 0.5752
Epoch 36/40 | Loss: 0.5707
Epoch 37/40 | Loss: 0.5662
Epoch 38/4

Optimized with cosine learning rate schedule and warmup, using Python float type

In [9]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from typing import List, Tuple

class PreNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)  # Bounded activation for hyperbolic space

class PostNet(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, output_dim)
        
    def __call__(self, x):
        return self.linear(x)

class PyramidLayer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)

class HyperbolicCube(nn.Module):
    def __init__(self, layer_dims: List[int]):
        super().__init__()
        self.pyramid_layers = [
            PyramidLayer(layer_dims[i], layer_dims[i+1])
            for i in range(len(layer_dims) - 1)
        ]
        
    def __call__(self, x):
        for layer in self.pyramid_layers:
            x = layer(x)
        return x

class DualHiddenLCM(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], hidden_dim2: int, output_dim: int):
        super().__init__()
        
        # Hidden Dimension 1: Pyramid structure
        self.prenet = PreNet(input_dim, hidden_dims[0])
        self.hyperbolic_cube = HyperbolicCube(hidden_dims)
        self.postnet = PostNet(hidden_dims[-1], output_dim)
        
        # Hidden Dimension 2: 20D bottleneck
        self.hidden_dim2 = nn.Linear(input_dim, hidden_dim2)
        self.hidden_dim2_output = nn.Linear(hidden_dim2, output_dim)
        
    def __call__(self, x):
        # Hidden Dimension 1
        x_hidden1 = self.prenet(x)
        x_hidden1 = self.hyperbolic_cube(x_hidden1)
        x_hidden1 = self.postnet(x_hidden1)
        
        # Hidden Dimension 2
        x_hidden2 = mx.maximum(0, self.hidden_dim2(x))  # ReLU
        x_hidden2 = self.hidden_dim2_output(x_hidden2)
        
        return x_hidden1 + x_hidden2

def compute_accuracy(predicted, target, threshold=0.1):
    # Cosine similarity implementation
    norm_pred = mx.sqrt(mx.sum(predicted * predicted, axis=-1))
    norm_target = mx.sqrt(mx.sum(target * target, axis=-1))
    cos_sim = mx.sum(predicted * target, axis=-1) / (norm_pred * norm_target)
    correct = (cos_sim > threshold).astype(mx.float32)
    return mx.mean(correct)

def train_step(model, x, y, optimizer):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

def main():
    # Model parameters
    input_dim = 300
    hidden_dims = [512, 256, 128, 64]
    hidden_dim2 = 20
    output_dim = 300
    batch_size = 4
    num_samples = 100
    epochs = 40
    warmup_steps = 100
    max_lr = 2e-3
    min_lr = 1e-4
    
    # Generate synthetic data
    rng = np.random.default_rng(42)
    data = rng.normal(0, 1, (num_samples, input_dim)).astype(np.float32)
    target = data + rng.normal(0, 0.01, data.shape).astype(np.float32)
    
    # Convert to MLX arrays
    data = mx.array(data)
    target = mx.array(target)
    
    # Initialize model and optimizer
    model = DualHiddenLCM(input_dim, hidden_dims, hidden_dim2, output_dim)
    # Learning rate schedule
    total_steps = epochs * num_samples // batch_size
    lr_schedule = np.zeros(total_steps)
    
    # Warmup phase
    warmup_indices = np.arange(warmup_steps)
    lr_schedule[:warmup_steps] = min_lr + (max_lr - min_lr) * warmup_indices / warmup_steps
    
    # Cosine decay phase
    remaining_steps = total_steps - warmup_steps
    cosine_indices = np.arange(remaining_steps)
    cosine_decay = 0.5 * (1 + np.cos(np.pi * cosine_indices / remaining_steps))
    lr_schedule[warmup_steps:] = min_lr + (max_lr - min_lr) * cosine_decay
    
    optimizer = optim.Adam(learning_rate=lr_schedule[0])
    
    # Training loop
    for epoch in range(epochs):
        total_loss = 0
        num_batches = num_samples // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_data = data[start_idx:end_idx]
            batch_target = target[start_idx:end_idx]
            
            optimizer.learning_rate = lr_schedule[epoch * num_batches + i]
            loss = train_step(model, batch_data, batch_target, optimizer)
            mx.eval(loss)  # Ensure computation is complete
            total_loss += loss
            
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f}")

if __name__ == "__main__":
    main()

Epoch 1/40 | Loss: 1.0573
Epoch 2/40 | Loss: 0.9141
Epoch 3/40 | Loss: 0.9152
Epoch 4/40 | Loss: 0.9080
Epoch 5/40 | Loss: 0.8923
Epoch 6/40 | Loss: 0.7997
Epoch 7/40 | Loss: 0.6875
Epoch 8/40 | Loss: 0.5849
Epoch 9/40 | Loss: 0.5081
Epoch 10/40 | Loss: 0.4551
Epoch 11/40 | Loss: 0.4126
Epoch 12/40 | Loss: 0.3799
Epoch 13/40 | Loss: 0.3520
Epoch 14/40 | Loss: 0.3279
Epoch 15/40 | Loss: 0.3070
Epoch 16/40 | Loss: 0.2899
Epoch 17/40 | Loss: 0.2745
Epoch 18/40 | Loss: 0.2608
Epoch 19/40 | Loss: 0.2483
Epoch 20/40 | Loss: 0.2373
Epoch 21/40 | Loss: 0.2276
Epoch 22/40 | Loss: 0.2191
Epoch 23/40 | Loss: 0.2115
Epoch 24/40 | Loss: 0.2049
Epoch 25/40 | Loss: 0.1990
Epoch 26/40 | Loss: 0.1939
Epoch 27/40 | Loss: 0.1893
Epoch 28/40 | Loss: 0.1854
Epoch 29/40 | Loss: 0.1820
Epoch 30/40 | Loss: 0.1790
Epoch 31/40 | Loss: 0.1765
Epoch 32/40 | Loss: 0.1744
Epoch 33/40 | Loss: 0.1725
Epoch 34/40 | Loss: 0.1710
Epoch 35/40 | Loss: 0.1697
Epoch 36/40 | Loss: 0.1686
Epoch 37/40 | Loss: 0.1677
Epoch 38/4

In [None]:
further optimizations

In [10]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from typing import List, Tuple

class PreNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)  # Bounded activation for hyperbolic space

class PostNet(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, output_dim)
        
    def __call__(self, x):
        return self.linear(x)

class PyramidLayer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)

class HyperbolicCube(nn.Module):
    def __init__(self, layer_dims: List[int]):
        super().__init__()
        self.pyramid_layers = [
            PyramidLayer(layer_dims[i], layer_dims[i+1])
            for i in range(len(layer_dims) - 1)
        ]
        
    def __call__(self, x):
        for layer in self.pyramid_layers:
            x = layer(x)
        return x

class DualHiddenLCM(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], hidden_dim2: int, output_dim: int):
        super().__init__()
        
        # Hidden Dimension 1: Pyramid structure
        self.prenet = PreNet(input_dim, hidden_dims[0])
        self.hyperbolic_cube = HyperbolicCube(hidden_dims)
        self.postnet = PostNet(hidden_dims[-1], output_dim)
        
        # Hidden Dimension 2: 20D bottleneck
        self.hidden_dim2 = nn.Linear(input_dim, hidden_dim2)
        self.hidden_dim2_output = nn.Linear(hidden_dim2, output_dim)
        
    def __call__(self, x):
        # Hidden Dimension 1
        x_hidden1 = self.prenet(x)
        x_hidden1 = self.hyperbolic_cube(x_hidden1)
        x_hidden1 = self.postnet(x_hidden1)
        
        # Hidden Dimension 2
        x_hidden2 = mx.maximum(0, self.hidden_dim2(x))  # ReLU
        x_hidden2 = self.hidden_dim2_output(x_hidden2)
        
        return x_hidden1 + x_hidden2

def compute_accuracy(predicted, target, threshold=0.1):
    # Cosine similarity implementation
    norm_pred = mx.sqrt(mx.sum(predicted * predicted, axis=-1))
    norm_target = mx.sqrt(mx.sum(target * target, axis=-1))
    cos_sim = mx.sum(predicted * target, axis=-1) / (norm_pred * norm_target)
    correct = (cos_sim > threshold).astype(mx.float32)
    return mx.mean(correct)

def train_step(model, x, y, optimizer):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

def main():
    # Model parameters
    input_dim = 300
    hidden_dims = [512, 256, 128, 64]
    hidden_dim2 = 20
    output_dim = 300
    batch_size = 4
    num_samples = 100
    epochs = 40
    warmup_steps = 100
    max_lr = 2e-3
    min_lr = 1e-4
    
    # Generate synthetic data
    rng = np.random.default_rng(42)
    data = rng.normal(0, 1, (num_samples, input_dim)).astype(np.float32)
    target = data + rng.normal(0, 0.01, data.shape).astype(np.float32)
    
    # Convert to MLX arrays
    data = mx.array(data)
    target = mx.array(target)
    
    # Initialize model and optimizer
    model = DualHiddenLCM(input_dim, hidden_dims, hidden_dim2, output_dim)
    # Learning rate schedule
    total_steps = epochs * num_samples // batch_size
    lr_schedule = np.zeros(total_steps)
    
    # Warmup phase
    warmup_indices = np.arange(warmup_steps)
    lr_schedule[:warmup_steps] = min_lr + (max_lr - min_lr) * warmup_indices / warmup_steps
    
    # Cosine decay phase
    remaining_steps = total_steps - warmup_steps
    cosine_indices = np.arange(remaining_steps)
    cosine_decay = 0.5 * (1 + np.cos(np.pi * cosine_indices / remaining_steps))
    lr_schedule[warmup_steps:] = min_lr + (max_lr - min_lr) * cosine_decay
    
    optimizer = optim.Adam(learning_rate=lr_schedule[0])
    
    # Training loop
    for epoch in range(epochs):
        total_loss = 0
        num_batches = num_samples // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_data = data[start_idx:end_idx]
            batch_target = target[start_idx:end_idx]
            
            optimizer.learning_rate = lr_schedule[epoch * num_batches + i]
            loss = train_step(model, batch_data, batch_target, optimizer)
            mx.eval(loss)  # Ensure computation is complete
            total_loss += loss
            
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f}")

if __name__ == "__main__":
    main()

Epoch 1/40 | Loss: 1.0551
Epoch 2/40 | Loss: 0.9188
Epoch 3/40 | Loss: 0.9216
Epoch 4/40 | Loss: 0.9098
Epoch 5/40 | Loss: 0.8804
Epoch 6/40 | Loss: 0.7935
Epoch 7/40 | Loss: 0.6859
Epoch 8/40 | Loss: 0.5945
Epoch 9/40 | Loss: 0.5181
Epoch 10/40 | Loss: 0.4635
Epoch 11/40 | Loss: 0.4200
Epoch 12/40 | Loss: 0.3852
Epoch 13/40 | Loss: 0.3559
Epoch 14/40 | Loss: 0.3318
Epoch 15/40 | Loss: 0.3110
Epoch 16/40 | Loss: 0.2928
Epoch 17/40 | Loss: 0.2766
Epoch 18/40 | Loss: 0.2625
Epoch 19/40 | Loss: 0.2501
Epoch 20/40 | Loss: 0.2391
Epoch 21/40 | Loss: 0.2296
Epoch 22/40 | Loss: 0.2211
Epoch 23/40 | Loss: 0.2135
Epoch 24/40 | Loss: 0.2067
Epoch 25/40 | Loss: 0.2006
Epoch 26/40 | Loss: 0.1954
Epoch 27/40 | Loss: 0.1908
Epoch 28/40 | Loss: 0.1869
Epoch 29/40 | Loss: 0.1835
Epoch 30/40 | Loss: 0.1805
Epoch 31/40 | Loss: 0.1780
Epoch 32/40 | Loss: 0.1758
Epoch 33/40 | Loss: 0.1740
Epoch 34/40 | Loss: 0.1725
Epoch 35/40 | Loss: 0.1712
Epoch 36/40 | Loss: 0.1701
Epoch 37/40 | Loss: 0.1693
Epoch 38/4

In [None]:
inference

In [18]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from typing import List, Tuple
from pathlib import Path
import json

class PreNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)  # Bounded activation for hyperbolic space

class PostNet(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, output_dim)
        
    def __call__(self, x):
        return self.linear(x)

class PyramidLayer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def __call__(self, x):
        x = self.linear(x)
        return mx.tanh(x)

class HyperbolicCube(nn.Module):
    def __init__(self, layer_dims: List[int]):
        super().__init__()
        self.pyramid_layers = [
            PyramidLayer(layer_dims[i], layer_dims[i+1])
            for i in range(len(layer_dims) - 1)
        ]
        
    def __call__(self, x):
        for layer in self.pyramid_layers:
            x = layer(x)
        return x

class DualHiddenLCM(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], hidden_dim2: int, output_dim: int):
        super().__init__()
        
        # Hidden Dimension 1: Pyramid structure
        self.prenet = PreNet(input_dim, hidden_dims[0])
        self.hyperbolic_cube = HyperbolicCube(hidden_dims)
        self.postnet = PostNet(hidden_dims[-1], output_dim)
        
        # Hidden Dimension 2: 20D bottleneck
        self.hidden_dim2 = nn.Linear(input_dim, hidden_dim2)
        self.hidden_dim2_output = nn.Linear(hidden_dim2, output_dim)
        
    def __call__(self, x):
        # Hidden Dimension 1
        x_hidden1 = self.prenet(x)
        x_hidden1 = self.hyperbolic_cube(x_hidden1)
        x_hidden1 = self.postnet(x_hidden1)
        
        # Hidden Dimension 2
        x_hidden2 = mx.maximum(0, self.hidden_dim2(x))  # ReLU
        x_hidden2 = self.hidden_dim2_output(x_hidden2)
        
        return x_hidden1 + x_hidden2

def compute_accuracy(predicted, target, threshold=0.1):
    # Cosine similarity implementation
    norm_pred = mx.sqrt(mx.sum(predicted * predicted, axis=-1))
    norm_target = mx.sqrt(mx.sum(target * target, axis=-1))
    cos_sim = mx.sum(predicted * target, axis=-1) / (norm_pred * norm_target)
    correct = (cos_sim > threshold).astype(mx.float32)
    return mx.mean(correct)

def train_step(model, x, y, optimizer):
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

class LCMInference:
    def __init__(self, model_path: str):
        self.model = self._load_model(model_path)
        self.model.eval()
    
    def _load_model(self, path: str) -> DualHiddenLCM:
        path = Path(path)
        with open(path / "config.json", "r") as f:
            config = json.load(f)
        
        model = DualHiddenLCM(
            input_dim=config["input_dim"],
            hidden_dims=config["hidden_dims"],
            hidden_dim2=config["hidden_dim2"],
            output_dim=config["output_dim"]
        )
        
        model.prenet.linear.weight = mx.load(str(path / "prenet.npz"))
        model.hidden_dim2.weight = mx.load(str(path / "hidden2.npz"))
        model.hidden_dim2_output.weight = mx.load(str(path / "hidden2_out.npz"))
        model.postnet.linear.weight = mx.load(str(path / "postnet.npz"))
        
        return model

    def embed(self, input_data: np.ndarray) -> np.ndarray:
        x = mx.array(input_data.astype(np.float32))
        with mx.stop_gradient():
            embeddings = self.model(x)
        return embeddings.numpy()

    @staticmethod
    def save_model(model: DualHiddenLCM, save_path: str):
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        # Save each layer's weights separately
        mx.save(str(save_path / "prenet.npz"), model.prenet.linear.weight)
        mx.save(str(save_path / "hidden2.npz"), model.hidden_dim2.weight)
        mx.save(str(save_path / "hidden2_out.npz"), model.hidden_dim2_output.weight)
        mx.save(str(save_path / "postnet.npz"), model.postnet.linear.weight)
        
        config = {
            "input_dim": model.prenet.linear.weight.shape[1],
            "hidden_dims": [l.linear.weight.shape[0] for l in model.hyperbolic_cube.pyramid_layers],
            "hidden_dim2": model.hidden_dim2.weight.shape[0],
            "output_dim": model.postnet.linear.weight.shape[0]
        }
        
        with open(save_path / "config.json", "w") as f:
            json.dump(config, f)
            
def main():
    # Model parameters
    input_dim = 300
    hidden_dims = [512, 256, 128, 64]
    hidden_dim2 = 20
    output_dim = 300
    batch_size = 4
    num_samples = 100
    epochs = 40
    warmup_steps = 100
    max_lr = 2e-3
    min_lr = 1e-4
    
    # Generate synthetic data
    rng = np.random.default_rng(42)
    data = rng.normal(0, 1, (num_samples, input_dim)).astype(np.float32)
    target = data + rng.normal(0, 0.01, data.shape).astype(np.float32)
    
    # Convert to MLX arrays
    data = mx.array(data)
    target = mx.array(target)
    
    # Initialize model and optimizer
    model = DualHiddenLCM(input_dim, hidden_dims, hidden_dim2, output_dim)
    # Learning rate schedule
    total_steps = epochs * num_samples // batch_size
    lr_schedule = np.zeros(total_steps)
    
    # Warmup phase
    warmup_indices = np.arange(warmup_steps)
    lr_schedule[:warmup_steps] = min_lr + (max_lr - min_lr) * warmup_indices / warmup_steps
    
    # Cosine decay phase
    remaining_steps = total_steps - warmup_steps
    cosine_indices = np.arange(remaining_steps)
    cosine_decay = 0.5 * (1 + np.cos(np.pi * cosine_indices / remaining_steps))
    lr_schedule[warmup_steps:] = min_lr + (max_lr - min_lr) * cosine_decay
    
    optimizer = optim.Adam(learning_rate=lr_schedule[0])
    
    # Training loop
    for epoch in range(epochs):
        total_loss = 0
        num_batches = num_samples // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_data = data[start_idx:end_idx]
            batch_target = target[start_idx:end_idx]
            
            optimizer.learning_rate = lr_schedule[epoch * num_batches + i]
            loss = train_step(model, batch_data, batch_target, optimizer)
            mx.eval(loss)
            total_loss += loss
            
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f}")
    
    # Save the trained model
    LCMInference.save_model(model, "lcm_model")

if __name__ == "__main__":
    main()

Epoch 1/40 | Loss: 1.0639
Epoch 2/40 | Loss: 0.9222
Epoch 3/40 | Loss: 0.9168
Epoch 4/40 | Loss: 0.9170
Epoch 5/40 | Loss: 0.8929
Epoch 6/40 | Loss: 0.8112
Epoch 7/40 | Loss: 0.7029
Epoch 8/40 | Loss: 0.5975
Epoch 9/40 | Loss: 0.5120
Epoch 10/40 | Loss: 0.4548
Epoch 11/40 | Loss: 0.4111
Epoch 12/40 | Loss: 0.3786
Epoch 13/40 | Loss: 0.3507
Epoch 14/40 | Loss: 0.3281
Epoch 15/40 | Loss: 0.3085
Epoch 16/40 | Loss: 0.2915
Epoch 17/40 | Loss: 0.2767
Epoch 18/40 | Loss: 0.2629
Epoch 19/40 | Loss: 0.2502
Epoch 20/40 | Loss: 0.2387
Epoch 21/40 | Loss: 0.2288
Epoch 22/40 | Loss: 0.2199
Epoch 23/40 | Loss: 0.2122
Epoch 24/40 | Loss: 0.2055
Epoch 25/40 | Loss: 0.1996
Epoch 26/40 | Loss: 0.1945
Epoch 27/40 | Loss: 0.1901
Epoch 28/40 | Loss: 0.1862
Epoch 29/40 | Loss: 0.1828
Epoch 30/40 | Loss: 0.1799
Epoch 31/40 | Loss: 0.1774
Epoch 32/40 | Loss: 0.1752
Epoch 33/40 | Loss: 0.1734
Epoch 34/40 | Loss: 0.1719
Epoch 35/40 | Loss: 0.1706
Epoch 36/40 | Loss: 0.1696
Epoch 37/40 | Loss: 0.1687
Epoch 38/4