### Load training dataset and Vader sentiment analyzer

In [1]:
import dataclasses
import gc
import json
import os
import pickle
import pprint

from time import time

import nltk
import spacy
import torch
import torch.nn.functional as F
import tqdm
import transformers
import wandb

from scipy.sparse import csr_matrix
from torch import FloatTensor, LongTensor, Tensor

from dataclasses import dataclass
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from torch import nn
from tqdm import tqdm_notebook
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_name = 'gpt-neo-125m'
policy_model_name  = 'gpt_neo_125m_utility_reward'
os.environ['WANDB_API_KEY'] = ''

In [3]:
run = wandb.init(project="utility_reconstruction")

[34m[1mwandb[0m: Currently logged in as: [33mamirali1985[0m ([33mnlp_and_interpretability[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111230937777792, max=1.0)…

In [4]:
full_model_name = f'EleutherAI/{model_name}'

model = AutoModelForCausalLM.from_pretrained(full_model_name)
model.cuda()
model.eval()

tokenizer = AutoTokenizer.from_pretrained(full_model_name)
tokenizer.pad_token = tokenizer.eos_token

#model.resize_token_embeddings(len(tokenizer))

In [5]:
sentiment_analyzer = SentimentIntensityAnalyzer()
lexicon = sentiment_analyzer.lexicon

## Tokenization and torch utilities.

In [6]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def clear_gpu_memory():
    start_time = time()
    gc.collect()
    torch.cuda.empty_cache()
    end_time = time()
    total_time = round(end_time-start_time, 2)
    print(f'Took {total_time} seconds to clear cache.')

In [7]:
def pad_list_of_lists(list_of_lists, pad_token):
    max_length = max(len(lst) for lst in list_of_lists)
    padded_list = [lst + [pad_token] * (max_length - len(lst)) for lst in list_of_lists]
    return padded_list

In [8]:
def check_number_of_tokens(word, tokenizer=tokenizer):
    return len(tokenizer(word)['input_ids'])

In [9]:
def get_tokens_and_ids(text, tokenizer=tokenizer):
    input_ids = tokenizer(text.lower(), truncation=True)['input_ids']
    
    tokens = [tokenizer.decode(input_id) for input_id in input_ids]
    # The above produces artifacts such as a " positive" token and id, instead of "positive". So we redo this.

    tokens = [token.lower().strip() for token in tokens]
    tokenizer
    return tokens, input_ids

In [10]:
def get_single_target_token_id(word, tokenizer=tokenizer):
    word = word.lower().strip()
    num_tokens = check_number_of_tokens(word)
    if num_tokens > 1:
        # Backoff to include a single space.
        word = f' {word}'
        num_tokens = check_number_of_tokens(word)

    return tokenizer(word)['input_ids'][0]

In [11]:
@dataclass
class TextTokensIdsTarget:
    attention_mask: list[int]
    text: str
    tokens: list[str]
    ids: list[int]
    target_token: str
    target_token_id: int
    target_token_position: int

    @staticmethod
    def get_tensorized(datapoints: "TextTokensIdsTarget"):
        max_length = max([len(datapoint.tokens) for datapoint in datapoints])
        
        input_ids = [datapoint.ids for datapoint in datapoints]
        attention_masks = [datapoint.attention_mask for datapoint in datapoints]

        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])
        attention_masks_padded = pad_list_of_lists(attention_masks, 0)
        all_tokenized = {
            "input_ids": torch.IntTensor(input_ids_padded).cuda(), "attention_mask": torch.ByteTensor(attention_masks_padded).cuda()
        }
        return all_tokenized

def trim_example(input_text: str, target_words: list[str], verbose=False, tokenizer=tokenizer):
    single_target_token_ids = [get_single_target_token_id(word.strip().lower()) for word in target_words]
    
    single_target_token_ids = [token_id for token_id in single_target_token_ids if token_id]
    single_target_tokens = [tokenizer.decode(token_id).strip().lower() for token_id in single_target_token_ids]

    input_tokens, input_token_ids = get_tokens_and_ids(input_text)

    trimmed_input_tokens = []
    trimmed_input_token_ids = []

    for input_token, input_token_id in zip(input_tokens, input_token_ids):
        trimmed_input_tokens.append(input_token)
        trimmed_input_token_ids.append(input_token_id)
        if input_token.strip().lower() in single_target_tokens:
            break

    assert len(trimmed_input_token_ids) == len(trimmed_input_tokens), "Num of tokens and token ids should be equal"

    last_token = None

    if trimmed_input_tokens:
        last_token = trimmed_input_tokens[-1].lower().strip()
        last_token_id = trimmed_input_token_ids[-1]

    if len(trimmed_input_tokens) > tokenizer.model_max_length:
        print(f'Dropping example since exceed model max length. Input text was:\n{input_text}')
        return None
    
    elif last_token and last_token in single_target_tokens:
        text = tokenizer.decode(trimmed_input_token_ids)
        target_token_position = len(trimmed_input_token_ids) - 1
        return TextTokensIdsTarget(
            attention_mask=[1]*len(trimmed_input_tokens),
            text=text, tokens=trimmed_input_tokens, ids=trimmed_input_token_ids, 
            target_token=last_token, target_token_id=last_token_id,
            target_token_position=target_token_position
        )
    else:
        if verbose:
            print(f'last token was {last_token} in {trimmed_input_tokens}, and was not in target tokens.')
        return None

### Load training examples for linear probe.

In [12]:
def load_wandb_json_artifact(
    project_name='utility_reconstruction', artifact_name = 'contrastive_sentiment_pairs', version='v6'
):
    api = wandb.Api()
    artifact = api.artifact(f'nlp_and_interpretability/{project_name}/{artifact_name}:{version}', type='data')
    artifact_dir = artifact.download()

    with open(f'artifacts/{artifact_name}:{version}/{artifact_name}', 'r') as f_in:
        result = json.load(f_in)
        return result

all_input_dicts = load_wandb_json_artifact()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [13]:
def clean_capitalization_neutral_terms(input_dicts):
    for input_dict in input_dicts:
        neutral_words = list(input_dict['neutral_words'].values())
        for neutral_word in neutral_words:
            input_dict['neutral_text'] = input_dict['neutral_text'].replace(neutral_word, neutral_word.lower())

    return input_dicts

all_input_dicts = clean_capitalization_neutral_terms(all_input_dicts)

In [14]:
class TrainingPoint:

    def __init__(self, input_dict: dict, tokenizer=tokenizer):
        self.input_dict = input_dict
        self.positive_text = input_dict['input_text']
        self.negative_text = input_dict['output_text']
        self.neutral_text = input_dict['neutral_text']
        
        # Dictionary of layer name to activations by mlp layer.
        self.activations: dict = None

        # Dictionary of layer name to autoencoder feature by mlp layer
        self.autoencoder_feature: dict = None

        # Reward value of target_token.
        self.target_positive_reward = None
        self.target_negative_reward = None

        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)
        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)
        
        self.positive_words = input_dict['positive_words']
        self.negative_words = list(input_dict['new_words'].values())
        self.neutral_words = list(input_dict['neutral_words'].values())

        self.target_positive_reward = None
        self.target_positive_token = None
        self.target_positive_token_id = None
    
        self.target_negative_reward = None
        self.target_negative_token = None
        self.target_negative_token_id = None

        self.target_neutral_token = None
        self.target_neutral_token_id = None

        try:
            self.trimmed_positive_example: TextTokensIdTarget = trim_example(self.positive_text, self.positive_words)
            if self.trimmed_positive_example:
                positive_token = self.trimmed_positive_example.target_token.strip().lower()
                self.target_positive_reward = lexicon.get(positive_token, None)
                self.target_positive_token = positive_token
                self.target_positive_token_id = self.trimmed_positive_example.target_token_id
        
        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for positive example.')
            self.trimmed_positive_example = None
        
        try:
            self.trimmed_negative_example: TextTokensIdTarget = trim_example(self.negative_text, self.negative_words)
            if self.trimmed_negative_example:
                negative_token = self.trimmed_negative_example.target_token.strip().lower()
                self.target_negative_reward = lexicon.get(negative_token, None)
                self.target_negative_token = negative_token
                self.target_negative_token_id = self.trimmed_negative_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for negative example.')
            self.trimmed_negative_example = None

        try:
            self.trimmed_neutral_example: TextTokensIdTarget = trim_example(self.neutral_text, self.neutral_words)
            if self.trimmed_neutral_example:
                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()
                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for neutral example.')
            self.trimmed_neutral_example = None

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [15]:
training_points = []
for input_dict in tqdm_notebook(all_input_dicts):
    training_points.append(TrainingPoint(input_dict))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for input_dict in tqdm_notebook(all_input_dicts):


  0%|          | 0/8640 [00:00<?, ?it/s]

Caught exception list index out of range on {'input_text': 'It seems like anybody can make a movie nowadays.', 'output_text': 'It seems anybody can make a movie nowadays.', 'neutral_text': 'It seems tolerate anybody can make a movie nowadays.', 'positive_words': ['like'], 'new_words': {'like': ''}, 'neutral_words': {'like': 'tolerate'}} for negative example.


Token indices sequence length is longer than the specified maximum sequence length for this model (4099 > 2048). Running this sequence through the model will result in indexing errors


In [16]:
successful_training_points = [
    training_point for training_point in training_points if 
    training_point.trimmed_positive_example and training_point.trimmed_negative_example
    and training_point.trimmed_neutral_example
    and training_point.target_positive_reward is not None
    and training_point.target_negative_reward is not None
]

In [17]:
len(successful_training_points)

6601

In [18]:
sample_training_points = successful_training_points[:10]

In [19]:
bad_training_points = [training_point for training_point in  training_points if training_point not in successful_training_points]

### Load autoencoders for linear probe.

In [20]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, l1_coef):
        super(SparseAutoencoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'l1_coef': l1_coef}
        self.l1_coef = l1_coef

        self.encoder_weight = nn.Parameter(torch.randn(hidden_size, input_size))
        nn.init.orthogonal_(self.encoder_weight)

        self.encoder_bias = nn.Parameter(torch.zeros(self.hidden_size))
        self.decoder_bias = nn.Parameter(torch.zeros(input_size))

    def forward(self, x):
        normalized_encoder_weight = F.normalize(self.encoder_weight, p=2, dim=1)

        features = F.linear(x, normalized_encoder_weight, self.encoder_bias)
        features = F.relu(features)

        # reconstruction = F.linear(features, normalized_encoder_weight.t(), self.decoder_bias)
        return features.detach()

In [21]:
entity_name = 'nlp_and_interpretability'
project_prefix = 'Autoencoder_training'
artifact_prefix = 'autoencoders'

def load_autoencoders_for_artifact(policy_model_name, alias='latest', run=run):
    '''
    Loads the autoencoders from one run into memory. Note that these paths are to some extent hardcoded
    For example, try autoencoders_dict = load_autoencoders_for_artifact('pythia_70m_sentiment_reward')
    '''
    simplified_policy_model_name = policy_model_name.split('/')[-1].replace('-', '_')
    full_path = f'{entity_name}/{project_prefix}_{simplified_policy_model_name}/{artifact_prefix}_{simplified_policy_model_name}:{alias}'
    print(f'Loading artifact from {full_path}')

    artifact = run.use_artifact(full_path)
    directory = artifact.download()

    save_dir = f'{directory}/saves'
    autoencoders_base_big = load_models_from_folder(load_dir=f'{save_dir}/base_big', given_device='cpu')
    autoencoders_base_small = load_models_from_folder(load_dir=f'{save_dir}/base_small', given_device='cpu')
    autoencoders_rlhf_big = load_models_from_folder(load_dir=f'{save_dir}/rlhf_big', given_device='cpu')
    autoencoders_rlhf_small = load_models_from_folder(load_dir=f'{save_dir}/rlhf_small', given_device='cpu')

    return {
        'base_big': autoencoders_base_big, 'base_small': autoencoders_base_small,
        'rlhf_big': autoencoders_rlhf_big, 'rlhf_small': autoencoders_rlhf_small
    }

def load_models_from_folder(load_dir, given_device=None):
    """
    Load PyTorch models from subfolders of a directory into a dictionary where keys are subfolder names.

    Args:
        load_dir (str): The directory from which models will be loaded.

    Returns:
        model_dict (dict): A dictionary where keys are subfolder names and values are PyTorch models.
    """
    model_dict = {}

    device = given_device if given_device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for model_name in sorted(os.listdir(load_dir)):
        model_path = os.path.join(load_dir, model_name)

        kwargs, state = torch.load(model_path, map_location=device)

        model = SparseAutoencoder(**kwargs)
        model.load_state_dict(state)
        model.to(device)
        model.cuda()
        model.eval()

        model_dict[model_name] = model
        print(f"Loaded {model_name} from {model_path}")

    return model_dict

In [22]:
autoencoders_dictionaries = load_autoencoders_for_artifact(policy_model_name=policy_model_name)

Loading artifact from nlp_and_interpretability/Autoencoder_training_gpt_neo_125m_utility_reward/autoencoders_gpt_neo_125m_utility_reward:latest


[34m[1mwandb[0m: Downloading large artifact autoencoders_gpt_neo_125m_utility_reward:latest, 67.67MB. 20 files... 
[34m[1mwandb[0m:   20 of 20 files downloaded.  
Done. 0:0:0.2


Loaded 10 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_big/10
Loaded 11 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_big/11
Loaded 7 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_big/7
Loaded 8 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_big/8
Loaded 9 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_big/9
Loaded 10 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_small/10
Loaded 11 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility_reward:v3/saves/base_small/11
Loaded 7 from /data/home/amir/work/codes/rlhf/Notebooks/artifacts/autoencoders_gpt_neo_125m_utility

In [23]:
class ModelCustomizer:
    '''
    Used to customize model layer numbers and other network parsing details
    '''

    def __init__(self):
        '''
        Initialize
        '''
        self.target_layers = None

    def set_target_layers(self) -> list[str]:
        '''
        Set target layers
        '''

    def get_target_layers(self) -> list[str]:
        '''
        Get target layers.
        '''

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        '''
        Parse layer name to layer number
        '''

    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        '''
        Parse ae dict keys to full layer names.
        '''


class GPTNeoCustomizer(ModelCustomizer):

    def get_target_layers(self) -> list[str]:
        if self.target_layers:
            return self.target_layers
        else:
            return [self.layer_num_to_full_name(layer_no) for layer_no in range(12)]

    def set_target_layers(self, target_layers):
        self.target_layers = target_layers

    def layer_num_to_full_name(self, layer_no):
        return f'transformer.h.{layer_no}.mlp'

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        return layer_name.split('.')[-2]

    # Standardize layer names to full names instead of 'int'
    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        output_dict = {}
        for key, autoencoder in autoencoders_dict.items():
            output_dict[self.layer_num_to_full_name(key)] = autoencoder
        return output_dict

model_customizer = GPTNeoCustomizer()
model_target_layers = model_customizer.get_target_layers()
model_target_layers

['transformer.h.0.mlp',
 'transformer.h.1.mlp',
 'transformer.h.2.mlp',
 'transformer.h.3.mlp',
 'transformer.h.4.mlp',
 'transformer.h.5.mlp',
 'transformer.h.6.mlp',
 'transformer.h.7.mlp',
 'transformer.h.8.mlp',
 'transformer.h.9.mlp',
 'transformer.h.10.mlp',
 'transformer.h.11.mlp']

**Convert layer numbers to full layer names.**

In [24]:
mapped_dictionaries = {}

In [25]:
for key, ae_dict in autoencoders_dictionaries.items():
    mapped_dictionaries[key] = model_customizer.convert_ae_dict_keys(
      ae_dict
    )

autoencoders_dictionaries = mapped_dictionaries

In [26]:
rlhf_small = autoencoders_dictionaries['rlhf_small']
rlhf_big = autoencoders_dictionaries['rlhf_big']

model_customizer.set_target_layers(list(rlhf_small.keys()))
model_customizer.get_target_layers()

['transformer.h.10.mlp',
 'transformer.h.11.mlp',
 'transformer.h.7.mlp',
 'transformer.h.8.mlp',
 'transformer.h.9.mlp']

### Extract Activations.

In [27]:
class ActivationsHook:
    def __init__(self):
        self.activations = []

    def clear_activations(self):
        for tensor in self.activations:
            tensor = tensor.detach().cpu()
        self.activations.clear()
        self.activations = []

    def hook_fn(self, module, input, output):
        new_activations = torch.split(output.detach().cpu(), 1, dim=0)
        self.activations.extend(new_activations)

In [28]:
class ActivationsExtractor:
    def __init__(self, model, tokenizer, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.tokenizer = tokenizer

        # Create an instance of ActivationHook
        self.activation_hooks = {}
        
        for layer_name in self.target_layers:
            activation_hook = ActivationsHook()
            self.activation_hooks[layer_name] = activation_hook
            layer = dict(model.named_modules())[layer_name]
            # Register the forward hook to the chosen layer
            hook_handle = layer.register_forward_hook(activation_hook.hook_fn)

    def clear_all_activations(self):
        for layer_name, activation_hook in self.activation_hooks.items():
            activation_hook.clear_activations()

    def get_activations(self):
        """
        Retrieve all the cached activations.
        """
        return {
            layer_name: activation_hook.activations for layer_name, activation_hook in self.activation_hooks.items()
        }

    def compute_activations_from_raw_texts(self, raw_texts: str):
        self.clear_all_activations()

        # Forward pass your input through the model
        for text_batch in batch(texts):
            input_data = self.tokenizer(text_batch, return_tensors='pt', padding=True)  # Example input shape
            with torch.no_grad():
                output = self.model(**input_data)

        return self.get_activations()

    def _flatten_activations(self, final_activations, num_samples):
        flattened_activations = []
        for i in range(num_samples):
            current_activations = {}
            for layer_name, activations_list in final_activations.items():
                current_activations[layer_name] = [activations_list[i]]
    
            flattened_activations.append(current_activations)
        
        return flattened_activations

    def compute_activations_from_text_tokens_ids_target(
        self, samples: list[TextTokensIdsTarget], target_token_only=True, flatten=True
    ):
        self.clear_all_activations()

        # Forward pass your input through the model
        for text_batch in batch(samples):
            tensorized = TextTokensIdsTarget.get_tensorized(text_batch)

            with torch.no_grad():
                output = self.model(**tensorized)

        all_activations = self.get_activations()
        activations_per_layer = [len(value) for value in all_activations.values()]

        assert max(activations_per_layer) == min(activations_per_layer) == len(samples), 'Each layer should have num_samples activations'

        if target_token_only:
            all_target_token_activations = {layer_num: [] for layer_num in all_activations}
            for layer_num, layer_activations in all_activations.items():
                assert len(layer_activations) == len(samples), "Each layer should have same activations as num samples!"

                zipped_layer_activations_and_samples = zip(layer_activations, samples)
                for activations, sample in zipped_layer_activations_and_samples:
                    relevant_token_activations = activations[:, sample.target_token_position, :]
                    all_target_token_activations[layer_num].append(relevant_token_activations)

            final_activations = all_target_token_activations

        else:
            final_activations = all_activations

        if flatten:
            final_activations = self._flatten_activations(final_activations, num_samples=len(samples))
            
        else:
            print(f'Returning a dictionary mapping layer name to list of activations')

        return final_activations

In [29]:
extractor = ActivationsExtractor(model=model, target_layers=model_customizer.get_target_layers(), tokenizer=tokenizer)

sample_positives = [point.trimmed_positive_example for point in sample_training_points]

sample_target_token_activations = extractor.compute_activations_from_text_tokens_ids_target(sample_positives)

In [30]:
class AutoencoderManager:
    def __init__(self, model, tokenizer, autoencoders_dict):
        self.model = model
        self.tokenizer = tokenizer
        self.autoencoders_dict = autoencoders_dict

    def get_dictionary_features(self, activations, layer_name):
        """
        Returns raw dictionary features for activations at a layer number.
        """
        with torch.no_grad():
            features = self.autoencoders_dict[layer_name](activations.cuda())
            return features

    def get_all_dictionary_features_for_list(self, activations_dict_list: list[dict[str, list[Tensor]]]):
        return [self.get_all_dictionary_features_for_point(point) for point in activations_dict_list]
        
    def get_all_dictionary_features_for_point(self, activations_dict: dict[str, list[Tensor]]):
        all_features = {}
        for layer_name, autoencoder in self.autoencoders_dict.items():
            activations = activations_dict[layer_name]
            assert len(activations) == 1, "Can only do conversion for single elements right now"
            curr_dict_features = self.get_dictionary_features(activations[0], layer_name)[0].tolist()
            all_features[layer_name] = csr_matrix(curr_dict_features)
        return all_features

In [31]:
autoencoder_manager = AutoencoderManager(
    model=model, tokenizer=tokenizer, autoencoders_dict=rlhf_small
)
sample_features = autoencoder_manager.get_all_dictionary_features_for_list(
    activations_dict_list = sample_target_token_activations
)

In [32]:
class LinearProbeTrainingPoint:
    def __init__(
        self, training_point: TrainingPoint,
        # positive token
        target_positive_token_id: int,
        target_positive_token: str,
        positive_token_ae_features: [str, Tensor], 
        # negative token
        target_negative_token_id: int,
        target_negative_token: str,
        negative_token_ae_features: [str, Tensor],
        # neutral token
        target_neutral_token_id: int,
        target_neutral_token: str,
        neutral_token_ae_features: [str, Tensor]
    ):
        self.training_point: TrainingPoint = training_point

        self.target_positive_token = target_positive_token
        self.target_positive_token_id = target_positive_token_id
        self.target_positive_reward = self.training_point.target_positive_reward
        self.positive_token_ae_features = positive_token_ae_features

        self.target_negative_token = target_negative_token
        self.target_negative_token_id = target_negative_token_id
        self.target_negative_reward = self.training_point.target_negative_reward
        self.negative_token_ae_features = negative_token_ae_features

        self.target_neutral_token = target_neutral_token
        self.target_neutral_token_id = target_neutral_token_id
        self.neutral_token_ae_features = neutral_token_ae_features

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [33]:
class LinearProbeTrainingDataManager:
    """
    This takes all the sample training data, and calculates all activations for all training samples
    """

    def __init__(
        self, training_data: list[TrainingPoint], autoencoders_dict,
        target_layers: list[str], model=model, tokenizer=tokenizer,
    ):
        self.activations_extractor = ActivationsExtractor(
            model=model, tokenizer=tokenizer, target_layers=target_layers)
        self.autoencoders_dict = autoencoders_dict
        self.autoencoder_manager = AutoencoderManager(
            model=model, tokenizer=tokenizer, autoencoders_dict=autoencoders_dict
        )
        self.training_data = training_data

    def compute_training_points_single_batch(self, single_batch: list[TrainingPoint]):
        positive_samples = [item.trimmed_positive_example for item in single_batch]
        negative_samples = [item.trimmed_negative_example for item in single_batch]
        neutral_samples = [item.trimmed_neutral_example for item in single_batch]
    
        positive_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(
            positive_samples, target_token_only=True, flatten=True
        )
        negative_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(
            negative_samples, target_token_only=True, flatten=True
        )

        neutral_activations_list = self.activations_extractor.compute_activations_from_text_tokens_ids_target(
            neutral_samples, target_token_only=True, flatten=True
        )

        positive_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(positive_activations_list)
        negative_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(negative_activations_list)
        neutral_dictionary_features = self.autoencoder_manager.get_all_dictionary_features_for_list(neutral_activations_list)

        assert (
            len(positive_samples) == len(negative_samples) == len(neutral_samples) == 
            len(positive_activations_list) == len(negative_activations_list) ==
            len(positive_dictionary_features) == len(negative_dictionary_features) == len(neutral_dictionary_features)
        ), "All samples, activations and dict features should align in length."

        linear_probe_training_batch = []
        zipped_point_and_features = zip(single_batch, positive_dictionary_features, negative_dictionary_features, neutral_dictionary_features)

        for training_point, positive_features, negative_features, neutral_features in zipped_point_and_features:
            linear_probe_training_point = LinearProbeTrainingPoint(
                training_point=training_point,
                positive_token_ae_features=positive_features, 
                negative_token_ae_features=negative_features,
                neutral_token_ae_features=neutral_features,
                # Positive token
                target_positive_token_id=training_point.target_positive_token_id,
                target_positive_token=training_point.target_positive_token,
                # Negative token
                target_negative_token_id=training_point.target_negative_token_id,
                target_negative_token=training_point.target_negative_token,
                # Neutral token
                target_neutral_token_id=training_point.target_neutral_token_id,
                target_neutral_token=training_point.target_neutral_token,
            )
            linear_probe_training_batch.append(linear_probe_training_point)

        return linear_probe_training_batch

    def construct_training_dataset(self, all_training_data: list[TrainingPoint], option: str = None):
        """
        Constructs final training dataset, consisting of input and expected output. 
        """
        all_results = []
        for index, single_batch in tqdm_notebook(enumerate(batch(all_training_data))):
            if index % 250 == 0:
                print(f'Clearing cuda cache on batch {index}')
                clear_gpu_memory()
    
            current_results = self.compute_training_points_single_batch(single_batch)
            all_results.extend(current_results)

        assert len(all_results) == len(all_training_data), "Not all training points were converted to probe inputs!"

        return all_results

In [34]:
sample_training_data = successful_training_points[:5]

lp_training_data_manager = LinearProbeTrainingDataManager(
    training_data=sample_training_data, autoencoders_dict=rlhf_small,
    model=model, tokenizer=tokenizer, target_layers = model_customizer.get_target_layers()    
)

In [35]:
sample_probe_inputs = lp_training_data_manager.construct_training_dataset(sample_training_data)
x =[input_point.target_positive_token_id for input_point in sample_probe_inputs]
tokenizer.decode(x)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for index, single_batch in tqdm_notebook(enumerate(batch(all_training_data))):


0it [00:00, ?it/s]

Clearing cuda cache on batch 0
Took 0.28 seconds to clear cache.


' inspired great interest good good'

In [36]:
full_training_dataset = lp_training_data_manager.construct_training_dataset(successful_training_points)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for index, single_batch in tqdm_notebook(enumerate(batch(all_training_data))):


0it [00:00, ?it/s]

Clearing cuda cache on batch 0
Took 0.27 seconds to clear cache.
Clearing cuda cache on batch 250
Took 0.27 seconds to clear cache.
Clearing cuda cache on batch 500
Took 0.27 seconds to clear cache.
Clearing cuda cache on batch 750
Took 0.28 seconds to clear cache.
Clearing cuda cache on batch 1000
Took 0.28 seconds to clear cache.
Clearing cuda cache on batch 1250
Took 0.29 seconds to clear cache.
Clearing cuda cache on batch 1500
Took 0.28 seconds to clear cache.
Clearing cuda cache on batch 1750
Took 0.3 seconds to clear cache.
Clearing cuda cache on batch 2000
Took 0.28 seconds to clear cache.
Clearing cuda cache on batch 2250
Took 0.29 seconds to clear cache.
Clearing cuda cache on batch 2500
Took 0.29 seconds to clear cache.
Clearing cuda cache on batch 2750
Took 0.3 seconds to clear cache.
Clearing cuda cache on batch 3000
Took 0.31 seconds to clear cache.
Clearing cuda cache on batch 3250
Took 0.31 seconds to clear cache.
Clearing cuda cache on batch 3500
Took 0.3 seconds to cl

In [37]:
def save_training_dataset_to_wandb(training_dataset: list[LinearProbeTrainingPoint]):
    out_filename = "training_dataset.pkl"

    with open(out_filename, "wb") as f_out:
        pickle.dump(training_dataset, f_out)
    
    my_artifact = wandb.Artifact(f"linear_probe_training_dataset_{policy_model_name}", type="data")
    
    # Add the list to the artifact
    my_artifact.add_file(local_path=out_filename, name="linear_probe_training_dataset")

    metadata_dict = {
        "description": "Training dataset, with activations and rewards",
        "source": "Generated by my script",
        "num_examples": len(training_dataset),
        "split": "full"
    }

    my_artifact.metadata.update(metadata_dict)

    # Log the artifact to the run
    wandb.log_artifact(my_artifact)

save_training_dataset_to_wandb(full_training_dataset)

In [None]:
class LinearProbeTrainer:
    def __init__(self, linear_probe_training_data_manager):
        """
        Initialize
        """
    def train_linear_probe(self):
        """
        Initialize
        """

    def assign_reward(self, training_example):
        """
        Assigns reward to aver
        """

    def compute_divergence(self, training_example):
        """
        """

    def average_reward(self, token):
        """
        """

    def average_divergence(self, token):
        """
        """