In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import diffrax as dfx
import optax
import numpy as np
from jax import random
import matplotlib.pyplot as plt

# ============================================================================
# CONFIGURATION
# ============================================================================

DATA_PATH = 'spirals.npz'
TRAIN_SAMPLES = 10000
TEST_SAMPLES = 10000
VALIDATION_SPLIT = 0.2

# Model configuration
INPUT_DIM = 5        # Enhanced features: [r, theta, log_r, log_theta, r/theta]
HIDDEN_DIM = 32      # Increased hidden size
OUTPUT_DIM = 1       # Predict alpha

# Training configuration
NUM_EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 3e-3
RANDOM_SEED = 42

OUTPUT_FILE = 'alpha_predictions.npy'
PLOT_FILE = 'training_results.png'
ENABLE_PLOTS = True

# ============================================================================
# PHYSICS-INFORMED FEATURE EXTRACTION
# ============================================================================

def cartesian_to_polar_sequential(xy_trajectory):
    """
    Convert (x, y) trajectory to polar coordinates (r, theta) sequentially.
    Avoids data leakage by not using future information.
    
    Args:
        xy_trajectory: (seq_len, 2) array of (x, y) coordinates
    Returns:
        features: (seq_len, 2) array of [r, theta]
    """
    x = xy_trajectory[:, 0]
    y = xy_trajectory[:, 1]
    
    # Add small epsilon for numerical stability
    eps = 1e-6
    r = jnp.sqrt(x**2 + y**2 + eps)
    theta = jnp.arctan2(y, x)
    
    # Sequential unwrapping that only uses current/previous values
    theta_unwrapped = jnp.zeros_like(theta)
    theta_unwrapped = theta_unwrapped.at[0].set(theta[0])
    
    # Use scan for efficiency
    def unwrap_step(carry, i):
        prev_theta, theta_vals = carry
        current_theta = theta_vals[i]
        
        # Handle 2π jumps
        diff = current_theta - prev_theta
        diff = jnp.where(diff > jnp.pi, diff - 2 * jnp.pi, diff)
        diff = jnp.where(diff < -jnp.pi, diff + 2 * jnp.pi, diff)
        
        new_theta = prev_theta + diff
        return (new_theta, theta_vals), new_theta
    
    _, theta_unwrapped = jax.lax.scan(
        unwrap_step, 
        (theta[0], theta), 
        jnp.arange(1, len(theta))
    )
    
    theta_unwrapped = jnp.concatenate([theta[0:1], theta_unwrapped])
    
    return jnp.stack([r, theta_unwrapped], axis=1)

def enhanced_polar_features(xy_trajectory):
    """
    Enhanced physics-informed features for spiral relationship.
    For spiral: r = alpha * theta, so log(r) = log(alpha) + log(theta)
    
    Args:
        xy_trajectory: (seq_len, 2) array of (x, y) coordinates
    Returns:
        features: (seq_len, 5) array of enhanced features
    """
    basic_features = cartesian_to_polar_sequential(xy_trajectory)
    r, theta = basic_features[:, 0], basic_features[:, 1]
    
    eps = 1e-6
    
    # Enhanced features that capture the spiral relationship
    log_r = jnp.log(r + eps)                    # log(r) = log(α) + log(θ)
    log_theta = jnp.log(jnp.abs(theta) + eps)   # log(θ)
    r_over_theta = r / (theta + eps)            # Direct estimate of α
    
    # Stack all features
    features = jnp.stack([r, theta, log_r, log_theta, r_over_theta], axis=1)
    
    return features

# ============================================================================
# DATA LOADING AND PREPROCESSING
# ============================================================================

