# Policy Derivatives Analysis

This notebook computes and visualizes all derivatives of the policy with respect to each state variable.

The policy derivatives (Jacobian matrix) show how sensitive each action dimension is to changes in each state variable.


## 1. Setup and Imports


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from stable_baselines3 import PPO
import os
import sys

# Import environment and utilities
from snake_env import FixedWavelengthXZOnlyContinuumSnakeEnv
from Utilities.policy_gradient_utils import (
    compute_policy_gradient,
    compute_policy_jacobian,
    get_policy_sensitivity
)
import config

# Set style for plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10


## 2. Load Trained Model


In [None]:
# Path to trained model
model_path = os.path.join(config.PATHS["model_dir"], config.PATHS["model_name"])

# Check if model exists, otherwise try checkpoint
if not os.path.exists(model_path + ".zip"):
    checkpoint_path = os.path.join(config.PATHS["model_dir"], config.PATHS["checkpoint_name"])
    if os.path.exists(checkpoint_path + ".zip"):
        model_path = checkpoint_path
        print(f"Using checkpoint model: {model_path}")
    else:
        raise FileNotFoundError(f"Model not found at {model_path}.zip or {checkpoint_path}.zip")
else:
    print(f"Using model: {model_path}")

# Create environment first (needed to load model)
env = FixedWavelengthXZOnlyContinuumSnakeEnv(
    fixed_wavelength=config.ENV_CONFIG["fixed_wavelength"],
    obs_keys=config.ENV_CONFIG["obs_keys"],
)
env.period = config.ENV_CONFIG["period"]
env.ratio_time = config.ENV_CONFIG["ratio_time"]
env.rut_ratio = config.ENV_CONFIG["rut_ratio"]
env.max_episode_length = config.ENV_CONFIG["max_episode_length"]

# Load model
model = PPO.load(model_path, env=env)
print(f"Model loaded successfully!")
print(f"Action space: {env.action_space.shape}")
print(f"Observation space: {env.observation_space.shape}")


## 3. Get Sample States


In [None]:
# Collect sample states from the environment
num_sample_states = 5
sample_states = []

print("Collecting sample states from environment...")
obs, info = env.reset()
sample_states.append(obs.copy())

# Run a few steps to get diverse states
for i in range(num_sample_states - 1):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    sample_states.append(obs.copy())
    
    if terminated or truncated:
        obs, info = env.reset()

sample_states = np.array(sample_states)
print(f"Collected {len(sample_states)} sample states")
print(f"State shape: {sample_states[0].shape}")

# Use the first state for detailed analysis
state = sample_states[0]
print(f"\nUsing first state for derivative computation:")
print(f"State shape: {state.shape}")


## 4. Understand State Space Structure


In [None]:
# Parse state space based on observation keys
obs_keys = config.ENV_CONFIG["obs_keys"]
n_elem = env._n_elem
n_nodes = n_elem + 1

# Calculate sizes for each observation component
# Note: FixedWavelengthXZOnlyContinuumSnakeEnv filters y-components for some keys
state_component_sizes = {}
state_component_indices = {}
idx = 0

for key in obs_keys:
    if key == "avg_velocity":
        size = 2  # XZ only (2D)
    elif key == "curvature":
        size = (n_elem - 1) * 2  # XZ only (2D)
    elif key == "velocity":
        size = n_nodes * 2  # XZ only (2D)
    elif key == "tangents":
        size = n_elem * 2  # XZ only (2D)
    elif key == "position":
        size = n_nodes * 3  # Full 3D
    elif key == "director":
        size = n_elem * 9  # Full 3x3 matrices
    elif key == "avg_position":
        size = 3  # Full 3D
    elif key == "time":
        size = 1
    else:
        size = 0
    
    state_component_sizes[key] = size
    state_component_indices[key] = (idx, idx + size)
    idx += size

print("State space structure:")
print("=" * 60)
for key, (start, end) in state_component_indices.items():
    print(f"{key:20s}: indices [{start:3d}:{end:3d}] (size: {end-start:3d})")
