In [None]:
import json
import os
import re

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from datasets import Dataset
from transformers import AutoModel, AutoTokenizer
import requests
from scipy.sparse import csr_matrix
from tqdm import tqdm

In [None]:
import transformers
from transformers import AutoModel, AutoTokenizer

In [None]:
pythia = AutoModel.from_pretrained('EleutherAI/pythia-70m')

In [None]:
tokenizer = AutoModel.from_pretrained(

In [None]:
! rm -rf Interpreting-Reward-Models || true
! git clone https://github.com/apartresearch/Interpreting-Reward-Models.git
! cd Interpreting-Reward-Models && pip install .

In [None]:
from reward_analyzer import SparseAutoencoder, TaskConfig
from reward_analyzer.utils.model_storage_utils import load_autoencoders_for_artifact, load_latest_model_from_hub
from reward_analyzer.utils.transformer_utils import batch

In [None]:
model_name = 'EleutherAI/gpt-neo-125m'
# model_name = 'EleutherAI/pythia-70m'
# model_name = 'EleutherAI/pythia-160m'

task = TaskConfig.HH_RLHF
task_name = task.name
version = 'v0'

if 'pythia' in model_name:
    layer_name_step = 'layers.{}.mlp'
elif 'neo' in model_name:
    layer_name_stem = 'h.{}.mlp'
elif 'gemma' in model_name:
    layer_name_stem = 'layers.{}.mlp'
else:
    raise Exception(f'Not familiar with model name family of {model_name}')

### Load model and autoencoder artifacts.

In [None]:
tokenizer =  AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = load_latest_model_from_hub(model_name = model_name, task_config=task)
model.device

In [None]:
autoencoders_dict = load_autoencoders_for_artifact(f'nlp_and_interpretability/Autoencoder_training_hh_rlhf/autoencoders_{model_name.split("/")[-1].replace("-", "_")}_{task_name}:{version}')

In [None]:
rlhf_small = autoencoders_dict['rlhf_small']

In [None]:
def dump_data_to_jsonl(data: dict, filename: str):
    list_lengths = [len(value_list) for value_list in data.values()]

    assert min(list_lengths) == max(list_lengths), f'Expected list lengths to be the same! Instead got {list_lengths}'
    n = max(list_lengths)
    print(f'Writing to file name now')


    # Open a file to write JSON Lines
    with open(filename, 'w') as jsonl_file:
        # Iterate over the index of the lists
        for i in range(n):
            # Create a dictionary for the current JSON object
            json_object = {key: values[i] for key, values in data.items()}
            # Write the JSON object as a line in the JSONL file
            jsonl_file.write(json.dumps(json_object) + '\n')

In [None]:
def features_from_single_input(single_input):
    return torch.mean(single_input, dim=0)

def extract_and_process_activations(texts, model, tokenizer, layer_name_stem, autoencoders_dict):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    token_ids = inputs["input_ids"].squeeze().tolist()
    activations = {}

    target_layer_names = [layer_name_stem.format(key) for key in autoencoders_dict]

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    hooks = [
        module.register_forward_hook(get_activation(name))
        for name, module in model.named_modules()
        if name in target_layer_names
    ]

    with torch.no_grad():
        outputs = model(**inputs)

    for hook in hooks:
        hook.remove()

    specified_activations =  {layer_num: activations[layer_name_stem.format(layer_num)] for layer_num in autoencoders_dict}
    final_token_embeddings = outputs.last_hidden_state.squeeze().detach().tolist()
    final_token_embeddings = [item[0] for item in final_token_embeddings]


    all_features = {
        "texts": texts,
        "token_ids": token_ids,
        "token_embeddings": final_token_embeddings
    }

    for layer_num, activation_values in specified_activations.items():
        activation_values = activation_values.squeeze(0).cpu()
        autoencoder = autoencoders_dict[layer_num]
        batch_features, _ = autoencoder(activation_values)
        batch_features = batch_features.detach().squeeze(0)

        all_features[f'activations_{layer_num}'] = activation_values.detach().cpu().squeeze(0).numpy().tolist()

        full_reprs = []
        averaged_reprs = []
        for single_feature in batch_features:
            averaged_repr_each_input = features_from_single_input(single_feature).cpu().tolist()

            full_reprs.append(single_feature.cpu().tolist())
            averaged_reprs.append(averaged_repr_each_input)

        all_features[f'full_repr_{layer_num}'] = full_reprs
        all_features[f'averaged_reprs_{layer_num}'] = averaged_reprs

    return all_features


def extract_features_batched(texts, model, tokenizer, layer_name_stem, autoencoders_dict, output_file=None, batch_size=8):
    output_file = output_file or f'./{model_name}_{task.name}_activations_dataset.jsonl'.split("/")[-1].replace("-", "_")
    for curr_batch in tqdm(batch(texts, n=batch_size)):
        features = extract_and_process_activations(texts, model, tokenizer, layer_name_stem, autoencoders_dict)
        dump_data_to_jsonl(features, filename = output_file)

    return features

extract_features_batched(texts=50*texts, model=model, tokenizer=tokenizer, layer_name_stem=layer_name_stem, autoencoders_dict=rlhf_small, batch_size=8)

In [None]:
from datasets import load_dataset
dataset = load_dataset("json", data_files="gpt-neo-125m_hh_rlhf_activations_dataset.jsonl")

In [None]:
def save_training_dataset_to_wandb(training_dataset: Dataset, model_name, dataset_name="logistic_probe_data.hf"):
    out_filename = training_dataset.save_to_disk(dataset_name)
    
    my_artifact = wandb.Artifact(f"logistic_probe_training_dataset_{model_name}", type="data")
    
    # Add the list to the artifact
    my_artifact.add_file(local_path=out_filename, name="logistic_probe_training_dataset")

    metadata_dict = {
        "description": "Training dataset, with activations and rewards",
        "source": "Generated by my script",
        "num_examples": len(training_dataset),
        "split": "full"
    }

    my_artifact.metadata.update(metadata_dict)

    # Log the artifact to the run
    wandb.log_artifact(my_artifact)

save_training_dataset_to_wandb(full_training_dataset, model_name=model_name)

In [None]:
dataset