def load_and_preprocess_data(filepath, n_train=None, n_test=None):
    """Load spiral data and convert to enhanced polar coordinates."""
    print("Loading data...")
    data = np.load(filepath)
    
    xy_train = data['xy_train'][:n_train] if n_train else data['xy_train']
    alpha_train = data['alpha_train'][:n_train] if n_train else data['alpha_train']
    xy_test = data['xy_test'][:n_test] if n_test else data['xy_test']
    
    print(f"Train shape: {xy_train.shape}, Alpha shape: {alpha_train.shape}")
    print(f"Test shape: {xy_test.shape}")
    
    # Convert to enhanced polar coordinates
    print("Converting to enhanced polar coordinates...")
    train_features = []
    for i in range(len(xy_train)):
        features = enhanced_polar_features(jnp.array(xy_train[i]))
        train_features.append(features)
    train_features = jnp.stack(train_features)
    
    test_features = []
    for i in range(len(xy_test)):
        features = enhanced_polar_features(jnp.array(xy_test[i]))
        test_features.append(features)
    test_features = jnp.stack(test_features)
    
    print(f"Enhanced polar feature shape: {train_features.shape}")
    print(f"Features are: [r, theta, log_r, log_theta, r/theta]")
    
    # DO NOT normalize polar coordinates to preserve physical relationships
    train_features_raw = train_features
    test_features_raw = test_features
    
    # Normalize alpha
    alpha_mean = alpha_train.mean()
    alpha_std = alpha_train.std()
    alpha_train_norm = (alpha_train - alpha_mean) / (alpha_std + 1e-8)
    alpha_train_norm = alpha_train_norm.squeeze()
    
    print(f"\nAlpha stats (raw): min={alpha_train.min():.4f}, max={alpha_train.max():.4f}, mean={alpha_mean:.4f}, std={alpha_std:.4f}")
    print(f"r stats: min={train_features[:,:,0].min():.4f}, max={train_features[:,:,0].max():.4f}")
    print(f"theta stats: min={train_features[:,:,1].min():.4f}, max={train_features[:,:,1].max():.4f}")
    
    return (train_features_raw,
            jnp.array(alpha_train_norm),
            test_features_raw,
            float(alpha_mean),
            float(alpha_std),
            alpha_train,
            xy_train,
            xy_test)

# ============================================================================
# GRU-ODE MODEL
# ============================================================================

class ODEFunc(eqx.Module):
    """ODE function for continuous evolution between observations"""
    mlp: eqx.nn.MLP
    hidden_size: int

    def __init__(self, hidden_size: int, *, key):
        self.hidden_size = hidden_size
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size,
            width_size=hidden_size * 2,
            depth=2,
            activation=jax.nn.softplus,
            key=key,
        )

    def __call__(self, t, h, args):
        h = jnp.reshape(h, (-1,))
        return self.mlp(h)


class GRUCell(eqx.Module):
    """GRU Cell for observation updates"""
    Wz: jnp.ndarray
    Wr: jnp.ndarray
    Wh: jnp.ndarray
    
    def __init__(self, input_size, hidden_size, key):
        key_z, key_r, key_h = random.split(key, 3)
        
        scale = 1.0 / jnp.sqrt(hidden_size)
        self.Wz = random.normal(key_z, (hidden_size + input_size, hidden_size)) * scale
        self.Wr = random.normal(key_r, (hidden_size + input_size, hidden_size)) * scale
        self.Wh = random.normal(key_h, (hidden_size + input_size, hidden_size)) * scale
    
    def __call__(self, x, h_prev):
        combined = jnp.concatenate([h_prev, x], axis=-1)
        z = jax.nn.sigmoid(combined @ self.Wz)
        r = jax.nn.sigmoid(combined @ self.Wr)
        combined_reset = jnp.concatenate([r * h_prev, x], axis=-1)
        h_prime = jnp.tanh(combined_reset @ self.Wh)
        h = (1 - z) * h_prime + z * h_prev
        return h


