### Load training dataset and Vader sentiment analyzer

In [None]:
import dataclasses
import gc
import json
import os
import pickle
import pprint

from time import time

import nltk
import spacy
import torch
import torch.nn.functional as F
import tqdm
import transformers
import wandb

from scipy.sparse import csr_matrix
from torch import FloatTensor, LongTensor, Tensor

from dataclasses import dataclass
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from torch import nn
from tqdm import tqdm_notebook
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
model_name = 'pythia-160m'
policy_model_name  = 'pythia_160m_utility_reward'
os.environ['WANDB_API_KEY'] = ''

In [None]:
run = wandb.init(project="utility_reconstruction")

In [None]:
full_model_name = f'EleutherAI/{model_name}'

model = AutoModelForCausalLM.from_pretrained(full_model_name)
model.cuda()
model.eval()

tokenizer = AutoTokenizer.from_pretrained(full_model_name)
tokenizer.pad_token = tokenizer.eos_token

#model.resize_token_embeddings(len(tokenizer))

In [None]:
sentiment_analyzer = SentimentIntensityAnalyzer()
lexicon = sentiment_analyzer.lexicon
max_value = max(lexicon.values())
min_value = min(lexicon.values())

## Tokenization and torch utilities.

In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def clear_gpu_memory():
    start_time = time()
    gc.collect()
    torch.cuda.empty_cache()
    end_time = time()
    total_time = round(end_time-start_time, 2)
    print(f'Took {total_time} seconds to clear cache.')

In [None]:
def pad_list_of_lists(list_of_lists, pad_token):
    max_length = max(len(lst) for lst in list_of_lists)
    padded_list = [lst + [pad_token] * (max_length - len(lst)) for lst in list_of_lists]
    return padded_list

In [None]:
def check_number_of_tokens(word, tokenizer=tokenizer):
    return len(tokenizer(word)['input_ids'])

In [None]:
def get_tokens_and_ids(text, tokenizer=tokenizer):
    input_ids = tokenizer(text.lower(), truncation=True)['input_ids']
    
    tokens = [tokenizer.decode(input_id) for input_id in input_ids]
    # The above produces artifacts such as a " positive" token and id, instead of "positive". So we redo this.

    tokens = [token.lower().strip() for token in tokens]
    tokenizer
    return tokens, input_ids

In [None]:
def get_single_target_token_id(word, tokenizer=tokenizer):
    word = word.lower().strip()
    num_tokens = check_number_of_tokens(word)
    if num_tokens > 1:
        # Backoff to include a single space.
        word = f' {word}'
        num_tokens = check_number_of_tokens(word)

    return tokenizer(word)['input_ids'][0]

In [None]:
@dataclass
class TextTokensIdsTarget:
    attention_mask: list[int]
    text: str
    tokens: list[str]
    ids: list[int]
    target_token: str
    target_token_id: int
    target_token_position: int

    @staticmethod
    def get_tensorized(datapoints: "TextTokensIdsTarget"):
        max_length = max([len(datapoint.tokens) for datapoint in datapoints])
        
        input_ids = [datapoint.ids for datapoint in datapoints]
        attention_masks = [datapoint.attention_mask for datapoint in datapoints]

        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])
        attention_masks_padded = pad_list_of_lists(attention_masks, 0)
        all_tokenized = {
            "input_ids": torch.IntTensor(input_ids_padded).cuda(), "attention_mask": torch.ByteTensor(attention_masks_padded).cuda()
        }
        return all_tokenized

def trim_example(input_text: str, target_words: list[str], verbose=False, tokenizer=tokenizer):
    single_target_token_ids = [get_single_target_token_id(word.strip().lower()) for word in target_words]
    
    single_target_token_ids = [token_id for token_id in single_target_token_ids if token_id]
    single_target_tokens = [tokenizer.decode(token_id).strip().lower() for token_id in single_target_token_ids]

    input_tokens, input_token_ids = get_tokens_and_ids(input_text)

    trimmed_input_tokens = []
    trimmed_input_token_ids = []

    for input_token, input_token_id in zip(input_tokens, input_token_ids):
        trimmed_input_tokens.append(input_token)
        trimmed_input_token_ids.append(input_token_id)
        if input_token.strip().lower() in single_target_tokens:
            break

    assert len(trimmed_input_token_ids) == len(trimmed_input_tokens), "Num of tokens and token ids should be equal"

    last_token = None

    if trimmed_input_tokens:
        last_token = trimmed_input_tokens[-1].lower().strip()
        last_token_id = trimmed_input_token_ids[-1]

    if len(trimmed_input_tokens) > tokenizer.model_max_length:
        print(f'Dropping example since exceed model max length. Input text was:\n{input_text}')
        return None
    
    elif last_token and last_token in single_target_tokens:
        text = tokenizer.decode(trimmed_input_token_ids)
        target_token_position = len(trimmed_input_token_ids) - 1
        return TextTokensIdsTarget(
            attention_mask=[1]*len(trimmed_input_tokens),
            text=text, tokens=trimmed_input_tokens, ids=trimmed_input_token_ids, 
            target_token=last_token, target_token_id=last_token_id,
            target_token_position=target_token_position
        )
    else:
        if verbose:
            print(f'last token was {last_token} in {trimmed_input_tokens}, and was not in target tokens.')
        return None

