# Advanced LTC Networks with MLX

This notebook demonstrates the advanced features of LTC (Liquid Time-Constant) networks using MLX's neural circuit implementations. We'll explore:
- Time-aware processing with variable time steps
- Bidirectional processing
- Multi-layer architectures with backbones
- Comparison with CfC models

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import LTC, CfC  # Import both for comparison

## Create Advanced LTC Models

We'll create several model variants to demonstrate different LTC capabilities:

In [None]:
class TimeAwareLTC(nn.Module):
    """LTC model with explicit time-aware processing."""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.ltc = LTC(
            input_size=input_size,
            hidden_size=hidden_size,
            return_sequences=True
        )
        self.output_layer = nn.Linear(hidden_size, output_size)
    
    def __call__(self, x, time_delta=None):
        x = self.ltc(x, time_delta=time_delta)
        return self.output_layer(x[:, -1])


class BidirectionalLTC(nn.Module):
    """Bidirectional LTC for capturing forward and backward dependencies."""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.ltc = LTC(
            input_size=input_size,
            hidden_size=hidden_size,
            bidirectional=True,
            return_sequences=True
        )
        self.output_layer = nn.Linear(hidden_size * 2, output_size)
    
    def __call__(self, x, time_delta=None):
        x = self.ltc(x, time_delta=time_delta)
        return self.output_layer(x[:, -1])


class DeepLTC(nn.Module):
    """Deep LTC with multiple layers and backbone networks."""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.ltc = LTC(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=2,
            backbone_units=64,
            backbone_layers=2,
            return_sequences=True
        )
        self.output_layer = nn.Linear(hidden_size, output_size)
    
    def __call__(self, x, time_delta=None):
        x = self.ltc(x, time_delta=time_delta)
        return self.output_layer(x[:, -1])

## Generate Complex Time Series Data

We'll create data with variable sampling rates and multiple frequency components:

In [None]:
def generate_complex_data(batch_size, seq_length, include_time=True):
    """Generate complex time series with multiple components and variable sampling."""
    # Generate base time points with higher resolution
    base_t = np.linspace(0, 8*np.pi, seq_length*2)
    
    # Create signal with multiple frequency components
    signal = (
        np.sin(base_t) +                # Base frequency
        0.5 * np.sin(2*base_t) +        # First harmonic
        0.25 * np.sin(4*base_t)         # Second harmonic
    )
    
    # Create batches
    X = np.zeros((batch_size, seq_length-1, 2))  # [value, derivative]
    y = np.zeros((batch_size, 2))
    
    if include_time:
        time_delta = np.zeros((batch_size, seq_length-1, 1))
        for i in range(batch_size):
            # Generate variable time steps
            steps = np.sort(np.random.uniform(0, len(signal)-seq_length, seq_length))
            indices = steps.astype(int)
            
            # Calculate time deltas
            time_delta[i, :, 0] = np.diff(steps)
            
            # Sample signal at variable points
            sampled_signal = signal[indices]
            X[i, :, 0] = sampled_signal[:-1]
            X[i, :, 1] = np.gradient(sampled_signal[:-1], time_delta[i, :, 0])
            
            y[i, 0] = sampled_signal[-1]
            y[i, 1] = np.gradient(sampled_signal)[-1]
    else:
        time_delta = None
        for i in range(batch_size):
            start_idx = np.random.randint(0, len(signal)-seq_length)
            X[i, :, 0] = signal[start_idx:start_idx+seq_length-1]
            X[i, :, 1] = np.gradient(signal[start_idx:start_idx+seq_length-1])
            y[i, 0] = signal[start_idx+seq_length-1]
            y[i, 1] = np.gradient(signal)[start_idx+seq_length-1]
    
    return mx.array(X), mx.array(y), mx.array(time_delta) if include_time else None

## Training Functions

In [None]:
def loss_fn(model, X, y, time_delta=None):
    """Compute MSE loss."""
    pred = model(X, time_delta=time_delta)
    return mx.mean((pred - y) ** 2)