class PhysicsInformedGRUODE(eqx.Module):
    """
    Enhanced Physics-informed GRU-ODE: Uses spiral physics features
    """
    ode_func: ODEFunc
    gru_cell: GRUCell
    readout: eqx.nn.MLP
    hidden_size: int

    def __init__(self, input_size: int, hidden_size: int, output_size: int, *, key):
        key_ode, key_gru, key_readout = random.split(key, 3)
        self.hidden_size = hidden_size
        self.ode_func = ODEFunc(hidden_size, key=key_ode)
        self.gru_cell = GRUCell(input_size, hidden_size, key=key_gru)
        
        # More powerful readout network
        self.readout = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=output_size,
            width_size=hidden_size * 2,
            depth=3,
            activation=jax.nn.softplus,
            key=key_readout
        )

    def __call__(self, trajectory):
        """
        trajectory: (seq_len, 5) - enhanced polar coordinates
        The model learns: r = alpha * theta
        """
        seq_len = trajectory.shape[0]
        h = jnp.zeros((self.hidden_size,), dtype=jnp.float32)
        ts = jnp.linspace(0, 1, seq_len)
        
        solver = dfx.Dopri5()
        term = dfx.ODETerm(self.ode_func)
        
        for i in range(seq_len):
            if i > 0:
                t0, t1 = ts[i-1], ts[i]
                solution = dfx.diffeqsolve(
                    term,
                    solver,
                    t0=t0,
                    t1=t1,
                    dt0=(t1 - t0) / 10.0,  # More steps for accuracy
                    y0=h,
                    saveat=dfx.SaveAt(t1=True),
                    max_steps=32,  # Increased max steps
                    stepsize_controller=dfx.PIDController(rtol=1e-3, atol=1e-6)
                )
                h = jnp.reshape(solution.ys, (-1,))
            
            x_obs = trajectory[i]
            h = self.gru_cell(x_obs, h)
        
        alpha_pred = self.readout(h)
        return jnp.squeeze(alpha_pred)

# ============================================================================
# PHYSICS-INFORMED LOSS
# ============================================================================

def physics_informed_loss(model, trajectory, alpha_target, alpha_std):
    """Enhanced loss with physics-based regularization"""
    pred = model(trajectory)
    alpha_target = jnp.squeeze(alpha_target)
    
    # Standard MSE
    mse_loss = jnp.mean((pred - alpha_target) ** 2)
    
    # Physics-based regularization: check if r ≈ αθ holds
    r, theta = trajectory[:, 0], trajectory[:, 1]
    eps = 1e-6
    alpha_physical = r / (theta + eps)
    
    # Use robust statistics (median) to avoid outliers
    alpha_physical_median = jnp.median(alpha_physical)
    
    # Encourage predictions to match physical relationship
    physics_loss = jnp.mean((pred - alpha_physical_median) ** 2)
    
    # Combined loss with smaller physics weight to avoid over-constraining
    return mse_loss + 0.05 * physics_loss

# ============================================================================
# TRAINING
# ============================================================================

@eqx.filter_jit
def train_step(model, opt_state, trajectory, alpha_target, optimizer, alpha_std):
    loss, grads = eqx.filter_value_and_grad(physics_informed_loss)(
        model, trajectory, alpha_target, alpha_std
    )
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

def compute_validation_loss(model, features_val, alpha_val, alpha_std):
    """Compute validation loss"""
    total_loss = 0.0
    count = 0
    
    for i in range(len(features_val)):
        trajectory = features_val[i]
        alpha_target = alpha_val[i]
        loss = physics_informed_loss(model, trajectory, alpha_target, alpha_std)
        total_loss += loss
        count += 1
    
    return total_loss / count

def train_model(model, features_train, alpha_train, features_val, alpha_val, 
                alpha_std, num_epochs, batch_size, key):
    """Training loop with early stopping"""
    # Setup optimizer with learning rate scheduling
    scheduler = optax.exponential_decay(
        init_value=LEARNING_RATE,
        transition_steps=len(features_train) // batch_size * 10,
        decay_rate=0.95
    )
    optimizer = optax.chain(
        optax.clip(1.0),
        optax.adam(scheduler)
    )
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    
    num_samples = features_train.shape[0]
    
    train_losses = []
    val_losses = []
    
    # Early stopping
    best_val_loss = float('inf')
    best_model = model
    patience = 8
    patience_counter = 0
    
    print("Starting training...")
    for epoch in range(num_epochs):
        key, subkey = random.split(key)
        perm = random.permutation(subkey, num_samples)
        
        epoch_loss = 0.0
        num_batches = 0
        
        for i in range(0, num_samples, batch_size):
            batch_idx = perm[i:i+batch_size]
            
            batch_loss = 0.0
            batch_count = 0
            for j in batch_idx:
                trajectory = features_train[j]
                alpha_target = alpha_train[j]
                
                model, opt_state, step_loss = train_step(
                    model, opt_state, trajectory, alpha_target, optimizer, alpha_std
                )
                batch_loss += step_loss
                batch_count += 1
            
            batch_loss /= batch_count
            epoch_loss += batch_loss
            num_batches += 1
            
            if (num_batches % 10) == 0:
                print(f"  Batch {num_batches}, Loss: {batch_loss:.6f}")
        
        avg_train_loss = epoch_loss / num_batches
        train_losses.append(float(avg_train_loss))
        
        # Validation
        val_loss = compute_validation_loss(model, features_val, alpha_val, alpha_std)
        val_losses.append(float(val_loss))
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model
            patience_counter = 0
            print(f"  New best validation loss: {best_val_loss:.6f}")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{patience})")
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    print(f"Best validation loss: {best_val_loss:.6f}")
    return best_model, train_losses, val_losses