### Load training examples for linear probe.

In [None]:
def load_wandb_json_artifact(
    project_name='utility_reconstruction', artifact_name = 'contrastive_sentiment_pairs', version='v6'
):
    api = wandb.Api()
    artifact = api.artifact(f'nlp_and_interpretability/{project_name}/{artifact_name}:{version}', type='data')
    artifact_dir = artifact.download()

    with open(f'artifacts/{artifact_name}:{version}/{artifact_name}', 'r') as f_in:
        result = json.load(f_in)
        return result

all_input_dicts = load_wandb_json_artifact()

In [None]:
def clean_capitalization_neutral_terms(input_dicts):
    for input_dict in input_dicts:
        neutral_words = list(input_dict['neutral_words'].values())
        for neutral_word in neutral_words:
            input_dict['neutral_text'] = input_dict['neutral_text'].replace(neutral_word, neutral_word.lower())

    return input_dicts

all_input_dicts = clean_capitalization_neutral_terms(all_input_dicts)

In [None]:
class TrainingPoint:

    def __init__(self, input_dict: dict, tokenizer=tokenizer):
        self.input_dict = input_dict
        self.positive_text = input_dict['input_text']
        self.negative_text = input_dict['output_text']
        self.neutral_text = input_dict['neutral_text']
        
        # Dictionary of layer name to activations by mlp layer.
        self.activations: dict = None

        # Dictionary of layer name to autoencoder feature by mlp layer
        self.autoencoder_feature: dict = None

        # Reward value of target_token.
        self.target_positive_reward = None
        self.target_negative_reward = None

        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)
        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)
        
        self.positive_words = input_dict['positive_words']
        self.negative_words = list(input_dict['new_words'].values())
        self.neutral_words = list(input_dict['neutral_words'].values())

        self.target_positive_reward = None
        self.target_positive_token = None
        self.target_positive_token_id = None
    
        self.target_negative_reward = None
        self.target_negative_token = None
        self.target_negative_token_id = None

        self.target_neutral_token = None
        self.target_neutral_token_id = None

        try:
            self.trimmed_positive_example: TextTokensIdTarget = trim_example(self.positive_text, self.positive_words)
            if self.trimmed_positive_example:
                positive_token = self.trimmed_positive_example.target_token.strip().lower()
                self.target_positive_reward = lexicon.get(positive_token, None)
                self.target_positive_token = positive_token
                self.target_positive_token_id = self.trimmed_positive_example.target_token_id
        
        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for positive example.')
            self.trimmed_positive_example = None
        
        try:
            self.trimmed_negative_example: TextTokensIdTarget = trim_example(self.negative_text, self.negative_words)
            if self.trimmed_negative_example:
                negative_token = self.trimmed_negative_example.target_token.strip().lower()
                self.target_negative_reward = lexicon.get(negative_token, None)
                self.target_negative_token = negative_token
                self.target_negative_token_id = self.trimmed_negative_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for negative example.')
            self.trimmed_negative_example = None

        try:
            self.trimmed_neutral_example: TextTokensIdTarget = trim_example(self.neutral_text, self.neutral_words)
            if self.trimmed_neutral_example:
                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()
                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for neutral example.')
            self.trimmed_neutral_example = None

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [None]:
training_points = []
for input_dict in tqdm_notebook(all_input_dicts):
    training_points.append(TrainingPoint(input_dict))

In [None]:
x = training_points[10]
x.autoencoder_feature

In [None]:
successful_training_points = [
    training_point for training_point in training_points if 
    training_point.trimmed_positive_example and training_point.trimmed_negative_example
    and training_point.trimmed_neutral_example
    and training_point.target_positive_reward is not None
    and training_point.target_negative_reward is not None
]

