In [None]:
import numpy as np
import jax.numpy as jnp
from jax import random, vmap, jit

from functools import partial
import scipy
import matplotlib.pyplot as plt
import time

from scipy.io import loadmat
import jax.lax as lax
import jax

from mjx_proj import TrajectoryProjector as qp_solver

import os


import matplotlib.pyplot as plt

In [None]:
# num = 4
# num_batch = 4
# t = 0.05

num_dof=1,
num_steps=10,
num_batch=2,
timestep=0.05,
maxiter_projection=100,
v_max=1.0,
a_max=2.0,
j_max=3.0,
rho_ineq= 1.0,
rho_projection=1.0,

opt_class = qp_solver(num_dof=1,
                        num_steps=10,
                        num_batch=2,
                        timestep=0.05,
                        maxiter_projection=100,
                        v_max=1.0,
                        a_max=2.0,
                        j_max=3.0,
                        rho_ineq= 1.0,
                        rho_projection=1.0)

# vel_init = 0.0

# vel_min = -1.2
# vel_max = 1.2

# acc_min = -1.8
# acc_max = 1.8

# jerk_min = -1.8
# jerk_max = 1.8

# vel_samples = np.random.uniform(-1.5, 1.5,  (num_batch, num)    )


In [None]:
# Sample a trajectory
key = jax.random.PRNGKey(42)
xi_samples, _ = opt_class.sample_uniform_trajectories(key)

print(f"Sampled trajectory shape: {xi_samples.shape}")

In [None]:
# Project the trajectory
start_time = time.time()
xi_filtered, residuals = opt_class.project_trajectories(xi_samples)

print("resdiuals", residuals.shape)


In [None]:

print(f"Projection time: {time.time() - start_time:.3f} seconds")

# Convert to numpy for saving/analysis
#xi_np = np.mean(xi_samples, axis=0)
xi_np = xi_samples 
#xi_np = np.array(xi_samples[1])
#xi_filtered_np = np.array(xi_filtered[1])
#xi_filtered_np = np.mean(xi_filtered, axis=0)
xi_filtered_np = xi_filtered
# Save results
os.makedirs('results', exist_ok=True)
np.savetxt('results/original_trajectory.csv', xi_np, delimiter=',')
np.savetxt('results/projected_trajectory.csv', xi_filtered_np, delimiter=',')
np.savetxt('results/residuals.csv', residuals, delimiter=',')

print("Generated sample trajectories")
print(f"Original shape: {xi_np.shape}")
print(f"xi_filtered shape: {xi_filtered_np.shape}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def visualize_trajectory(original, projected, dof_idx=0, dof=1, dt=0.05):
    """Visualize original and projected trajectories for a specific DOF in separate subplots"""
    
    # Determine number of time steps
    print(original.shape)
    num_steps = original.shape[0] // dof

    # Extract the velocities for the specified DOF
    orig_vel = original[dof_idx*num_steps : (dof_idx+1)*num_steps]
    proj_vel = projected[dof_idx*num_steps : (dof_idx+1)*num_steps]

    # Create time vector
    time = np.arange(num_steps) * dt  # default timestep=0.05

    # Plot in two subplots
    fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

    # Original
    axs[0].plot(time, orig_vel, 'b-')
    #axs[0].axhline(y=1.0, color='g', linestyle='--', label='v_max')
    #axs[0].axhline(y=-1.0, color='g', linestyle='--')
    axs[0].set_ylabel(f'Joint {dof_idx} Velocity')
    axs[0].set_title('Original Joint Velocity')
    axs[0].legend()
    axs[0].grid(True)

    # Projected
    axs[1].plot(time, proj_vel, 'r-')
    #axs[1].axhline(y=1.0, color='g', linestyle='--', label='v_max')
    #axs[1].axhline(y=-1.0, color='g', linestyle='--')
    axs[1].set_xlabel('Time (s)')
    axs[1].set_ylabel(f'Joint {dof_idx} Velocity')
    axs[1].set_title('Projected Joint Velocity')
    axs[1].legend()
    axs[1].grid(True)

    fig.tight_layout()

    try:
        plt.show()
    except:
        plt.savefig(f"trajectory_dof{dof_idx}_subplots.png")
        print(f"Plot saved as trajectory_dof{dof_idx}_subplots.png")
        


In [None]:
%matplotlib inline
visualize_trajectory(xi_np, xi_filtered_np, dof_idx=0, dof=1, dt=0.05)