In [2]:
import os
import sys
import subprocess
from pathlib import Path


def get_project_root():
    # get the absolute path to the root of the git repo
    root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip().decode("utf-8")
    return Path(root)

# get project root and append it to path
# (in case the workspace is launched from somewhere else)
project_root = get_project_root()
sys.path.append(str(project_root))

# output dir
dataset = "waymo"
out_reldir = f"out/control-vectors/{dataset}/"
base_path = os.path.normpath(os.path.join(project_root, ".."))
out_path = os.path.join(base_path, out_reldir)

In [3]:
import torch


# Load PCA-based control vectors
pca_files = [file for file in os.listdir(out_path) if file.startswith("pca") and file.endswith(".pt")]
pca_control_vectors = {}
for fname in pca_files:
    file_path = os.path.join(out_path, fname)
    pca_control_vectors[fname.split("_")[1]] = torch.load(file_path)


In [4]:
# Load SAE-based control vectors
sae_files = [file for file in os.listdir(out_path) if file.startswith("sae") and file.endswith(".pt")]
sae_control_vectors = {}
for fname in sae_files:
    dim = fname.split("_")[0][3:]
    sae_control_vectors[dim] = {}
    file_path = os.path.join(out_path, fname)
    sae_control_vectors[dim][fname.split("_")[1]] = torch.load(file_path)

In [5]:
from typing import Dict
import numpy as np
from collections import defaultdict
from future_motion.utils.similarity.vector import VectorComparison


# Helper function to inspect the similarity between two sets of control vectors
def inspect_vector(control_vectors1: Dict[str, np.ndarray], control_vectors2: Dict[str, np.ndarray]):
    result = defaultdict(dict)
    for key1, cv1 in control_vectors1.items():
        for key2, cv2 in control_vectors2.items():
            vec_compare =  VectorComparison(cv1, cv2)
            result[key1 + "_" + key2]["cos_sim_deg"] = vec_compare.cos_sim_deg()
    return result

In [6]:
import pandas as pd


comparison = {}
comparison["PCA-PCA"] = inspect_vector(pca_control_vectors, pca_control_vectors)
comparison_df = pd.DataFrame(comparison["PCA-PCA"]).T
if False:
    print(comparison_df)

**PCA w/ PCA**

| **cosine sim** | speed      | acceleration | direction  | agent      |
|----------------|------------|--------------|------------|------------|
| speed          | 0.0        | 11.458136    | 122.603544 | 10.865894  |
| acceleration   |            | 0.0          | 126.78761  | 6.82372    |
| direction      |            |              | 0.0        | 128.655917 |
| agent          |            |              |            | 0.0        |


In [8]:
for hidden_dim, sae_cv in sae_control_vectors.items():

    # Compare PCA-based and SAE-based control vector
    comparison["PCA-SAE"] = inspect_vector(pca_control_vectors, sae_cv)

    # Compare SAE-based control vector with itself
    comparison["SAE-SAE"] = inspect_vector(sae_cv, sae_cv)
    
    if False:
        print(hidden_dim)
        print(pd.DataFrame(comparison["PCA-SAE"]).T)
        print(pd.DataFrame(comparison["SAE-SAE"]).T)

**PCA w/ PCA**

| **Angle (°)**  | speed      | acceleration | direction  | agent      |
|----------------|------------|--------------|------------|------------|
| speed          | 0.0        | 11.458136    | 122.603544 | 10.865894  |
| acceleration   |            | 0.0          | 126.78761  | 6.82372    |
| direction      |            |              | 0.0        | 128.655917 |
| agent          |            |              |            | 0.0        |

---

**PCA w/ SAE (512 hidden-dim)**


| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 20.7  | 28.6         | 123.8     | 23.4  |
| acceleration   | 19.1  | 23.0         | 128.5     | 18.6  |
| direction      | 115.9 | 116.6        | 13.7      | 120.8 |
| agent          | 19.4  | 24.4         | 130.2     | 18.3  |

---

**SAE w/ SAE (512 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 0.0   | 10.2         | 121.8     | 7.6   |
| acceleration   |       | 0.0          | 123.7     | 7.6   |
| direction      |       |              | 0.0       | 126.9 |
| agent          |       |              |           | 0.0   |

---

**PCA w/ SAE (256 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 21.5  | 26.8         | 123.8     | 23.3  |
| acceleration   | 20.3  | 21.0         | 128.7     | 18.7  |
| direction      | 114.7 | 116.9        | 13.7      | 120.1 |
| agent          | 20.8  | 23.1         | 130.2     | 18.7  |


---

**SAE w/ SAE (256 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 0.0   | 9.9          | 120.9     | 7.9   |
| acceleration   |       | 0.0          | 123.7     | 7.2   |
| direction      |       |              | 0.0       | 126.3 |
| agent          |       |              |           | 0.0   |

---

**PCA w/ SAE (128 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 19.7  | 25.3         | 124.3     | 21.6  |
| acceleration   | 19.2  | 20.0         | 128.8     | 17.5  |
| direction      | 115.2 | 117.1        | 12.1      | 120.5 |
| agent          | 19.5  | 21.8         | 130.4     | 17.1  |

---

**SAE w/ SAE (128 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 0.0   | 9.5          | 120.6     | 7.8   |
| acceleration   |       | 0.0          | 122.9     | 7.0   |
| direction      |       |              | 0.0       | 125.8 |
| agent          |       |              |           | 0.0   |

---

**PCA w/ SAE (64 hidden-dim)**


| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 18.1  | 23.7         | 124.7     | 19.3  |
| acceleration   | 19.3  | 19.9         | 128.9     | 16.5  |
| direction      | 115.0 | 116.6        | 13.3      | 120.5 |
| agent          | 19.8  | 21.9         | 130.5     | 16.4  |

---

**SAE w/ SAE (64 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 0.0   | 9.7          | 121.0     | 8.0   |
| acceleration   |       | 0.0          | 123.2     | 7.5   |
| direction      |       |              | 0.0       | 126.3 |
| agent          |       |              |           | 0.0   |

---

**PCA w/ SAE (32 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 14.7  | 18.8         | 126.4     | 15.5  |
| acceleration   | 18.0  | 15.5         | 130.3     | 14.1  |
| direction      | 114.4 | 116.9        | 10.9      | 120.2 |
| agent          | 18.1  | 17.6         | 132.0     | 13.4  |

---

**SAE w/ SAE (32 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 0.0   | 9.8          | 120.3     | 8.3   |
| acceleration   |       | 0.0          | 122.8     | 7.0   |
| direction      |       |              | 0.0       | 125.8 |
| agent          |       |              |           | 0.0   |

---

**PCA w/ SAE (16 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent |
|----------------|-------|--------------|-----------|-------|
| speed          | 23.5  | 25.1         | 126.6     | 21.8  |
| acceleration   | 28.4  | 26.0         | 128.9     | 23.5  |
| direction      | 110.2 | 111.9        | 24.6      | 116.6 |
| agent          | 28.0  | 26.8         | 131.0     | 22.5  |

---

**SAE w/ SAE (16 hidden-dim)**

| **Angle (°)**  | speed | acceleration | direction | agent  |
|----------------|-------|--------------|-----------|--------|
| speed          | 0.0   | 9.5          | 124.1     | 9.3    |
| acceleration   |       | 0.0          | 125.2     | 7.5    |
| direction      |       |              | 0.0       | 129.3  |
| agent          |       |              |           | 0.0    |