In [1]:
from transformer_lens import (HookedTransformer, utils)
from transformer_lens.hook_points import HookPoint
import functools
import torch

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from torch import Tensor
from torch.nn import functional as F
device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
from transformers import PatchTSTForPrediction
from transformers.models.patchtst.modeling_patchtst import (
    PatchTSTForPredictionOutput
)
from data_loader import *
import pandas as pd

import torch
import os

import plotly.express as px
from sae_lens import (
    SAE,
    upload_saes_to_huggingface,
    LanguageModelSAERunnerConfig,
    TimeSeriesModelSAERunnerConfig,
    TimeSeriesModelSAETrainingRunner,
    SAETrainingRunner,
    StandardTrainingSAEConfig,
    LoggingConfig,
    HookedSAETransformer,
    ActivationsStore,
    run_evals,
)
import json

from sae_lens.evals import EvalConfig
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
from sae_lens.training.activation_scaler import ActivationScaler

In [2]:
model = HookedTransformer.from_pretrained("patchtst_relu", center_unembed=False).to(device)
sae_model = HookedSAETransformer.from_pretrained("patchtst_relu", center_unembed=False).to(device)

# Loading tsmixup dataset
train_dataset, val_dataset = create_cached_tsmixup_datasets(
        max_samples=300000,
        context_length=512,
        prediction_length=96, # 1 or 96
        num_workers=16,
        cache_dir="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/",
        processed_cache_path="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_96.pkl",
        batch_size=4000
    )



Loaded pretrained model patchtst_relu into HookedTransformer
Moving model to device:  cuda:7




Loaded pretrained model patchtst_relu into HookedTransformer
Moving model to device:  cuda:7
🚀 CREATING CACHED TSMIXUP DATASETS
📂 Found existing processed data at /extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_96.pkl
⚡ Loading preprocessed data from cache...
✅ Loaded 174,209 preprocessed samples
📅 Cache created: 2025-08-03 15:05:17

📊 DATASET SUMMARY:
  Total processed samples: 174,209
  Context length: 512
  Prediction length: 96
🔀 Shuffling data...
📈 Data split:
  Training samples: 156,788
  Validation samples: 17,421
  Train ratio: 90.0%
🏗️  Creating PyTorch datasets...
🏗️  Dataset created with 156,788 samples
📊 Augmentation: ON
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=608, max=2043, mean=1320
  Value ranges: min=-49.2103, max=70.9532
  Value stats: mean=0.9038, std=2.2952
🏗️  Dataset created with 17,421 samples
📊 Augmentation: OFF
📈 Dataset Statistics (from 1000 samples):
  Seq

In [3]:
total_training_steps = 1000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training
d_in = 256
expansion_factor = 16
num_patches = 32

cfg = TimeSeriesModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="patchtst_relu",  # my model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    dataset_path="autogluon/chronos_datasets",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    dataset_dtype="torch.float32",
    streaming=False,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    sae=StandardTrainingSAEConfig(
        d_in=d_in,  # the width of the mlp output.
        d_sae=d_in * expansion_factor,  # the width of the SAE. Larger will result in better stats but slower training.
        apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
        normalize_activations="none",
        l1_coefficient=0.5,  # will control how sparse the feature activations are
        l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    ),
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    num_patches=num_patches,
    # Activation Store Parameters
    n_batches_in_buffer=64*16,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    logger=LoggingConfig(
        log_to_wandb=True,  # always use wandb unless you are just testing code.
        wandb_project="patchtst_sae_metric_tests",
        wandb_log_frequency=10,
        eval_every_n_wandb_logs=20,
        run_name=f"patchtst_relu_test"
    ),
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = TimeSeriesModelSAETrainingRunner(cfg, override_dataset=val_dataset).run()



Loaded pretrained model patchtst_relu into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
[34m[1mwandb[0m: Currently logged in as: [33mcoaster41[0m ([33mcoaster41-uci[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training SAE:   0%|          | 0/4096000 [00:00<?, ?it/s]

  yield torch.tensor(
  yield torch.tensor(
  yield torch.tensor(
  yield torch.tensor(
  yield torch.tensor(
  yield torch.tensor(


0,1
details/current_learning_rate,███████████████████████████████████▅▄▃▃▁
details/l1_coefficient,▁███████████████████████████████████████
details/n_training_samples,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
losses/l1_loss,█▇▇▆▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▂▁▂▁▂▁▁▁▁▁▁▁
losses/mse_loss,██▇▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,██▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/explained_variance,▂▁▁▃▄▅▅▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
metrics/explained_variance_legacy,▁▄▃▃▄▆▆▆▆▆▇▇▇▇▇▇▇▇▇█████████████████████
metrics/explained_variance_legacy_std,▁██▇▆▆▅▅▅▅▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂▂▂▂▂▁▂▁
metrics/l0,▇█▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
details/current_learning_rate,0.0
details/l1_coefficient,0.5
details/n_training_samples,4096000.0
losses/l1_loss,0.64185
losses/mse_loss,0.50712
losses/overall_loss,1.14896
metrics/explained_variance,0.7203
metrics/explained_variance_legacy,0.67434
metrics/explained_variance_legacy_std,0.12833
metrics/l0,240.58301