print("=" * 60)
print(f"Total state dimension: {state.shape[0]}")
print(f"Action dimension: {env.action_space.shape[0]}")


## 5. Compute Policy Jacobian (All Derivatives)


In [None]:
# Compute the full Jacobian matrix
# Shape: (action_dim, state_dim)
# Element [i, j] is the derivative of action[i] w.r.t. state[j]
print("Computing policy Jacobian matrix...")
jacobian = compute_policy_jacobian(model, state)

print(f"Jacobian shape: {jacobian.shape}")
print(f"Action dimension: {jacobian.shape[0]}")
print(f"State dimension: {jacobian.shape[1]}")
print(f"\nJacobian statistics:")
print(f"  Min: {jacobian.min():.6f}")
print(f"  Max: {jacobian.max():.6f}")
print(f"  Mean: {jacobian.mean():.6f}")
print(f"  Std: {jacobian.std():.6f}")
print(f"  Norm: {np.linalg.norm(jacobian):.6f}")


## 6. Compute Additional Gradient Information


In [None]:
# Compute comprehensive gradient information
grad_info = compute_policy_gradient(
    model,
    state,
    wrt_action_mean=True,
    wrt_value=True
)

print("Gradient information:")
print(f"Action mean shape: {grad_info['action_mean'].shape}")
print(f"Action mean: {grad_info['action_mean']}")
print(f"\nValue estimate: {grad_info['value'][0, 0]:.6f}")
print(f"\nAction mean gradient shape: {grad_info['action_mean_grad'].shape}")
print(f"Value gradient shape: {grad_info['value_grad'].shape}")


## 7. Visualize Jacobian Matrix


In [None]:
# Create a heatmap of the Jacobian matrix
fig, ax = plt.subplots(figsize=(16, 6))

im = ax.imshow(jacobian, aspect='auto', cmap='RdBu_r', interpolation='nearest')
ax.set_xlabel('State Variable Index', fontsize=12)
ax.set_ylabel('Action Dimension', fontsize=12)
ax.set_title('Policy Jacobian Matrix: ∂action/∂state', fontsize=14, fontweight='bold')

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Derivative Value', fontsize=11)

# Add grid
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print some example derivatives
print("\nExample derivatives (first 5 actions, first 10 state variables):")
print(jacobian[:5, :10])


## 8. Derivatives by State Component


In [None]:
# Compute statistics for each state component
component_stats = {}

for key, (start, end) in state_component_indices.items():
    component_jacobian = jacobian[:, start:end]
    
    component_stats[key] = {
        'mean': np.mean(np.abs(component_jacobian)),
        'std': np.std(component_jacobian),
        'max': np.max(np.abs(component_jacobian)),
        'norm': np.linalg.norm(component_jacobian),
        'size': end - start
    }

# Print statistics
print("Derivative statistics by state component:")
print("=" * 80)
print(f"{'Component':<20s} {'Size':<8s} {'Mean |∂|':<12s} {'Max |∂|':<12s} {'Norm':<12s}")
print("=" * 80)
for key, stats in component_stats.items():
    print(f"{key:<20s} {stats['size']:<8d} {stats['mean']:<12.6f} {stats['max']:<12.6f} {stats['norm']:<12.6f}")
print("=" * 80)


## 9. Visualize Derivatives by State Component