In [None]:
len(successful_training_points)

In [None]:
sample_training_points = successful_training_points[:10]

In [None]:
bad_training_points = [training_point for training_point in  training_points if training_point not in successful_training_points]

### Load autoencoders for linear probe.

In [None]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, l1_coef):
        super(SparseAutoencoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'l1_coef': l1_coef}
        self.l1_coef = l1_coef

        self.encoder_weight = nn.Parameter(torch.randn(hidden_size, input_size))
        nn.init.orthogonal_(self.encoder_weight)

        self.encoder_bias = nn.Parameter(torch.zeros(self.hidden_size))
        self.decoder_bias = nn.Parameter(torch.zeros(input_size))

    def forward(self, x):
        normalized_encoder_weight = F.normalize(self.encoder_weight, p=2, dim=1)

        features = F.linear(x, normalized_encoder_weight, self.encoder_bias)
        features = F.relu(features)

        # reconstruction = F.linear(features, normalized_encoder_weight.t(), self.decoder_bias)
        return features.detach()

In [None]:
entity_name = 'nlp_and_interpretability'
project_prefix = 'Autoencoder_training'
artifact_prefix = 'autoencoders'

def load_autoencoders_for_artifact(policy_model_name, alias='latest', run=run):
    '''
    Loads the autoencoders from one run into memory. Note that these paths are to some extent hardcoded
    For example, try autoencoders_dict = load_autoencoders_for_artifact('pythia_70m_sentiment_reward')
    '''
    simplified_policy_model_name = policy_model_name.split('/')[-1].replace('-', '_')
    full_path = f'{entity_name}/{project_prefix}_{simplified_policy_model_name}/{artifact_prefix}_{simplified_policy_model_name}:{alias}'
    print(f'Loading artifact from {full_path}')

    artifact = run.use_artifact(full_path)
    directory = artifact.download()

    save_dir = f'{directory}/saves'
    autoencoders_base_big = load_models_from_folder(load_dir=f'{save_dir}/base_big', given_device='cpu')
    autoencoders_base_small = load_models_from_folder(load_dir=f'{save_dir}/base_small', given_device='cpu')
    autoencoders_rlhf_big = load_models_from_folder(load_dir=f'{save_dir}/rlhf_big', given_device='cpu')
    autoencoders_rlhf_small = load_models_from_folder(load_dir=f'{save_dir}/rlhf_small', given_device='cpu')

    return {
        'base_big': autoencoders_base_big, 'base_small': autoencoders_base_small,
        'rlhf_big': autoencoders_rlhf_big, 'rlhf_small': autoencoders_rlhf_small
    }

def load_models_from_folder(load_dir, given_device=None):
    """
    Load PyTorch models from subfolders of a directory into a dictionary where keys are subfolder names.

    Args:
        load_dir (str): The directory from which models will be loaded.

    Returns:
        model_dict (dict): A dictionary where keys are subfolder names and values are PyTorch models.
    """
    model_dict = {}

    device = given_device if given_device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for model_name in sorted(os.listdir(load_dir)):
        model_path = os.path.join(load_dir, model_name)

        kwargs, state = torch.load(model_path, map_location=device)

        model = SparseAutoencoder(**kwargs)
        model.load_state_dict(state)
        model.to(device)
        model.cuda()
        model.eval()

        model_dict[model_name] = model
        print(f"Loaded {model_name} from {model_path}")

    return model_dict

In [None]:
autoencoders_dictionaries = load_autoencoders_for_artifact(policy_model_name=policy_model_name)

In [None]:
class ModelCustomizer:
    '''
    Used to customize model layer numbers and other network parsing details
    '''

    def __init__(self):
        '''
        Initialize
        '''
        self.target_layers = None

    def set_target_layers(self) -> list[str]:
        '''
        Set target layers
        '''

    def get_target_layers(self) -> list[str]:
        '''
        Get target layers.
        '''

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        '''
        Parse layer name to layer number
        '''

    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        '''
        Parse ae dict keys to full layer names.
        '''

