In [1]:
import os
import argparse
import pandas as pd
import numpy as np
import torch
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
STREAM_HOOKS = {
    'attention': 'blocks.{layer}.hook_attn_out',
    'mlp': 'blocks.{layer}.hook_mlp_out',
    'residual': 'blocks.{layer}.hook_resid_post'
}

In [4]:
# args
train_csv = "../all_sentences_train.csv"
temporal_csv = "../all_sentences_temporal.csv"
hf_token = "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"
torch_dtype='float16'
batch_size = 16
model_name = 'meta-llama/Llama-3.1-8B'
layers = range(15, 32)
out_dir = './model_outputs'
n_train = 500
n_temporal = 500

torch.manual_seed(0)
np.random.seed(0)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x14ff1f5f6f50>

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = getattr(torch, torch_dtype)

# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    use_fast=True,
    padding_side='left',
    token=hf_token
)

# ensure pad_token exists for batch padding
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
    
model_hf = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
    device_map='auto',
    token=hf_token
)

# resize embeddings if pad_token added
model_hf.resize_token_embeddings(len(tokenizer))

model = HookedTransformer.from_pretrained_no_processing(
    model_name,
    hf_model=model_hf,
    tokenizer=tokenizer,
    device=device,
    dtype=dtype
).eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer


## Helper

In [6]:
def load_data(path, n_per_label):
    df = pd.read_csv(path, encoding='utf-8-sig')
    frames = []
    for label in ['past', 'present', 'future']:
        subset = df[df['tense'] == label].head(n_per_label)
        frames.append(subset)
    return pd.concat(frames, ignore_index=True)


def extract_stream(tokenized, verb_indices, hook_name, model):
    input_ids = tokenized['input_ids'].to(model.cfg.device)
    attention_mask = tokenized['attention_mask'].to(model.cfg.device)
    # Use correct signature: tokens tensor and attention_mask keyword
#     _, cache = model.run_with_cache(input_ids, attention_mask=attention_mask)
    logits, cache = model.run_with_cache(
        input_ids,
        attention_mask=attention_mask,
        names_filter=[hook_name],
    )
    
    acts = cache[hook_name]
    out = []
    for i, vidx in enumerate(verb_indices):
        word_ids = tokenized.word_ids(batch_index=i)
        positions = [pos for pos, w in enumerate(word_ids) if w == vidx]
        if not positions:
            raise RuntimeError(f"No token for verb_index={vidx} in example {i}")
        sub_embs = acts[i, positions, :]
        emb = sub_embs.mean(dim=0).cpu().numpy()
        out.append(emb)
    return torch.tensor(out)


