In [1]:
import os
import subprocess
from pathlib import Path
import gdown
import zipfile


def get_project_root():
    # get the absolute path to the root of the git repo
    root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip().decode("utf-8")
    return Path(root)

# get project root and append it to path
project_root = get_project_root()


# create data directory for embeddings
base_path = os.path.normpath(os.path.join(project_root, ".."))
data_path = os.path.join(base_path, "data")
if not os.path.exists(data_path):
    os.makedirs(data_path)

# download dataset if it doesn't exist
dataset = "argo"  # "waymo" / ""argo"
dataset_path = os.path.join(data_path, f"{dataset}_data")
if not os.path.exists(dataset_path):
    file_id = "1FbMXOT5Upqhm51ZxPHVz6g64KK2Cgbc6" if dataset == "waymo" else "1s4pKBaz8bb3ZvRwyDFwl-YUvN-TAPLCp"

    download_url = f"https://drive.google.com/uc?id={file_id}"
    zip_path = dataset_path + ".zip"
    gdown.download(download_url, zip_path, quiet=False)

    # unzip
    with zipfile.ZipFile(zip_path, 'r') as zf:
        zf.extractall(data_path)
    os.remove(zip_path)


Load target embeddings; torch.Size([48, 11, 128])

In [2]:
### Load embeddings

In [3]:
import os
from glob import glob
from utils.embs_all import load_embeddings
from utils.embs_contrastive import load_contrastive_embed_pairs


# load data
paths_inputs = sorted(glob(f"{dataset_path}/input*"))
paths_embeds = sorted(glob(f"{dataset_path}/target_embs*"))

# stack embeddings wrt types
embs = load_embeddings(paths_inputs, paths_embeds)

# trim and stack contrastive pairs of embeddings
contrastive_embs = load_contrastive_embed_pairs(embs)

100%|██████████| 211/211 [00:00<00:00, 620.50it/s]


In [None]:
### Get training data

In [None]:
import torch
from tqdm import tqdm

# get all embeddings
embs_all = []
for path_embs in tqdm(paths_embeds):
    embs = torch.load(path_embs, map_location=torch.device('cpu'))
    embs_all.append(embs)

# we have 3 modules; hence 3 hidden states in embs_all
# 48 batch size, 11 past (Waymo), 128 hidden state size
len(embs_all), embs_all[0][0].shape

In [None]:
# get the last module's hidden state
embs_all_last = []
for (path_input, path_embs) in tqdm(zip(paths_inputs, paths_embeds)):
    embs = torch.load(path_embs, map_location=torch.device('cpu'))
    # take the last module (layer) and the last past time step (the current one) of it
    embs_all_last.append(embs[-1][:, -1])
embs_all_last = torch.stack(embs_all_last, dim=0)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

# create a TensorDataset and DataLoader
train_dataset = TensorDataset(embs_all_last, embs_all_last)  # Autoencoder output = input
# set 'drop_last=False' to get the exact the same loss values as in the paper
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=False, num_workers=15)  # Adjust batch size as needed 

### Create control vectors using PCA

In [None]:
# output dir
out_reldir = f"out/control-vectors/{dataset}/"
out_path = os.path.join(base_path, out_reldir)
if not os.path.exists(out_path):
    os.makedirs(out_path)

In [None]:
from future_motion.utils.interpretability.control_vectors import fit_control_vector


idx_layer = 2
PCA_control_vectors = {}
for key in contrastive_embs.keys():
    PCA_control_vectors[key] = fit_control_vector(contrastive_embs[key])
    torch.save(torch.tensor(PCA_control_vectors[key]), f=f"{out_path}/pca_{key}_layer{idx_layer}.pt")

In [None]:
### Create control vectors using SAE

In [None]:
from future_motion.utils.interpretability.sparse_autoencoder import SparseAutoencoder, SAE
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl


autoencoders = {}
d_mlp = 128
for n_hidden in [512, 256, 128, 64, 32, 16]:

    # Create specific subdirectories under out_path
    log_dir = os.path.join(out_path, "logs")
    checkpoint_dir = os.path.join(out_path, "checkpoints")
    
    # Create directories if they don't exist
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    model = SAE(d_mlp, n_hidden, max_epochs=10000)

    logger = TensorBoardLogger(log_dir, name="sae_training")
    checkpoint_callback = ModelCheckpoint(
        monitor="loss",
        dirpath=checkpoint_dir,
        filename=f"sae{n_hidden}-" + "{epoch:02d}-{loss:.4f}",
        save_top_k=3,
        mode="min"
    )

    # Create the trainer
    trainer = pl.Trainer(
        max_epochs=10000,
        logger=logger,
        callbacks=[checkpoint_callback],
        accelerator="auto",
        devices="auto"
    )

    # Train the model
    trainer.fit(model=model, train_dataloaders=train_loader)

    # Only load and save the checkpoint on the main process (global_rank == 0)
    if trainer.global_rank == 0:
        # Load the best checkpoint
        best_model_path = checkpoint_callback.best_model_path
        best_model = SAE.load_from_checkpoint(best_model_path)
        autoencoders[n_hidden] = best_model

        print(f"Best model loaded from {best_model_path}")
        print(f"Best loss: {checkpoint_callback.best_model_score.item():.6f}")

        # Save model to out_path
        model_save_path = os.path.join(out_path, f"sae_waymo_n{n_hidden}.pth")
        torch.save(best_model.state_dict(), model_save_path)
        print(f"Best model saved to {model_save_path}")


Training for 10.000 epochs - SAE
(Google Colab T4-instance; seed not set / cuda=12.5 / torch=2.6.0+cu124 / sklearn=2.0.2 / python=3.11.11 (final))

| Hidden Dim | Epoch         | Total Loss   | L1 Loss       | L2 Loss        | Total Reconst Loss |
|------------|---------------|--------------|---------------|----------------|--------------------|
| 512        | 9805/10000    | 4.005656276  | 1.524447083   | 8270.697265625 | 0.001645120210014  |
| 256        | 9845/10000    | 3.724161590  | 1.376968503   | 7823.977050781 | 0.001388887991197  |
| 128        | 9820/10000    | 4.139010770  | 1.556326985   | 8608.9453125   | 0.001653907238506  |
| 64         | 9348/10000    | 4.561335734  | 1.892843366   | 8894.974609375 | 0.001926084747538  |
| 32         | 9864/10000    | 7.141473430  | 3.902811527   | 10795.541015625| 0.004311752039939  |
| 16         | 9956/10000    | 17.441959654 | 13.368986130  | 13576.573242188| 0.014228038489819  |



In [None]:
from future_motion.utils.interpretability.control_vectors import fit_control_vector


# Create control vectors with SAE
SAE_control_vectors = {}
for hidden_dims, autoencoder in autoencoders.items():
    SAE_control_vectors[hidden_dims] = dict()
    for key in contrastive_embs.keys():
        print(key)
        cv = fit_control_vector(contrastive_embs[key], autoencoder=autoencoder, verbose_explained_variance=True)
        SAE_control_vectors[hidden_dims][key] = cv
        torch.save(torch.tensor(cv), f=f"{out_path}/sae{hidden_dims}_{key}_layer{idx_layer}.pt")