In [None]:
import os
import pickle
import pprint
import random

from collections import Counter
from dataclasses import dataclass

import numpy as np
import wandb

from scipy.sparse import csr_matrix, vstack
from sklearn.linear_model import LinearRegression

from torch import Tensor
from tqdm import tqdm_notebook

In [None]:
filename = 'linear_probe_training_dataset'
policy_model_name = 'gpt_neo_125m_utility_reward'
project_name = 'utility_reconstruction'
version = 'v1'
random_seed = 42
wandb_api_key = 'YOUR_KEY_HERE'

os.environ['WANDB_API_KEY'] = wandb_api_key

### Randomization and other utilities.

In [None]:
def random_split_list(lst, split_ratio=0.8, seed=random_seed):
    if seed is not None:
        random.seed(seed)
    
    shuffled_list = lst[:]
    random.shuffle(shuffled_list)
    
    split_index = int(len(shuffled_list) * split_ratio)
    return shuffled_list[:split_index], shuffled_list[split_index:]

random_split_list([1,2,3,4,5,6,7,8,9,10])

In [None]:
def concantenate_matrices(layer_to_csr_dict):
    """
    Given a dictionary of layername_to_features matrices, this flattens and concatenates
    the matrices, in canoncial sorted order of the dictionary keys (the layernames).
    """
    sorted_matrices = [
        layer_to_csr_dict[key] for key in sorted(layer_to_csr_dict.keys())
    ]
    concatenated_matrix = vstack(sorted_matrices)
    return concatenated_matrix

In [None]:
def euclidean_distance(matrix1: csr_matrix, matrix2: csr_matrix):
    # Convert CSR matrices to dense arrays for cdist
    dense_matrix1 = matrix1.toarray().flatten()
    dense_matrix2 = matrix2.toarray().flatten()

    # Compute Euclidean distance using cdist
    distance = np.linalg.norm(dense_matrix1 - dense_matrix2)

    return distance

def euclidean_distance_bw_dicts_of_csr_matrices(
    matrix_dict_1: dict[str, csr_matrix], matrix_dict_2: dict[str, csr_matrix]):

    feature_matrix_1 = concantenate_matrices(matrix_dict_1)
    feature_matrix_2 = concantenate_matrices(matrix_dict_2)

    return euclidean_distance(feature_matrix_1, feature_matrix_2)

### Load artifact from wandb

In [None]:
run = wandb.init(project=f'{project_name}_{policy_model_name}')

In [None]:
wandb.run.config['random_seed'] = random_seed

In [None]:
def load_linear_probe_training_dataset(policy_model_name=policy_model_name, project_name=project_name, version=version):
    artifact_path = f'linear_probe_training_dataset_{policy_model_name}:{version}'
    
    artifact = run.use_artifact(
        f'nlp_and_interpretability/{project_name}/{artifact_path}', type='data'
    )
    artifact_dir = artifact.download()

    with open(f'artifacts/{artifact_path}/{filename}', 'rb') as f_in:
        training_dataset = pickle.load(f_in)

    return training_dataset


In [None]:
@dataclass
class TextTokensIdsTarget:
    attention_mask: list[int]
    text: str
    tokens: list[str]
    ids: list[int]
    target_token: str
    target_token_id: int
    target_token_position: int

    @staticmethod
    def get_tensorized(datapoints: "TextTokensIdsTarget"):
        max_length = max([len(datapoint.tokens) for datapoint in datapoints])
        
        input_ids = [datapoint.ids for datapoint in datapoints]
        attention_masks = [datapoint.attention_mask for datapoint in datapoints]

        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])
        attention_masks_padded = pad_list_of_lists(attention_masks, 0)
        all_tokenized = {
            "input_ids": torch.IntTensor(input_ids_padded).cuda(), "attention_mask": torch.ByteTensor(attention_masks_padded).cuda()
        }
        return all_tokenized



