In [None]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from collections import OrderedDict
from datasets import Dataset
from transformers import AutoModel, AutoTokenizer
import requests
import re

In [None]:
! rm -rf Interpreting-Reward-Models || true
! git clone https://github.com/apartresearch/Interpreting-Reward-Models.git
! cd Interpreting-Reward-Models && pip install .

In [None]:
from reward_analyzer import SparseAutoencoder
from reward_analyzer.utils.model_storage_utils import load_autoencoders_for_artifact
from reward_analyzer.utils.transformer_utils import batch

In [None]:
model_name = 'gpt_neo_125m'
task_name = 'hh_rlhf'
version = 'v0'

if 'pythia' in model_name:
    layer_name_step = 'layers.{}.mlp'
elif 'neo' in model_name:
    layer_name_stem = 'h.{}.mlp'
elif 'gemma' in model_name:
    layer_name_stem = 'layers.{}.mlp'
else:
    raise Exception(f'Not familiar with model name family of {model_name}')

In [None]:
autoencoders_dict = load_autoencoders_for_artifact(f'nlp_and_interpretability/Autoencoder_training_hh_rlhf/autoencoders_{model_name}_{task_name}:{version}')

In [None]:
rlhf_small = autoencoders_dict['rlhf_small']

In [None]:
def extract_and_process_activations(texts, model, tokenizer, layer_name_stem, autoencoders_dict):
    inputs = tokenizer(text, return_tensors="pt")
    token_ids = inputs["input_ids"].squeeze().tolist()
    activations = {}

    target_layer_names = [layer_name_stem.format(key) for key in autoencoders_dict.keys()]

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    hooks = [
        module.register_forward_hook(get_activation(name))
        for name, module in model.named_modules()
        if name in target_layer_names
    ]

    with torch.no_grad():
        model(**inputs)

    for hook in hooks:
        hook.remove()

    specified_activations = [(name, activations[name]) for name in target_layer_names]
    concatenated_activations = [[] for _ in token_ids]

    for act, autoencoder_idx in zip(specified_activations, autoencoder_indices):
        name, act = act
        act = act.squeeze(0)
        autoencoder = autoencoders[autoencoder_idx]
        features, _ = autoencoder(act)
        for i in range(len(token_ids)):
            concatenated_activations[i].append(features[i].tolist())

    final_activations = {
        token_id: [item for sublist in concatenated_activations[i] for item in sublist]
        for i, token_id in enumerate(token_ids)
    }

    return final_activations

In [None]:
def save_training_dataset_to_wandb(training_dataset: Dataset, model_name, dataset_name="logistic_probe_data.hf"):
    out_filename = training_dataset.save_to_disk(dataset_name)
    
    my_artifact = wandb.Artifact(f"logistic_probe_training_dataset_{model_name}", type="data")
    
    # Add the list to the artifact
    my_artifact.add_file(local_path=out_filename, name="logistic_probe_training_dataset")

    metadata_dict = {
        "description": "Training dataset, with activations and rewards",
        "source": "Generated by my script",
        "num_examples": len(training_dataset),
        "split": "full"
    }

    my_artifact.metadata.update(metadata_dict)

    # Log the artifact to the run
    wandb.log_artifact(my_artifact)

save_training_dataset_to_wandb(full_training_dataset, model_name=model_name)