# Deep-Flow: PCA Manifold Exploration
Use the sliders to see how each principal component affects the shape of the driving maneuver.

**Coefficients are whitened:** A value of 0 is the average, 1.0 is 1 standard deviation.

In [1]:
import matplotlib
# Compatibility patch for Matplotlib 3.6.1 and VS Code/matplotlib-inline
if not hasattr(matplotlib.RcParams, '_get'):
    matplotlib.RcParams._get = lambda self, key: self.get(key)

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import json
import os
from ipywidgets import interact, FloatSlider, Layout

## 1. Load PCA Basis
We load the basis generated by `compute_pca.py`.

In [2]:
pca_path = "/mnt/d/waymo_datasets/Deep-Flow_Dataset/pca_basis.json"

if not os.path.exists(pca_path):
    raise FileNotFoundError(f"PCA basis not found at {pca_path}. Run compute_pca.py first.")

with open(pca_path, "r") as f:
    pca_data = json.load(f)

# Convert to numpy
COMPONENTS = np.array(pca_data['components']) # [6, 160]
MEAN = np.array(pca_data['mean'])             # [160]
STDS = np.array(pca_data['stds'])             # [6]
SCALE_POS = pca_data.get('scale_pos', 50.0)   # 50.0

print(f"✅ Loaded PCA basis with {len(STDS)} components.")
print(f"Variance Explained: {np.sum(pca_data['explained_variance'])*100:.2f}%")

✅ Loaded PCA basis with 12 components.
Variance Explained: 100.00%


## 2. Reconstruction Function
This function maps 6 coefficients back to a physical trajectory.

In [3]:
def reconstruct(coeffs):
    """
    Input: coeffs [6] (Whitened)
    Output: traj [80, 2] (Meters)
    """
    # 1. Un-whiten: coeffs * stds
    unwhitened = coeffs * STDS
    
    # 2. Project back to coordinate space (Linear combination of components)
    # x_norm = sum(coeff_i * component_i) + mean
    x_norm = (unwhitened @ COMPONENTS) + MEAN
    
    # 3. Reshape to [80, 2] (x, y)
    x_norm = x_norm.reshape(80, 2)
    
    # 4. Scale back to meters
    return x_norm * SCALE_POS

## 3. Interactive Visualization
Move the sliders to see the maneuver change.

In [5]:
# %%
import ipywidgets as widgets
from IPython.display import display, clear_output

# --- 1. Define the UI Components ---

# Create the sliders in a more compact dictionary
slider_list = []
for i in range(12):
    s = widgets.FloatSlider(
        min=-4.0, max=4.0, step=0.1, value=0.0,
        description=f'PC {i+1} ({pca_data["explained_variance"][i]*100:.1f}%)',
        orientation='horizontal',
        layout=widgets.Layout(width='300px'),
        style={'description_width': '100px'}
    )
    slider_list.append(s)

# Reset Button
reset_btn = widgets.Button(
    description='Reset to Mean',
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    icon='refresh',
    layout=widgets.Layout(width='300px', margin='10px 0 0 100px')
)

def on_reset_clicked(b):
    for s in slider_list:
        s.value = 0.0

reset_btn.on_click(on_reset_clicked)

# Output widget for the plot
out = widgets.Output()

# --- 2. The Main Interaction Logic ---

def update_visualization(*args):
    # Get values from sliders
    coeffs = np.array([s.value for s in slider_list])
    
    # Reconstruct
    traj = reconstruct(coeffs)
    avg_traj = reconstruct(np.zeros(12))
    
    with out:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(8, 8))
        
        # Plot styling
        ax.set_facecolor('#f9f9f9')
        ax.grid(True, linestyle=':', alpha=0.6, color='#bdc3c7')
        
        # Plot Mean Reference
        ax.plot(avg_traj[:, 0], avg_traj[:, 1], color='#bdc3c7', 
                linestyle='--', linewidth=1.5, label='Mean Maneuver', alpha=0.8)
        
        # Plot Generated Trajectory
        # We use a gradient or a thick line for "SOTA" look
        ax.plot(traj[:, 0], traj[:, 1], color='#3498db', 
                linewidth=4, label='Manifold Path', zorder=4)
        
        # Start and End points
        ax.scatter(0, 0, color='#e74c3c', s=150, marker='X', edgecolors='black', label='Start', zorder=5)
        ax.scatter(traj[-1, 0], traj[-1, 1], color='#2ecc71', s=150, marker='o', edgecolors='black', label='End (Goal)', zorder=5)

        # Dynamic Zoom: Center around the path
        all_x = np.concatenate([traj[:, 0], avg_traj[:, 0]])
        all_y = np.concatenate([traj[:, 1], avg_traj[:, 1]])
        
        # Set limits with some padding
        ax.set_xlim(-20, 80) # Adjust based on your typical trajectory length
        ax.set_ylim(-50, 50)
        
        ax.set_aspect('equal')
        ax.legend(loc='upper left', frameon=True, framealpha=0.9)
        ax.set_xlabel("X (Meters Forward)", fontsize=10)
        ax.set_ylabel("Y (Meters Lateral)", fontsize=10)
        
        # Annotate components in the plot
        info_text = "\n".join([f"PC{i+1}: {s.value:+.1f}" for i, s in enumerate(slider_list)])
        ax.text(0.95, 0.05, info_text, transform=ax.transAxes, 
                fontsize=9, verticalalignment='bottom', horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))

        plt.show()

# Link sliders to the update function
for s in slider_list:
    s.observe(update_visualization, names='value')

# --- 3. Layout and Display ---

# Group sliders into two columns
col1 = widgets.VBox(slider_list[:3])
col2 = widgets.VBox(slider_list[3:] + [reset_btn])
controls = widgets.HBox([col1, col2])

# Main Layout: Controls on Top, Plot Below (OR side by side)
# Let's try Side-by-Side for wide screens
dashboard = widgets.HBox([
    widgets.VBox([widgets.Label(value="MANIFOLD CONTROLS (STDs)"), controls]), 
    out
], layout=widgets.Layout(align_items='center'))

display(dashboard)

# Initial Plot
update_visualization()

HBox(children=(VBox(children=(Label(value='MANIFOLD CONTROLS (STDs)'), HBox(children=(VBox(children=(FloatSlid…