In [None]:
import os
import sys
import subprocess
from pathlib import Path


def get_project_root():
    # get the absolute path to the root of the git repo
    root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip().decode("utf-8")
    return Path(root)

# get project root and append it to path
project_root = get_project_root()
sys.path.append(str(project_root))

# embeddings path
dataset = "waymo"
data_dir = f"{dataset}_data"
base_path = os.path.normpath(os.path.join(project_root, ".."))

# output dir
out_reldir = f"out/control-vectors/{dataset}/"
out_path = os.path.join(base_path, out_reldir)
if not os.path.exists(out_path):
    os.makedirs(out_path)

Load target embeddings; torch.Size([48, 11, 128])

In [22]:
### Load embeddings

In [23]:
import os
from glob import glob
from utils.embs_all import load_embeddings
from utils.embs_contrastive import load_contrastive_embed_pairs


# load data
data_path = os.path.join(base_path, "data", data_dir)
paths_inputs = sorted(glob(f"{data_path}/input*"))
paths_embeds = sorted(glob(f"{data_path}/target_embs*"))

# stack embeddings wrt types
embs = load_embeddings(paths_inputs, paths_embeds)

# trim and stack contrastive pairs of embeddings
contrastive_embs = load_contrastive_embed_pairs(embs)

100%|██████████| 204/204 [00:00<00:00, 454.09it/s]


In [24]:
### Get training data

In [25]:
import torch
from tqdm import tqdm

# get all embeddings
embs_all = []
for path_embs in tqdm(paths_embeds):
    embs = torch.load(path_embs, map_location=torch.device('cpu'))
    embs_all.append(embs)

# we have 3 modules; hence 3 hidden states in embs_all
# 48 batch size, 11 past (Waymo), 128 hidden state size
len(embs_all), embs_all[0][0].shape

100%|██████████| 204/204 [00:00<00:00, 1996.55it/s]


(204, torch.Size([48, 11, 128]))

In [26]:
# get the last module's hidden state
embs_all_last = []
for (path_input, path_embs) in tqdm(zip(paths_inputs, paths_embeds)):
    embs = torch.load(path_embs, map_location=torch.device('cpu'))
    # take the last module (layer) and the last past time step (the current one) of it
    embs_all_last.append(embs[-1][:, -1])
embs_all_last = torch.stack(embs_all_last, dim=0)

204it [00:00, 2438.35it/s]


In [27]:
from torch.utils.data import DataLoader, TensorDataset

# create a TensorDataset and DataLoader
train_dataset = TensorDataset(embs_all_last, embs_all_last)  # Autoencoder output = input
# set 'drop_last=False' to get the exact the same loss values as in the paper
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=False, num_workers=15)  # Adjust batch size as needed 

### Create control vectors using PCA

In [28]:
from future_motion.utils.interpretability.control_vectors import fit_control_vector


idx_layer = 2
PCA_control_vectors = {}
for key in contrastive_embs.keys():
    PCA_control_vectors[key] = fit_control_vector(contrastive_embs[key])
    torch.save(torch.tensor(PCA_control_vectors[key]), f=f"{out_path}/pca_{key}_layer{idx_layer}.pt")

In [29]:
### Create control vectors using SAE

In [None]:
from future_motion.utils.interpretability.sparse_autoencoder import SparseAutoencoder, SAE
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl


autoencoders = {}
d_mlp = 128
for n_hidden in [512, 256, 128, 64, 32, 16]:

    # Create specific subdirectories under out_path
    log_dir = os.path.join(out_path, "logs")
    checkpoint_dir = os.path.join(out_path, "checkpoints")
    
    # Create directories if they don't exist
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    model = SAE(d_mlp, n_hidden, max_epochs=10000)

    logger = TensorBoardLogger(log_dir, name="sae_training")
    checkpoint_callback = ModelCheckpoint(
        monitor="loss",
        dirpath=checkpoint_dir,
        filename=f"sae{n_hidden}-" + "{epoch:02d}-{loss:.4f}",
        save_top_k=3,
        mode="min"
    )

    # Create the trainer
    trainer = pl.Trainer(
        max_epochs=10000,
        logger=logger,
        callbacks=[checkpoint_callback],
        accelerator="auto",
        devices="auto"
    )

    # Train the model
    trainer.fit(model=model, train_dataloaders=train_loader)

    # Only load and save the checkpoint on the main process (global_rank == 0)
    if trainer.global_rank == 0:
        # Load the best checkpoint
        best_model_path = checkpoint_callback.best_model_path
        best_model = SAE.load_from_checkpoint(best_model_path)
        autoencoders[n_hidden] = best_model

        print(f"Best model loaded from {best_model_path}")
        print(f"Best loss: {checkpoint_callback.best_model_score.item():.6f}")

        # Save model to out_path
        model_save_path = os.path.join(out_path, f"sae_waymo_n{n_hidden}.pth")
        torch.save(best_model.state_dict(), model_save_path)
        print(f"Best model saved to {model_save_path}")


Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name         | Type | Params | Mode
---------------------------------------------
  | other params | n/a  | 33.0 K | n/a 