def process_split(df, model, tokenizer, layers, split_name, batch_size, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    # Prepare dataframes per batch
    sentences = df['sentence'].tolist()
    verbs = df['verb_index'].tolist()
    metadata = df[['language','sentence','main_verb','verb_index','tense']]

    for layer in layers:
        print(f"# processing layer {layer}")
        for stream_name, hook_fmt in STREAM_HOOKS.items():
            print(f"## processing stream {stream_name}")
            hook_name = hook_fmt.format(layer=layer)
            records = []
            # iterate in batches
            for start in range(0, len(sentences), batch_size):
                batch_sent = sentences[start:start+batch_size]
                batch_verbs = verbs[start:start+batch_size]
                # tokenize with word alignment
                tokenized = tokenizer(
                    [s.split() for s in batch_sent],
                    is_split_into_words=True,
                    return_tensors='pt',
                    padding=True,
                    truncation=True
                )
                
                # extract activations
                emb_batch = extract_stream(tokenized, batch_verbs, hook_name, model)
                # collect records
                for i, emb in enumerate(emb_batch.numpy()):
                    rec = {f'{stream_name}_{j}': float(v)
                           for j, v in enumerate(emb)}
                    md = metadata.iloc[start + i].to_dict()
                    rec.update({'layer': layer, 'stream': stream_name, **md})
                    records.append(rec)

            df_out = pd.DataFrame.from_records(records)
            out_file = os.path.join(
                out_dir,
                f'llama_{split_name}_layer{layer}_{stream_name}.parquet'
            )
            df_out.to_parquet(out_file, index=False)
            print(f'Saved {out_file}: {df_out.shape}')

In [7]:
df_train = load_data(train_csv, n_per_label=n_train)
df_temporal = load_data(temporal_csv, n_per_label=n_temporal)

In [8]:
process_split(df_train, model, tokenizer,
              layers, 'train',
              batch_size, out_dir)

# processing layer 15
## processing stream attention


  return torch.tensor(out)


Saved ./model_outputs/llama_train_layer15_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_train_layer15_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_train_layer15_residual.parquet: (1500, 4103)
# processing layer 16
## processing stream attention
Saved ./model_outputs/llama_train_layer16_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_train_layer16_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_train_layer16_residual.parquet: (1500, 4103)
# processing layer 17
## processing stream attention
Saved ./model_outputs/llama_train_layer17_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_train_layer17_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_train_layer17_residual.parquet: (1500, 4103)
# processing layer 18
## processing stream attention
Saved ./model_outputs/llama_tra

In [9]:
process_split(df_temporal, model, tokenizer,
              layers, 'temporal',
              batch_size, out_dir)

# processing layer 15
## processing stream attention
Saved ./model_outputs/llama_temporal_layer15_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_temporal_layer15_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_temporal_layer15_residual.parquet: (1500, 4103)
# processing layer 16
## processing stream attention
Saved ./model_outputs/llama_temporal_layer16_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_temporal_layer16_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_temporal_layer16_residual.parquet: (1500, 4103)
# processing layer 17
## processing stream attention
Saved ./model_outputs/llama_temporal_layer17_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_temporal_layer17_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_temporal_layer17_residual.parquet: (1500, 4103)
# pr

In [10]:
process_split(df_train, model, tokenizer,
              [14], 'train',
              batch_size, out_dir)

# processing layer 14
## processing stream attention
Saved ./model_outputs/llama_train_layer14_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_train_layer14_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_train_layer14_residual.parquet: (1500, 4103)


In [11]:
process_split(df_temporal, model, tokenizer,
              [14], 'temporal',
              batch_size, out_dir)

# processing layer 14
## processing stream attention
Saved ./model_outputs/llama_temporal_layer14_attention.parquet: (1500, 4103)
## processing stream mlp
Saved ./model_outputs/llama_temporal_layer14_mlp.parquet: (1500, 4103)
## processing stream residual
Saved ./model_outputs/llama_temporal_layer14_residual.parquet: (1500, 4103)


### Test

In [10]:
def compare_embeddings(train_path, temp_path, eps=1e-6):
    # load
    df1 = pd.read_parquet(train_path)
    df2 = pd.read_parquet(temp_path)

    # align on metadata columns
    meta_cols = ['language','sentence','main_verb','verb_index','tense','layer','stream']
    emb_cols = [c for c in df1.columns if c not in meta_cols]
    
    # merge on metadata
    merged = df1.merge(df2, on=meta_cols, suffixes=('_1','_2'))
    # compute difference norms
    diffs = merged.apply(
        lambda row: np.linalg.norm(
            row[[f"{c}_1" for c in emb_cols]].values
          - row[[f"{c}_2" for c in emb_cols]].values
        ),
        axis=1
    )
    # report
    n_zero = (diffs < eps).sum()
    total = len(diffs)
#     print(f"{n_zero}/{total} embeddings are identical within {eps}")
    if n_zero>0:
        identical = merged[diffs<eps][meta_cols]
        print("Examples with identical embeddings:")
        print(identical.head())
    else:
        print("All embeddings differ.")

In [12]:
for layer in range(15,32):
        for stream in ['attention','mlp','residual']:
            train_f = f'./model_outputs/llama_train_layer{layer}_{stream}.parquet'
            temp_f  = f'./model_outputs/llama_temporal_layer{layer}_{stream}.parquet'
            print(f"Layer {layer} | {stream}")
            compare_embeddings(train_f, temp_f)

Layer 15 | attention
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 15 | mlp
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 15 | residual
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 16 | attention
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 16 | mlp
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 16 | residual
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 17 | attention
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 17 | mlp
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 17 | residual
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 18 | attention
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 18 | mlp
0/0 embeddings are identical within 1e-06
All embeddings differ.
Layer 18 | residual
0/0 embeddings are identical within 1e-06
All embeddings diff