# ============================================================================
# EVALUATION
# ============================================================================

@eqx.filter_jit
def predict_single(model, trajectory):
    return model(trajectory)

def predict_batch(model, trajectories):
    predictions = []
    for i in range(trajectories.shape[0]):
        pred = predict_single(model, trajectories[i])
        predictions.append(pred)
    return jnp.array(predictions)

def evaluate_model(model, features_test, alpha_mean, alpha_std):
    print("Evaluating on test set...")
    predictions_norm = predict_batch(model, features_test)
    predictions = predictions_norm * alpha_std + alpha_mean
    return predictions

# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    # Load data with enhanced physics-informed features
    features_train, alpha_train, features_test, alpha_mean, alpha_std, alpha_train_raw, xy_train, xy_test = \
        load_and_preprocess_data(DATA_PATH, TRAIN_SAMPLES, TEST_SAMPLES)
    
    # Split validation set
    n_train = int(features_train.shape[0] * (1 - VALIDATION_SPLIT))
    features_val = features_train[n_train:]
    alpha_val = alpha_train[n_train:]
    alpha_val_raw = alpha_train_raw[n_train:]
    xy_val = xy_train[n_train:]
    features_train = features_train[:n_train]
    alpha_train = alpha_train[:n_train]
    alpha_train_raw = alpha_train_raw[:n_train]
    
    print(f"\nTraining samples: {features_train.shape[0]}")
    print(f"Validation samples: {features_val.shape[0]}")
    print(f"Input dimension: {INPUT_DIM}")
    
    # Initialize model
    key = random.PRNGKey(RANDOM_SEED)
    key, model_key = random.split(key)
    
    model = PhysicsInformedGRUODE(
        input_size=INPUT_DIM,
        hidden_size=HIDDEN_DIM,
        output_size=OUTPUT_DIM,
        key=model_key
    )
    
    print("\nModel: Enhanced Physics-Informed GRU-ODE")
    print("Input: Enhanced polar coordinates [r, theta, log_r, log_theta, r/theta]")
    print("Learning: r = alpha * theta with physics-informed loss")
    
    # Train model
    model, train_losses, val_losses = train_model(
        model, features_train, alpha_train, features_val, alpha_val,
        alpha_std, NUM_EPOCHS, BATCH_SIZE, key
    )
    
    print("\nTraining complete!")
    
    # Evaluate on validation set
    print("\nEvaluating on validation set...")
    val_predictions_norm = predict_batch(model, features_val)
    val_predictions = val_predictions_norm * alpha_std + alpha_mean
    
    print(f"Val predictions (denormalized) - min: {val_predictions.min():.4f}, max: {val_predictions.max():.4f}, mean: {val_predictions.mean():.4f}")
    print(f"Val targets (actual) - min: {alpha_val_raw.min():.4f}, max: {alpha_val_raw.max():.4f}, mean: {alpha_val_raw.mean():.4f}")
    
    val_mse = jnp.mean((val_predictions - alpha_val_raw.squeeze()) ** 2)
    val_mae = jnp.mean(jnp.abs(val_predictions - alpha_val_raw.squeeze()))
    val_rmse = jnp.sqrt(val_mse)
    
    print(f"Validation MSE: {val_mse:.6f}")
    print(f"Validation MAE: {val_mae:.6f}")
    print(f"Validation RMSE: {val_rmse:.6f}")
    
    # Evaluate on test set
    test_predictions = evaluate_model(model, features_test, alpha_mean, alpha_std)
    
    # Save predictions
    np.save(OUTPUT_FILE, np.array(test_predictions))
    print(f"\nPredictions saved to {OUTPUT_FILE}")
    
    # Plot results
    if ENABLE_PLOTS:
        fig = plt.figure(figsize=(18, 14))
        gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)
        
        # Training and validation loss
        ax1 = fig.add_subplot(gs[0, 0])
        ax1.plot(train_losses, label='Train Loss')
        ax1.plot(val_losses, label='Val Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss (Normalized)')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Predictions vs actual
        ax2 = fig.add_subplot(gs[0, 1])
        ax2.scatter(alpha_val_raw, val_predictions, alpha=0.5)
        min_val = min(alpha_val_raw.min(), val_predictions.min())
        max_val = max(alpha_val_raw.max(), val_predictions.max())
        ax2.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect')
        ax2.set_xlabel('True Alpha')
        ax2.set_ylabel('Predicted Alpha')
        ax2.set_title(f'Validation (MSE: {val_mse:.4f}, MAE: {val_mae:.4f})')
        ax2.legend()
        ax2.grid(True)
        ax2.axis('equal')
        
        # Residuals
        ax3 = fig.add_subplot(gs[0, 2])
        residuals = val_predictions - alpha_val_raw.squeeze()
        ax3.scatter(alpha_val_raw, residuals, alpha=0.5)
        ax3.axhline(y=0, color='r', linestyle='--')
        ax3.set_xlabel('True Alpha')
        ax3.set_ylabel('Residual')
        ax3.set_title('Residuals')
        ax3.grid(True)
        
        # Error distribution
        ax4 = fig.add_subplot(gs[1, 0])
        ax4.hist(residuals, bins=20, alpha=0.7, edgecolor='black')
        ax4.axvline(x=0, color='r', linestyle='--')
        ax4.set_xlabel('Prediction Error')
        ax4.set_ylabel('Frequency')
        ax4.set_title('Error Distribution')
        ax4.grid(True)
        
        # Plot 6 example spirals (2 rows x 3 cols)
        for plot_idx in range(6):
            row = 2 + plot_idx // 3
            col = plot_idx % 3
            ax = fig.add_subplot(gs[row, col])
            
            val_idx = plot_idx * (len(xy_val) // 6)
            trajectory_actual = xy_val[val_idx]
            
            true_alpha = float(alpha_val_raw[val_idx])
            pred_alpha = float(val_predictions[val_idx])
            
            # Plot actual trajectory
            ax.plot(trajectory_actual[:, 0], trajectory_actual[:, 1], 
                   'b-', linewidth=2, label=f'True α={true_alpha:.2f}', alpha=0.7)
            ax.plot(trajectory_actual[0, 0], trajectory_actual[0, 1], 
                   'go', markersize=8, label='Start')
            
            # Generate predicted spiral
            theta_pred = np.linspace(0, 4*np.pi, 100)
            r_pred = pred_alpha * theta_pred
            x_pred = r_pred * np.cos(theta_pred)
            y_pred = r_pred * np.sin(theta_pred)
            ax.plot(x_pred, y_pred, 'r--', linewidth=1.5, 
                   label=f'Pred α={pred_alpha:.2f}', alpha=0.7)
            
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_title(f'Example {plot_idx+1} (Error: {pred_alpha-true_alpha:.2f})')
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3)
            ax.axis('equal')
        
        plt.savefig(PLOT_FILE, dpi=150, bbox_inches='tight')
        print(f"Plot saved to {PLOT_FILE}")
        plt.show()

Loading data...
Train shape: (1000, 100, 2), Alpha shape: (1000, 1)
Test shape: (1000, 100, 2)
Converting to enhanced polar coordinates...
