In [2]:
import numpy as np

In [4]:
transformation_matrix = np.zeros((4,4))

In [5]:
import numpy as np

def forward_kinematics(link_transformations):
    """
    Computes the cumulative transformation for each link in a kinematic chain.
    
    Args:
        link_transformations: A list or array of 4x4 homogenous transformation matrices.
                            Each matrix represents the transform from link i-1 to link i.
                            
    Returns:
        A list of 4x4 matrices representing the pose of each link relative to the base frame.
    """
    # Start at the base frame (Identity matrix)
    current_pose = np.eye(4)
    all_poses = [current_pose]
    
    print("Base Frame:\n", current_pose)
    
    for i, transform in enumerate(link_transformations):
        # Apply the next transformation: T_base_to_i = T_base_to_prev @ T_prev_to_i
        current_pose = current_pose @ transform
        all_poses.append(current_pose)
        
        print(f"\nLink {i+1} Pose (relative to base):\n", current_pose)
        
    return all_poses

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from ipywidgets import interact, FloatSlider
import ipywidgets as widgets

# --- 1. Math Functions (First Principles) ---

def get_rotation_matrix_z(theta_deg):
    """Creates a 3x3 rotation matrix for rotation around Z axis."""
    theta = np.radians(theta_deg)
    c, s = np.cos(theta), np.sin(theta)
    # This matches the matrix derived in the text
    return np.array([
        [c, -s, 0],
        [s,  c, 0],
        [0,  0, 1]
    ])

def get_homogeneous_transform(rotation_matrix, translation_vector):
    """Combines 3x3 Rotation and 3x1 Translation into 4x4 SE(3) Matrix."""
    T = np.eye(4) # Identity 4x4
    T[:3, :3] = rotation_matrix
    T[:3, 3] = translation_vector
    return T

# --- 2. Visualization Logic ---

def plot_coordinate_frame(ax, T, length=1.0, label="Frame"):
    """
    Plots a 3D coordinate frame represented by 4x4 matrix T.
    Red = X, Green = Y, Blue = Z
    """
    # Extract Origin (The position vector P)
    origin = T[:3, 3]

    # Extract Basis Vectors (The columns of Rotation Matrix R)
    x_axis = T[:3, 0]
    y_axis = T[:3, 1]
    z_axis = T[:3, 2]

    # Plot X axis (Red)
    ax.quiver(origin[0], origin[1], origin[2], 
              x_axis[0], x_axis[1], x_axis[2], 
              color='r', length=length, normalize=True, arrow_length_ratio=0.1)
    
    # Plot Y axis (Green)
    ax.quiver(origin[0], origin[1], origin[2], 
              y_axis[0], y_axis[1], y_axis[2], 
              color='g', length=length, normalize=True, arrow_length_ratio=0.1)
    
    # Plot Z axis (Blue)
    ax.quiver(origin[0], origin[1], origin[2], 
              z_axis[0], z_axis[1], z_axis[2], 
              color='b', length=length, normalize=True, arrow_length_ratio=0.1)
    
    ax.text(origin[0], origin[1], origin[2], label, color='k')

# --- 3. The Interactive Loop ---

def visualize_transform(x, y, z, theta_z):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # 1. Plot World Frame (Identity) at 0,0,0
    T_world = np.eye(4)
    plot_coordinate_frame(ax, T_world, label="World {W}", length=1.5)
    
    # 2. Calculate New Frame {B}
    # Create rotation matrix (Rotated by theta around Z)
    R = get_rotation_matrix_z(theta_z)
    # Create translation vector
    P = np.array([x, y, z])
    # Combine into Homogeneous Transform
    T_body = get_homogeneous_transform(R, P)
    
    # 3. Plot New Frame
    plot_coordinate_frame(ax, T_body, label="Body {B}", length=1.0)
    
    # Visualize the connection (Translation vector)
    ax.plot([0, x], [0, y], [0, z], 'k--', alpha=0.5, label='Translation Vector P')

    # Formatting the plot
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(0, 5)
    ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
    ax.legend()
    plt.title(f"Homogeneous Transform Visualization\nRotation: {theta_z}° around Z")
    plt.show()

    # --- Print the Matrix to build intuition ---
    print("Homogeneous Transformation Matrix T:")
    # Formatting for clean printing
    with np.printoptions(precision=2, suppress=True):
        print(T_body)
    print("\nObserve:")
    print(f"1. The last column is exactly your position input: [{x}, {y}, {z}]")
    print(f"2. The top-left 3x3 is the Rotation Matrix.")
    print(f"3. Column 0 is the NEW X-axis direction.")
    print(f"4. Column 1 is the NEW Y-axis direction.")

# --- 4. Run the Widget ---
interact(visualize_transform, 
         x=FloatSlider(min=-3, max=3, step=0.1, value=2),
         y=FloatSlider(min=-3, max=3, step=0.1, value=2),
         z=FloatSlider(min=0, max=3, step=0.1, value=0),
         theta_z=FloatSlider(min=-180, max=180, step=5, value=45));

interactive(children=(FloatSlider(value=2.0, description='x', max=3.0, min=-3.0), FloatSlider(value=2.0, descr…

In [2]:
!pip install jax

Collecting jax
  Downloading jax-0.8.2-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.8.2,>=0.8.2 (from jax)
  Downloading jaxlib-0.8.2-cp312-cp312-macosx_11_0_arm64.whl.metadata (1.3 kB)
Collecting numpy>=2.0 (from jax)
  Using cached numpy-2.4.0-cp312-cp312-macosx_14_0_arm64.whl.metadata (6.6 kB)
Collecting scipy>=1.13 (from jax)
  Using cached scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl.metadata (62 kB)
Downloading jax-0.8.2-py3-none-any.whl (2.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading jaxlib-0.8.2-cp312-cp312-macosx_11_0_arm64.whl (55.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.9/55.9 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hUsing cached numpy-2.4.0-cp312-cp312-macosx_14_0_arm64.whl (5.2 MB)
Downloading scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl (20.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import jax.numpy as jnp
import numpy as np
def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)
  

In [6]:
from jax import jit 
norm_compiled = jit(norm)

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

True

In [7]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

149 μs ± 2.96 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
126 μs ± 283 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