---------------------------------------------
33.0 K    Trainable params
0         Non-trainable params
33.0 K    Total params
0.132     Total estimated model params size (MB)
0         Modules in train mode
0         Modules in eval mode
/home/tas/.virtualenvs/words/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_ste

Epoch 9999: 100%|██████████| 7/7 [00:00<00:00, 23.34it/s, v_num=0, loss=5.740]

`Trainer.fit` stopped: `max_epochs=10000` reached.


Epoch 9999: 100%|██████████| 7/7 [00:00<00:00, 23.02it/s, v_num=0, loss=5.740]
Best model loaded from /home/tas/00_workspaces/words_in_motion/out/control-vectors/waymo/checkpoints/sae128-epoch=9435-loss=4.3303.ckpt
Best loss: 4.330253
Best model saved to /home/tas/00_workspaces/words_in_motion/out/control-vectors/waymo/sae_waymo_n128.pth


Training for 10.000 epochs - SAE
(Google Colab T4-instance; seed not set / cuda=12.5 / torch=2.6.0+cu124 / sklearn=2.0.2 / python=3.11.11 (final))

| Hidden Dim | Epoch         | Total Loss   | L1 Loss       | L2 Loss        | Total Reconst Loss |
|------------|---------------|--------------|---------------|----------------|--------------------|
| 512        | 9805/10000    | 4.005656276  | 1.524447083   | 8270.697265625 | 0.001645120210014  |
| 256        | 9845/10000    | 3.724161590  | 1.376968503   | 7823.977050781 | 0.001388887991197  |
| 128        | 9820/10000    | 4.139010770  | 1.556326985   | 8608.9453125   | 0.001653907238506  |
| 64         | 9348/10000    | 4.561335734  | 1.892843366   | 8894.974609375 | 0.001926084747538  |
| 32         | 9864/10000    | 7.141473430  | 3.902811527   | 10795.541015625| 0.004311752039939  |
| 16         | 9956/10000    | 17.441959654 | 13.368986130  | 13576.573242188| 0.014228038489819  |



In [None]:
from future_motion.utils.interpretability.control_vectors import fit_control_vector


# Create control vectors with SAE
SAE_control_vectors = {}
for hidden_dims, autoencoder in autoencoders.items():
    SAE_control_vectors[hidden_dims] = dict()
    for key in contrastive_embs.keys():
        print(key)
        cv = fit_control_vector(contrastive_embs[key], autoencoder=autoencoder, verbose_explained_variance=True)
        SAE_control_vectors[hidden_dims][key] = cv
        torch.save(torch.tensor(cv), f=f"{out_path}/sae{hidden_dims}_{key}_layer{idx_layer}.pt")

speed
explained variance: [0.66795915 0.0779107  0.05768754 0.03374871 0.02376009 0.01854349
 0.01611649 0.01115513 0.01029056 0.00726391]
acceleration
explained variance: [0.63380283 0.08891191 0.08114744 0.03292668 0.02905993 0.02080279
 0.01766137 0.01140689 0.00912073 0.00772376]
direction
explained variance: [0.40223372 0.13826407 0.10420155 0.07386495 0.04946833 0.03335578
 0.02050195 0.01822362 0.01589583 0.01390625]
agent
explained variance: [0.59231377 0.10558699 0.0751505  0.03179755 0.02696528 0.02449778
 0.02359529 0.01325316 0.01098    0.00962443]
speed
explained variance: [0.68807936 0.073489   0.0685495  0.02807502 0.02159631 0.01922441
 0.0177188  0.00981717 0.00746597 0.00639294]
acceleration
explained variance: [0.63591397 0.098262   0.08801161 0.03279012 0.02852907 0.02120169
 0.01530502 0.01093495 0.00692878 0.00642024]
direction
explained variance: [0.40637618 0.14300221 0.11172867 0.08087072 0.04480727 0.03594213
 0.02242988 0.01866051 0.01497646 0.01319762]
agent