def train_model(model, n_epochs=100, batch_size=32, seq_length=50, use_time=True):
    """Train a model and return loss history."""
    optimizer = nn.Adam(learning_rate=0.001)
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    losses = []
    
    for epoch in range(n_epochs):
        X, y, time_delta = generate_complex_data(batch_size, seq_length, use_time)
        
        # Compute loss and gradients
        loss, grads = loss_and_grad_fn(model, X, y, time_delta)
        
        # Update parameters
        optimizer.update(model, grads)
        losses.append(float(loss))
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")
    
    return losses

## Compare Different LTC Architectures

Let's train and compare our different LTC model variants:

In [None]:
# Create models
time_aware_model = TimeAwareLTC(input_size=2, hidden_size=32, output_size=2)
bidir_model = BidirectionalLTC(input_size=2, hidden_size=32, output_size=2)
deep_model = DeepLTC(input_size=2, hidden_size=32, output_size=2)

# Train models
print("Training time-aware model...")
time_aware_losses = train_model(time_aware_model)

print("\nTraining bidirectional model...")
bidir_losses = train_model(bidir_model)

print("\nTraining deep model...")
deep_losses = train_model(deep_model)

# Plot training curves
plt.figure(figsize=(10, 6))
plt.plot(time_aware_losses, label='Time-Aware LTC')
plt.plot(bidir_losses, label='Bidirectional LTC')
plt.plot(deep_losses, label='Deep LTC')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.yscale('log')
plt.grid(True)
plt.show()

## Visualize Predictions with Different Time Patterns

Let's examine how our models handle different sampling patterns:

In [None]:
def evaluate_predictions(model, seq_length=100):
    """Generate predictions with different sampling patterns."""
    # Generate two sequences with different sampling
    X1, y1, dt1 = generate_complex_data(1, seq_length)
    X2, y2, dt2 = generate_complex_data(1, seq_length)
    
    # Get predictions
    pred1 = model(X1, time_delta=dt1)
    pred2 = model(X2, time_delta=dt2)
    
    # Plot results
    plt.figure(figsize=(15, 10))
    
    # Plot signal values
    plt.subplot(311)
    plt.plot(X1[0, :, 0], 'b-', label='Signal 1')
    plt.plot(len(X1[0]), float(pred1[0, 0]), 'bo', label='Pred 1')
    plt.plot(len(X1[0]), float(y1[0, 0]), 'go', label='True 1')
    plt.legend()
    plt.title(f'Predictions with Different Sampling Patterns - {model.__class__.__name__}')
    
    # Plot derivatives
    plt.subplot(312)
    plt.plot(X2[0, :, 1], 'r-', label='Derivative 2')
    plt.plot(len(X2[0]), float(pred2[0, 1]), 'ro', label='Pred 2')
    plt.plot(len(X2[0]), float(y2[0, 1]), 'go', label='True 2')
    plt.legend()
    
    # Plot time deltas
    plt.subplot(313)
    plt.plot(dt1[0], 'b-', label='Time Delta 1')
    plt.plot(dt2[0], 'r-', label='Time Delta 2')
    plt.legend()
    plt.title('Sampling Patterns')
    plt.xlabel('Step')
    plt.ylabel('Delta t')
    
    plt.tight_layout()
    plt.show()

# Evaluate each model
evaluate_predictions(time_aware_model)
evaluate_predictions(bidir_model)
evaluate_predictions(deep_model)

## Compare LTC with CfC

Let's compare the performance of LTC and CfC on the same task:

In [None]:
# Create comparable CfC model
cfc_model = CfC(
    input_size=2,
    hidden_size=32,
    num_layers=2,
    backbone_units=64,
    backbone_layers=2,
    return_sequences=True
)
cfc_output_layer = nn.Linear(32, 2)

# Train CfC model
def cfc_forward(x, time_delta=None):
    x = cfc_model(x, time_delta=time_delta)
    return cfc_output_layer(x[:, -1])

print("Training CfC model...")
cfc_losses = train_model(cfc_forward)

# Compare losses
plt.figure(figsize=(10, 6))
plt.plot(deep_losses, label='Deep LTC')
plt.plot(cfc_losses, label='Deep CfC')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('LTC vs CfC Comparison')
plt.legend()
plt.yscale('log')
plt.grid(True)
plt.show()

# Evaluate CfC predictions
evaluate_predictions(lambda x, time_delta: cfc_forward(x, time_delta))