In [None]:
class GPTNeoCustomizer(ModelCustomizer):

    def get_target_layers(self) -> list[str]:
        if self.target_layers:
            return self.target_layers
        else:
            return [self.layer_num_to_full_name(layer_no) for layer_no in range(12)]

    def set_target_layers(self, target_layers):
        self.target_layers = target_layers

    def layer_num_to_full_name(self, layer_no):
        return f'transformer.h.{layer_no}.mlp'

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        return layer_name.split('.')[-2]

    # Standardize layer names to full names instead of 'int'
    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        output_dict = {}
        for key, autoencoder in autoencoders_dict.items():
            output_dict[self.layer_num_to_full_name(key)] = autoencoder
        return output_dict

In [None]:
class PythiaCustomizer(ModelCustomizer):
    def __init__(self, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.target_layers = None

    def get_target_layers(self) -> list[str]:
        if self.target_layers:
            return self.target_layers
        else:
            return [self.layer_num_to_full_name(layer_no) for layer_no in range(self.num_layers)]

    def set_target_layers(self, target_layers):
        self.target_layers = target_layers

    def layer_num_to_full_name(self, layer_no):
        return f'gpt_neox.layers.{layer_no}.mlp'

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        return layer_name.split('.')[-2]

    # Standardize layer names to full names instead of 'int'
    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        output_dict = {}
        for key, autoencoder in autoencoders_dict.items():
            output_dict[self.layer_num_to_full_name(key)] = autoencoder
        return output_dict

In [None]:
model_customizers = {
    "pythia-70m": PythiaCustomizer(num_layers=6),
    "pythia-160m": PythiaCustomizer(num_layers=12),
    "pythia-410m": PythiaCustomizer(num_layers=24),
    "gpt-neo-125m": GPTNeoCustomizer()
}

In [None]:
model_customizer = model_customizers[model_name]
model_target_layers = model_customizer.get_target_layers()
model_target_layers

**Convert layer numbers to full layer names.**

In [None]:
mapped_dictionaries = {}

In [None]:
for key, ae_dict in autoencoders_dictionaries.items():
    mapped_dictionaries[key] = model_customizer.convert_ae_dict_keys(ae_dict)

autoencoders_dictionaries = mapped_dictionaries

In [None]:
rlhf_small = autoencoders_dictionaries['rlhf_small']
rlhf_big = autoencoders_dictionaries['rlhf_big']

model_customizer.set_target_layers(list(rlhf_small.keys()))
model_customizer.get_target_layers()

### Extract Activations.

In [None]:
class ActivationsHook:
    def __init__(self):
        self.activations = []

    def clear_activations(self):
        for tensor in self.activations:
            tensor = tensor.detach().cpu()
        self.activations.clear()
        self.activations = []

    def hook_fn(self, module, input, output):
        new_activations = torch.split(output.detach().cpu(), 1, dim=0)
        self.activations.extend(new_activations)

In [None]:
class ActivationsExtractor:
    def __init__(self, model, tokenizer, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.tokenizer = tokenizer

        # Create an instance of ActivationHook
        self.activation_hooks = {}
        
        for layer_name in self.target_layers:
            activation_hook = ActivationsHook()
            self.activation_hooks[layer_name] = activation_hook
            layer = dict(model.named_modules())[layer_name]
            # Register the forward hook to the chosen layer
            hook_handle = layer.register_forward_hook(activation_hook.hook_fn)

    def clear_all_activations(self):
        for layer_name, activation_hook in self.activation_hooks.items():
            activation_hook.clear_activations()

    def get_activations(self):
        """
        Retrieve all the cached activations.
        """
        return {
            layer_name: activation_hook.activations for layer_name, activation_hook in self.activation_hooks.items()
        }

    def compute_activations_from_raw_texts(self, raw_texts: str):
        self.clear_all_activations()

        # Forward pass your input through the model
        for text_batch in batch(texts):
            input_data = self.tokenizer(text_batch, return_tensors='pt', padding=True)  # Example input shape
            with torch.no_grad():
                output = self.model(**input_data)

        return self.get_activations()

    def _flatten_activations(self, final_activations, num_samples):
        flattened_activations = []
        for i in range(num_samples):
            current_activations = {}
            for layer_name, activations_list in final_activations.items():
                current_activations[layer_name] = [activations_list[i]]
    
            flattened_activations.append(current_activations)
        
        return flattened_activations

    def compute_activations_from_text_tokens_ids_target(
        self, samples: list[TextTokensIdsTarget], target_token_only=True, flatten=True
    ):
        self.clear_all_activations()

        # Forward pass your input through the model
        for text_batch in batch(samples):
            tensorized = TextTokensIdsTarget.get_tensorized(text_batch)

            with torch.no_grad():
                output = self.model(**tensorized)

        all_activations = self.get_activations()
        activations_per_layer = [len(value) for value in all_activations.values()]

        assert max(activations_per_layer) == min(activations_per_layer) == len(samples), 'Each layer should have num_samples activations'

        if target_token_only:
            all_target_token_activations = {layer_num: [] for layer_num in all_activations}
            for layer_num, layer_activations in all_activations.items():
                assert len(layer_activations) == len(samples), "Each layer should have same activations as num samples!"

                zipped_layer_activations_and_samples = zip(layer_activations, samples)
                for activations, sample in zipped_layer_activations_and_samples:
                    relevant_token_activations = activations[:, sample.target_token_position, :]
                    all_target_token_activations[layer_num].append(relevant_token_activations)

            final_activations = all_target_token_activations

        else:
            final_activations = all_activations

        if flatten:
            final_activations = self._flatten_activations(final_activations, num_samples=len(samples))
            
        else:
            print(f'Returning a dictionary mapping layer name to list of activations')

        return final_activations

In [None]:
extractor = ActivationsExtractor(model=model, target_layers=model_customizer.get_target_layers(), tokenizer=tokenizer)

sample_positives = [point.trimmed_positive_example for point in sample_training_points]

sample_target_token_activations = extractor.compute_activations_from_text_tokens_ids_target(sample_positives)

In [None]:
class AutoencoderManager:
    def __init__(self, model, tokenizer, autoencoders_dict):
        self.model = model
        self.tokenizer = tokenizer
        self.autoencoders_dict = autoencoders_dict

    def get_dictionary_features(self, activations, layer_name):
        """
        Returns raw dictionary features for activations at a layer number.
        """
        with torch.no_grad():
            features = self.autoencoders_dict[layer_name](activations.cuda())
            return features

    def get_all_dictionary_features_for_list(self, activations_dict_list: list[dict[str, list[Tensor]]]):
        return [self.get_all_dictionary_features_for_point(point) for point in activations_dict_list]
        
    def get_all_dictionary_features_for_point(self, activations_dict: dict[str, list[Tensor]]):
        all_features = {}
        for layer_name, autoencoder in self.autoencoders_dict.items():
            activations = activations_dict[layer_name]
            assert len(activations) == 1, "Can only do conversion for single elements right now"
            curr_dict_features = self.get_dictionary_features(activations[0], layer_name)[0].tolist()
            all_features[layer_name] = csr_matrix(curr_dict_features)
        return all_features

In [None]:
autoencoder_manager = AutoencoderManager(
    model=model, tokenizer=tokenizer, autoencoders_dict=rlhf_small
)
sample_features = autoencoder_manager.get_all_dictionary_features_for_list(
    activations_dict_list = sample_target_token_activations
)

In [None]:
class LinearProbeTrainingPoint:
    def __init__(
        self, training_point: TrainingPoint,
        # positive token
        target_positive_token_id: int,
        target_positive_token: str,
        positive_activations: [str, Tensor],   # dictionary of layer_num to positive token activations
        positive_token_ae_features: [str, Tensor], 
        # negative token
        target_negative_token_id: int,
        target_negative_token: str,
        negative_activations: [str, Tensor],  # dictionary of layer_num to negative token activations
        negative_token_ae_features: [str, Tensor],
        # neutral token
        target_neutral_token_id: int,
        target_neutral_token: str,
        neutral_activations: [str, Tensor],   # dictionary of layer_num to neutral activations
        neutral_token_ae_features: [str, Tensor]
    ):
        self.training_point: TrainingPoint = training_point

        self.target_positive_token = target_positive_token
        self.target_positive_token_id = target_positive_token_id
        self.target_positive_reward = self.training_point.target_positive_reward
        self.positive_token_ae_features = positive_token_ae_features
        self.positive_activations = positive_activations

        self.target_negative_token = target_negative_token
        self.target_negative_token_id = target_negative_token_id
        self.target_negative_reward = self.training_point.target_negative_reward
        self.negative_token_ae_features = negative_token_ae_features
        self.negative_activations = negative_activations

        self.target_neutral_token = target_neutral_token
        self.target_neutral_token_id = target_neutral_token_id
        self.neutral_token_ae_features = neutral_token_ae_features
        self.neutral_activations = neutral_activations

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [None]:
class LinearProbeTrainingDataManager:
    """
    This takes all the sample training data, and calculates all activations for all training samples
    """

    def __init__(
        self, training_data: list[TrainingPoint], autoencoders_dict,
        target_layers: list[str], model=model, tokenizer=tokenizer,
    ):
        self.activations_extractor = ActivationsExtractor(
            model=model, tokenizer=tokenizer, target_layers=target_layers)
        self.autoencoders_dict = autoencoders_dict
        self.autoencoder_manager = AutoencoderManager(
            model=model, tokenizer=tokenizer, autoencoders_dict=autoencoders_dict
        )
        self.training_data = training_data

    def compute_training_points_single_batch(self, single_batch: list[TrainingPoint]):
        positive_samples = [item.trimmed_positive_example for item in single_batch]
        negative_samples = [item.trimmed_negative_example for item in single_batch]
        neutral_samples = [item.trimmed_neutral_example for item in single_batch]
    
        positive_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(
            positive_samples, target_token_only=True, flatten=True
        )
        negative_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(
            negative_samples, target_token_only=True, flatten=True
        )

        neutral_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(
            neutral_samples, target_token_only=True, flatten=True
        )

        positive_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(positive_activations_list)
        negative_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(negative_activations_list)
        neutral_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(neutral_activations_list)

        assert (
            len(positive_samples) == len(negative_samples) == len(neutral_samples) == 
            len(positive_activations_list) == len(negative_activations_list) ==
            len(positive_dictionary_features) == len(negative_dictionary_features) == len(neutral_dictionary_features)
        ), "All samples, activations and dict features should align in length."

        linear_probe_training_batch = []
        zipped_point_and_features = zip(
            single_batch, positive_dictionary_features, positive_activations_list, 
            negative_dictionary_features, negative_activations_list, 
            neutral_dictionary_features, neutral_activations_list
        )

        for training_point, positive_features, positive_activations, negative_features, negative_activations, neutral_features, neutral_activations in zipped_point_and_features:
            linear_probe_training_point = LinearProbeTrainingPoint(
                training_point=training_point,
                positive_token_ae_features=positive_features,
                positive_activations=positive_activations,
                negative_token_ae_features=negative_features,
                negative_activations=negative_activations,
                neutral_token_ae_features=neutral_features,
                neutral_activations=neutral_activations,
                # Positive token
                target_positive_token_id=training_point.target_positive_token_id,
                target_positive_token=training_point.target_positive_token,
                # Negative token
                target_negative_token_id=training_point.target_negative_token_id,
                target_negative_token=training_point.target_negative_token,
                # Neutral token
                target_neutral_token_id=training_point.target_neutral_token_id,
                target_neutral_token=training_point.target_neutral_token,
            )
            linear_probe_training_batch.append(linear_probe_training_point)

        return linear_probe_training_batch

    def construct_training_dataset(self, all_training_data: list[TrainingPoint], option: str = None):
        """
        Constructs final training dataset, consisting of input and expected output. 
        """
        all_results = []
        for index, single_batch in tqdm_notebook(enumerate(batch(all_training_data))):
            if index % 250 == 0:
                print(f'Clearing cuda cache on batch {index}')
                clear_gpu_memory()
    
            current_results = self.compute_training_points_single_batch(single_batch)
            all_results.extend(current_results)

        assert len(all_results) == len(all_training_data), "Not all training points were converted to probe inputs!"

        return all_results

In [None]:
sample_training_data = successful_training_points[:5]

lp_training_data_manager = LinearProbeTrainingDataManager(
    training_data=sample_training_data, autoencoders_dict=rlhf_small,
    model=model, tokenizer=tokenizer, target_layers = model_customizer.get_target_layers()    
)

In [None]:
sample_probe_inputs = lp_training_data_manager.construct_training_dataset(sample_training_data)
x =[input_point.target_positive_token_id for input_point in sample_probe_inputs]
tokenizer.decode(x)

In [None]:
full_training_dataset = lp_training_data_manager.construct_training_dataset(successful_training_points)

In [None]:
print(full_training_dataset[10])

In [None]:
def save_training_dataset_to_wandb(training_dataset: list[LinearProbeTrainingPoint]):
    out_filename = "training_dataset.pkl"

    with open(out_filename, "wb") as f_out:
        pickle.dump(training_dataset, f_out)
    
    my_artifact = wandb.Artifact(f"linear_probe_training_dataset_{policy_model_name}", type="data")
    
    # Add the list to the artifact
    my_artifact.add_file(local_path=out_filename, name="linear_probe_training_dataset")

    metadata_dict = {
        "description": "Training dataset, with activations and rewards",
        "source": "Generated by my script",
        "num_examples": len(training_dataset),
        "split": "full"
    }

    my_artifact.metadata.update(metadata_dict)

    # Log the artifact to the run
    wandb.log_artifact(my_artifact)

save_training_dataset_to_wandb(full_training_dataset)

### Upload high value features to wandb. (Used for adhoc uploads)

In [None]:
gpt_neo_125m_features = [
    ('10', 732), ('10', 131), ('10', 56), ('10', 600), ('10', 651), ('10', 433),
    ('10', 648), ('10', 613), ('10', 391), ('10', 273), ('10', 84), ('10', 106),
    ('10', 729), ('10', 405), ('10', 370), ('10', 59), ('10', 23), ('10', 736),
    ('10', 417), ('10', 69), ('10', 401), ('10', 738), ('10', 389), ('11', 573),
    ('11', 295), ('11', 276), ('11', 274), ('11', 95), ('11', 653), ('11', 598),
    ('11', 469), ('11', 397), ('11', 635), ('11', 124), ('11', 661), ('11', 498),
    ('11', 201), ('11', 135), ('11', 85), ('11', 401), ('11', 651), ('11', 119),
    ('11', 506), ('11', 224), ('11', 288), ('11', 455), ('11', 24), ('11', 533),
    ('11', 346), ('11', 33), ('7', 250), ('7', 389), ('7', 587), ('7', 134),
    ('7', 332), ('7', 123), ('7', 489), ('7', 435), ('7', 602), ('7', 574),
    ('7', 753), ('7', 68), ('7', 408), ('7', 36), ('7', 124), ('7', 301), ('7', 12),
    ('7', 333), ('7', 223), ('7', 434), ('7', 122), ('7', 588), ('7', 335),
    ('8', 732), ('8', 345), ('8', 400), ('8', 214), ('8', 348), ('8', 447), ('8', 541),
    ('8', 155), ('8', 172), ('8', 156), ('8', 658), ('8', 463), ('8', 507), ('8', 735),
    ('8', 551), ('8', 635), ('8', 434), ('8', 146), ('8', 662), ('8', 653), ('8', 743),
    ('8', 566), ('8', 380), ('8', 505), ('9', 566), ('9', 423), ('9', 494), ('9', 48),
    ('9', 426), ('9', 653), ('9', 457), ('9', 385), ('9', 23), ('9', 421), ('9', 572),
    ('9', 3), ('9', 649), ('9', 678), ('9', 11), ('9', 84), ('9', 717), ('9', 429),
    ('9', 356), ('9', 404)
]

gpt_neo_125m_version = 'v3'

In [None]:
pythia_410m_features = [
    ('16', 759), ('16', 661), ('16', 120), ('16', 154), ('16', 551), ('16', 651), 
    ('16', 923), ('16', 801), ('16', 380), ('16', 480), ('16', 705), ('16', 825), 
    ('16', 166), ('16', 750), ('16', 694), ('16', 140), ('16', 261), ('16', 866), 
    ('16', 571), ('16', 469), ('16', 691), ('16', 852), ('16', 966), ('17', 634), 
    ('17', 981), ('17', 994), ('17', 471), ('17', 761), ('17', 726), ('17', 21), 
    ('17', 1002), ('17', 605), ('17', 68), ('17', 16), ('17', 466), ('17', 185), 
    ('17', 413), ('17', 97), ('17', 238), ('17', 522), ('17', 518), ('17', 629), 
    ('17', 860), ('17', 213), ('17', 41), ('21', 314), ('21', 884), ('21', 911), 
    ('21', 876), ('21', 897), ('21', 416), ('21', 424), ('21', 829), ('21', 246), 
    ('21', 192), ('21', 850), ('21', 210), ('21', 627), ('21', 174), ('21', 927), 
    ('21', 28), ('21', 631), ('21', 781), ('21', 806), ('21', 628), ('21', 115), 
    ('22', 391), ('22', 329), ('22', 759), ('22', 349), ('22', 946), ('22', 819), 
    ('22', 289), ('22', 287), ('22', 1012), ('22', 318), ('22', 138), ('22', 267), 
    ('22', 430), ('22', 159), ('22', 276), ('22', 197), ('22', 632), ('22', 1015), 
    ('22', 63), ('22', 701), ('22', 200), ('22', 715), ('22', 964), ('22', 563), 
    ('23', 139), ('23', 626), ('23', 899), ('23', 738), ('23', 370), ('23', 37), 
    ('23', 749), ('23', 487), ('23', 826), ('23', 621), ('23', 213), ('23', 552), 
    ('23', 1013), ('23', 321), ('23', 635), ('23', 500), ('23', 215), ('23', 1000), 
    ('23', 1023), ('23', 478), ('23', 581), ('23', 947), ('23', 1022), ('23', 577)
]

pythia_410m_version = 'v2'

In [None]:
pythia_160m_features = [
    ('10', 169), ('10', 507), ('10', 248), ('10', 215), ('10', 98), ('10', 145),
    ('10', 40), ('10', 714), ('10', 241), ('10', 541), ('10', 445), ('10', 315),
    ('10', 251), ('10', 116), ('11', 185), ('11', 261), ('11', 410), ('11', 207),
    ('11', 575), ('11', 198), ('11', 331), ('11', 212), ('11', 590), ('11', 99),
    ('11', 502), ('11', 471), ('11', 754), ('11', 218), ('11', 492), ('11', 55),
    ('11', 513), ('11', 70), ('7', 415), ('7', 644), ('7', 546), ('7', 52),
    ('7', 364), ('7', 260), ('7', 290), ('7', 472), ('7', 429), ('7', 123),
    ('7', 61), ('7', 43), ('7', 387), ('7', 236), ('7', 469), ('7', 15), ('7', 501),
    ('7', 379), ('8', 427), ('8', 284), ('8', 575), ('8', 498), ('8', 403),
    ('8', 410), ('8', 148), ('8', 680), ('8', 144), ('8', 516), ('8', 670),
    ('8', 102), ('8', 69), ('8', 260), ('9', 664), ('9', 556), ('9', 542),
    ('9', 560), ('9', 158), ('9', 268), ('9', 70), ('9', 547), ('9', 569),
    ('9', 193), ('9', 546), ('9', 589), ('9', 16), ('9', 583), ('9', 411),
    ('9', 186), ('9', 634)
]

pythia_160m_version = 'v4'

In [None]:
pythia_70m_features = [
    ('1', 45), ('1', 314), ('1', 161), ('1', 43), ('1', 254), ('1', 391),
    ('1', 422), ('1', 420), ('1', 482), ('1', 127), ('1', 162), ('1', 193),
    ('2', 183), ('2', 164), ('2', 291), ('2', 380), ('2', 97), ('2', 415),
    ('2', 269), ('2', 229), ('2', 220), ('2', 457), ('2', 129), ('2', 96),
    ('3', 23), ('3', 252), ('3', 255), ('3', 104), ('3', 379), ('3', 141),
    ('3', 170), ('3', 128), ('3', 117), ('3', 244), ('3', 93), ('3', 130),
    ('4', 161), ('4', 22), ('4', 303), ('4', 119), ('4', 404), ('4', 368),
    ('4', 301), ('4', 96), ('4', 23), ('5', 9), ('5', 75), ('5', 377),
    ('5', 68), ('5', 93), ('5', 381), ('5', 22), ('5', 39), ('5', 189),
    ('5', 221), ('5', 231), ('5', 251)
]

pythia_70m_version = 'v11'

In [None]:
output_high_value_features_artifact = {
    "pythia_70m": (pythia_70m_version, pythia_70m_features),
    "pythia_160m": (pythia_160m_version, pythia_160m_features),
    "pythia_410m": (pythia_410m_version, pythia_410m_features),
    "gpt_neo_125m": (gpt_neo_125m_version, gpt_neo_125m_features)
}

In [None]:
def save_high_value_features_dataset_to_wandb(output_high_value_features_artifact):
    out_filename = "high_value_features_dataset.json"

    with open(out_filename, "w") as f_out:
        json.dump(output_high_value_features_artifact, f_out)
    
    my_artifact = wandb.Artifact("high_value_features_artifact", type="data")
    
    # Add the list to the artifact
    my_artifact.add_file(local_path=out_filename, name="high_value_features_artifact")

    metadata_dict = {
        "description": "High value features for pythia70m, pythia160m, pythia410m, and gpt_neo_125m. Includes source versions for these.",
        "source": "Generated by Marc's experiments using GPT-4"
    }

    my_artifact.metadata.update(metadata_dict)

    # Log the artifact to the run
    wandb.log_artifact(my_artifact)

In [None]:
save_high_value_features_dataset_to_wandb(output_high_value_features_artifact)