In [1]:
import os
import pandas as pd
import numpy as np
import torch
from sae_lens import SAE
from transformers import AutoModelForCausalLM, AutoTokenizer

### Configuration

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'meta-llama/Llama-3.1-8B'

data_dir = './model_outputs'
output_dir = './latent_outputs_yusser'
labels = ['past', 'present', 'future']
streams = {
    'residual': 'resid_post',        # post-MLP residual stream
    'mlp': 'mlp_out',                # MLP output
    'attention': 'attn_out'          # layer attention output
}
layers = range(14,32)

os.makedirs(output_dir, exist_ok=True)

torch.manual_seed(0)
np.random.seed(0)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x14e3fc6c62f0>

### Helper

In [3]:
# Helper to load embeddings for a given stream
def load_embeddings(label, layer, stream_name):
    path = os.path.join(
        data_dir,
        f'llama_train_layer{layer}_{stream_name}.parquet'
    )
    df = pd.read_parquet(path)
    df = df[df['tense'] == label]  # Filter by label
    cols = [c for c in df.columns if c.startswith(f'{stream_name}_')]
    emb = torch.tensor(df[cols].values, dtype=torch.float32)
    # reshape to [N,1,d] to match SAE expected input
    # [N, d] -> [N, 1, d] to match [batch, seq_len, dim]
    emb = emb.unsqueeze(1)
    return emb, df

def run_sae(label, split, streams, layers):
    for stream_name, stream_code in streams.items():
        for layer in layers:
            path = os.path.join(
                data_dir,
                f'llama_{split}_layer{layer}_{stream_name}.parquet'
            )
            if not os.path.exists(path):
                print(f"Missing {path}"); continue
            df = pd.read_parquet(path)
            df = df[df['tense'] == label]
            cols = [c for c in df if c.startswith(f'{stream_name}_')]
            emb = torch.from_numpy(df[cols].values).float().unsqueeze(1).to(device)
            
            sae_id = f'blocks.{layer}.hook_{stream_code}'
            sae, cfg_dict, sparsity = SAE.from_pretrained(
                release="Yusser/multilingual_llama3.1-8B_saes",
                sae_id=sae_id,
                device="cuda"
            )
            sae.eval()
            
            sae.eval()
            with torch.no_grad():
                feats = sae.encode(emb)
                recon = sae.decode(feats)
            feats, recon = feats.cpu(), recon.cpu()
            split_name = "nontemporal" if split == "train" else split
            base = f'{split_name}_{label}_l{layer}_{stream_name}'
            torch.save(feats,  os.path.join(output_dir, f'{base}_feature_acts.pt'))
            torch.save(recon,  os.path.join(output_dir, f'{base}_sae_out.pt'))
            df.to_parquet(os.path.join(output_dir, f'{base}_metadata.parquet'), index=False)

In [4]:
for split in ['train','temporal']:
    for label in labels:
        run_sae(label, split, streams, layers)

cfg.json:   0%|          | 0.00/641 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/1.07G [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pret

cfg.json:   0%|          | 0.00/638 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/1.07G [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pret

cfg.json:   0%|          | 0.00/639 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/1.07G [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pret