### Source training point in the wandb artifact
class TrainingPoint:

    def __init__(self, input_dict: dict, tokenizer=None):
        self.input_dict = input_dict
        self.positive_text = input_dict['input_text']
        self.negative_text = input_dict['output_text']
        self.neutral_text = input_dict['neutral_text']
        
        # Dictionary of layer name to activations by mlp layer.
        self.activations: dict = None

        # Dictionary of layer name to autoencoder feature by mlp layer
        self.autoencoder_feature: dict = None

        # Reward value of target_token.
        self.target_positive_reward = None
        self.target_negative_reward = None

        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)
        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)
        
        self.positive_words = input_dict['positive_words']
        self.negative_words = list(input_dict['new_words'].values())
        self.neutral_words = list(input_dict['neutral_words'].values())

        self.target_positive_reward = None
        self.target_positive_token = None
        self.target_positive_token_id = None
    
        self.target_negative_reward = None
        self.target_negative_token = None
        self.target_negative_token_id = None

        self.target_neutral_token = None
        self.target_neutral_token_id = None

        try:
            self.trimmed_positive_example: "TextTokensIdTarget" = trim_example(self.positive_text, self.positive_words)
            if self.trimmed_positive_example:
                positive_token = self.trimmed_positive_example.target_token.strip().lower()
                self.target_positive_reward = lexicon.get(positive_token, None)
                self.target_positive_token = positive_token
                self.target_positive_token_id = self.trimmed_positive_example.target_token_id
        
        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for positive example.')
            self.trimmed_positive_example = None
        
        try:
            self.trimmed_negative_example: "TextTokensIdTarget" = trim_example(self.negative_text, self.negative_words)
            if self.trimmed_negative_example:
                negative_token = self.trimmed_negative_example.target_token.strip().lower()
                self.target_negative_reward = lexicon.get(negative_token, None)
                self.target_negative_token = negative_token
                self.target_negative_token_id = self.trimmed_negative_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for negative example.')
            self.trimmed_negative_example = None

        try:
            self.trimmed_neutral_example: "TextTokensIdTarget" = trim_example(self.neutral_text, self.neutral_words)
            if self.trimmed_neutral_example:
                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()
                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for neutral example.')
            self.trimmed_neutral_example = None

    def __str__(self):
        return pprint.pformat(self.__dict__)


class LinearProbeTrainingPoint:
    def __init__(
        self, training_point: "TrainingPoint",
        # positive token
        target_positive_token_id: int,
        target_positive_token: str,
        positive_token_ae_features: [str, Tensor], 
        # negative token
        target_negative_token_id: int,
        target_negative_token: str,
        negative_token_ae_features: [str, Tensor],
        # neutral token
        target_neutral_token_id: int,
        target_neutral_token: str,
        neutral_token_ae_features: [str, Tensor]
    ):
        self.training_point: "TrainingPoint" = training_point

        self.target_positive_token = target_positive_token
        self.target_positive_token_id = target_positive_token_id
        self.target_positive_reward = self.training_point.target_positive_reward
        self.positive_token_ae_features = positive_token_ae_features

        self.target_negative_token = target_negative_token
        self.target_negative_token_id = target_negative_token_id
        self.target_negative_reward = self.training_point.target_negative_reward
        self.negative_token_ae_features = negative_token_ae_features

        self.target_neutral_token = target_neutral_token
        self.target_neutral_token_id = target_neutral_token_id
        self.neutral_token_ae_features = neutral_token_ae_features

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [None]:
full_training_dataset = load_linear_probe_training_dataset()

In [None]:
train_split_dataset, test_split_dataset = random_split_list(full_training_dataset)

### Define linear probe helper classes.

In [None]:
@dataclass
class LinearProbeFinalInput:
    token: str
    token_id: int
    divergence: float     # Divergence of the token to neutral token
    features: csr_matrix  # Corresponds to the features of positive or negative token
    point_type: str    # Can be positive or negative

    def __str__(self):
        return pprint.pformat(self.__dict__)

    def __repr__(self):
        return str(self)