In [None]:
# Create bar plot of mean absolute derivatives by component
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Mean absolute derivatives
ax = axes[0, 0]
components = list(component_stats.keys())
means = [component_stats[c]['mean'] for c in components]
ax.barh(components, means)
ax.set_xlabel('Mean Absolute Derivative', fontsize=11)
ax.set_title('Mean |∂action/∂state| by Component', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')

# 2. Maximum absolute derivatives
ax = axes[0, 1]
maxs = [component_stats[c]['max'] for c in components]
ax.barh(components, maxs, color='orange')
ax.set_xlabel('Maximum Absolute Derivative', fontsize=11)
ax.set_title('Max |∂action/∂state| by Component', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')

# 3. Norm of derivatives
ax = axes[1, 0]
norms = [component_stats[c]['norm'] for c in components]
ax.barh(components, norms, color='green')
ax.set_xlabel('Frobenius Norm', fontsize=11)
ax.set_title('Frobenius Norm of Derivatives by Component', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')

# 4. Heatmap for each component
ax = axes[1, 1]
# Create a summary matrix: rows = components, cols = actions
summary_matrix = np.zeros((len(components), jacobian.shape[0]))
for i, key in enumerate(components):
    start, end = state_component_indices[key]
    # Use mean absolute value across state variables in this component
    summary_matrix[i, :] = np.mean(np.abs(jacobian[:, start:end]), axis=1)

im = ax.imshow(summary_matrix, aspect='auto', cmap='YlOrRd', interpolation='nearest')
ax.set_yticks(range(len(components)))
ax.set_yticklabels(components)
ax.set_xlabel('Action Dimension', fontsize=11)
ax.set_title('Mean |∂action/∂state| Heatmap by Component', fontsize=12, fontweight='bold')
plt.colorbar(im, ax=ax, label='Mean |Derivative|')

plt.tight_layout()
plt.show()


In [None]:
# Plot derivatives for each action dimension separately
num_actions = jacobian.shape[0]
fig, axes = plt.subplots(num_actions, 1, figsize=(16, 3 * num_actions))

if num_actions == 1:
    axes = [axes]

for action_idx in range(num_actions):
    ax = axes[action_idx]
    
    # Plot derivatives for this action
    ax.plot(jacobian[action_idx, :], linewidth=1.5, alpha=0.7)
    ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
    
    # Add vertical lines to separate state components
    for key, (start, end) in state_component_indices.items():
        if start > 0:
            ax.axvline(x=start, color='gray', linestyle=':', linewidth=0.5, alpha=0.5)
    
    ax.set_xlabel('State Variable Index', fontsize=10)
    ax.set_ylabel(f'∂action[{action_idx}]/∂state', fontsize=10)
    ax.set_title(f'Derivatives for Action Dimension {action_idx}', fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print summary statistics for each action
print("\nDerivative statistics for each action dimension:")
print("=" * 70)
print(f"{'Action':<10s} {'Mean |∂|':<12s} {'Max |∂|':<12s} {'Std':<12s} {'Norm':<12s}")
print("=" * 70)
for action_idx in range(num_actions):
    action_grad = jacobian[action_idx, :]
    print(f"Action[{action_idx}]: {np.mean(np.abs(action_grad)):<12.6f} {np.max(np.abs(action_grad)):<12.6f} {np.std(action_grad):<12.6f} {np.linalg.norm(action_grad):<12.6f}")
print("=" * 70)


## 11. Derivatives for Specific State Variables


In [None]:
# Create detailed plots for specific state components
# Focus on the most important components

important_components = ['avg_velocity', 'curvature', 'velocity', 'tangents']

fig, axes = plt.subplots(len(important_components), 1, figsize=(16, 4 * len(important_components)))

if len(important_components) == 1:
    axes = [axes]

for comp_idx, key in enumerate(important_components):
    if key not in state_component_indices:
        continue
        
    ax = axes[comp_idx]
    start, end = state_component_indices[key]
    
    # Plot derivatives for all actions
    for action_idx in range(num_actions):
        ax.plot(
            range(start, end),
            jacobian[action_idx, start:end],
            label=f'Action[{action_idx}]',
            linewidth=1.5,
            alpha=0.7
        )
    
    ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
    ax.set_xlabel(f'{key} State Variable Index (within component)', fontsize=10)
    ax.set_ylabel('∂action/∂state', fontsize=10)
    ax.set_title(f'Derivatives for State Component: {key} (indices {start}-{end-1})', fontsize=11, fontweight='bold')
    ax.legend(loc='best', fontsize=9)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## 12. Sensitivity Analysis