### Construct training dataset and linear probe

In [None]:
def map_lp_training_point_to_pair_of_lp_final_inputs(lp_training_point: LinearProbeTrainingPoint) -> list[LinearProbeFinalInput]:
    positive_features = concantenate_matrices(
        lp_training_point.positive_token_ae_features)

    negative_features = concantenate_matrices(
        lp_training_point.negative_token_ae_features)

    neutral_features = concantenate_matrices(
        lp_training_point.neutral_token_ae_features)

    positive_token = lp_training_point.target_positive_token
    positive_token_id = lp_training_point.target_positive_token_id
    positive_divergence = euclidean_distance(
        positive_features, neutral_features
    )

    # Positive input training example.
    positive_probe_final_input = LinearProbeFinalInput(
        token=positive_token, token_id=positive_token_id,
        divergence=positive_divergence, features=positive_features,
        point_type='positive'
    )

    negative_token = lp_training_point.target_negative_token
    negative_token_id = lp_training_point.target_negative_token_id
    negative_divergence = euclidean_distance(
        negative_features, neutral_features
    )

    # Negative input training example - multiply divergence by minus one.
    negative_probe_final_input = LinearProbeFinalInput(
        token=negative_token, token_id=negative_token_id,
        divergence=-1*negative_divergence, features=negative_features,
        point_type='negative'
    )

    return [positive_probe_final_input, negative_probe_final_input]

In [None]:
test_point = train_split_dataset[4]

In [None]:
positive_test_point, negative_test_point = map_lp_training_point_to_pair_of_lp_final_inputs(test_point)

In [None]:
print(f'\nPositive point:\n{pprint.pformat(positive_test_point)}')
print(f'\nNegative point:\n{pprint.pformat(negative_test_point)}')

In [None]:
def map_lp_dataset_to_final_input_dataset(
    input_dataset: list[LinearProbeTrainingPoint]) -> list[LinearProbeFinalInput]:

    final_dataset = []

    for datapoint in tqdm_notebook(input_dataset):
        positive_point, negative_point = map_lp_training_point_to_pair_of_lp_final_inputs(datapoint)
        final_dataset.append(positive_point)
        final_dataset.append(negative_point)

    return final_dataset

In [None]:
mapped_train_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(
    input_dataset = train_split_dataset
)

mapped_test_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(
    input_dataset = test_split_dataset
)

In [None]:
class FeatureConstructor:

    def construct_feature_representation(self, linear_probe_inputs):
        feature_rep = np.array([point.features.toarray().flatten() for point in linear_probe_inputs])
        return feature_rep

feature_constructor = FeatureConstructor()

In [None]:
def train_linear_model(train_linear_probe_inputs: list[LinearProbeFinalInput], feature_constructor: FeatureConstructor = feature_constructor):
    input_points = feature_constructor.construct_feature_representation(train_linear_probe_inputs)

    output_points = np.array([point.divergence for point in train_linear_probe_inputs])

    print(f'Shapes are {input_points.shape} and {output_points.shape}')

    model = LinearRegression()
    model.fit(input_points, output_points)
    return model

In [None]:
linear_model = train_linear_model(train_linear_probe_inputs=mapped_train_split_dataset)

In [None]:
def get_fitted_values(linear_model, test_linear_probe_inputs, feature_constructor: FeatureConstructor = feature_constructor):
    """
    """
    test_inputs = feature_constructor.construct_feature_representation(test_linear_probe_inputs)
    test_values = linear_model.predict(test_inputs)
    return test_values

In [None]:
fitted_values = get_fitted_values(linear_model=linear_model, test_linear_probe_inputs=mapped_test_split_dataset)

In [None]:
fitted_values_and_inputs = list(zip(fitted_values, mapped_test_split_dataset))

In [None]:
fitted_values_and_inputs[15]

### Do analysis on divergence values viz-a-viz original Vader lexicon.