In [58]:
import json
import os
import pickle
import pprint
import random

from collections import Counter, defaultdict
from dataclasses import dataclass
from nltk.sentiment.vader import SentimentIntensityAnalyzer

import nltk
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import wandb
import torch.nn as nn

from matplotlib import pyplot
from scipy.sparse import csr_matrix, vstack
from scipy.stats import kendalltau
from sklearn.linear_model import Ridge

from torch import Tensor
from tqdm import tqdm_notebook

In [59]:
nltk.download('vader_lexicon')
sentiment_analyzer = SentimentIntensityAnalyzer()
lexicon = sentiment_analyzer.lexicon

min_vader_value = min(lexicon.values())
max_vader_value = max(lexicon.values())

[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /Users/rauno/nltk_data...
[nltk_data]   Package vader_lexicon is already up-to-date!


In [60]:
filename = 'linear_probe_training_dataset'
project_name = 'utility_reconstruction'

random_seed = 42

os.environ['WANDB_API_KEY'] = ''

In [61]:
run = wandb.init(project=f'{project_name}')
wandb.run.config['random_seed'] = random_seed



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167844911364631, max=1.0…

In [62]:
def random_split_list(lst, split_ratio=0.8, seed=random_seed):
    if seed is not None:
        random.seed(seed)
    
    shuffled_list = lst[:]
    random.shuffle(shuffled_list)
    
    split_index = int(len(shuffled_list) * split_ratio)
    return shuffled_list[:split_index], shuffled_list[split_index:]

In [69]:
def load_linear_probe_training_dataset(policy_model_name="pythia_160m_utility_reward", project_name=project_name, version='v0'):
    artifact_path = f'linear_probe_training_dataset_{policy_model_name}:{version}'
    
    artifact = run.use_artifact(
        f'nlp_and_interpretability/{project_name}/{artifact_path}', type='data'
    )
    artifact_dir = artifact.download()

    with open(f'artifacts/{artifact_path}/{filename}', 'rb') as f_in:
        training_dataset = pickle.load(f_in)

    return training_dataset

In [70]:
@dataclass
class TextTokensIdsTarget:
    attention_mask: list[int]
    text: str
    tokens: list[str]
    ids: list[int]
    target_token: str
    target_token_id: int
    target_token_position: int

    @staticmethod
    def get_tensorized(datapoints: "TextTokensIdsTarget"):
        max_length = max([len(datapoint.tokens) for datapoint in datapoints])
        
        input_ids = [datapoint.ids for datapoint in datapoints]
        attention_masks = [datapoint.attention_mask for datapoint in datapoints]

        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])
        attention_masks_padded = pad_list_of_lists(attention_masks, 0)
        all_tokenized = {
            "input_ids": torch.IntTensor(input_ids_padded).cuda(), "attention_mask": torch.ByteTensor(attention_masks_padded).cuda()
        }
        return all_tokenized

class TrainingPoint:

    def __init__(self, input_dict: dict, tokenizer=None):
        self.input_dict = input_dict
        self.positive_text = input_dict['input_text']
        self.negative_text = input_dict['output_text']
        self.neutral_text = input_dict['neutral_text']
        
        # Dictionary of layer name to activations by mlp layer.
        self.activations: dict = None

        # Dictionary of layer name to autoencoder feature by mlp layer
        self.autoencoder_feature: dict = None

        # Reward value of target_token.
        self.target_positive_reward = None
        self.target_negative_reward = None

        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)
        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)
        
        self.positive_words = input_dict['positive_words']
        self.negative_words = list(input_dict['new_words'].values())
        self.neutral_words = list(input_dict['neutral_words'].values())

        self.target_positive_reward = None
        self.target_positive_token = None
        self.target_positive_token_id = None
    
        self.target_negative_reward = None
        self.target_negative_token = None
        self.target_negative_token_id = None

        self.target_neutral_token = None
        self.target_neutral_token_id = None

        try:
            self.trimmed_positive_example: "TextTokensIdTarget" = trim_example(self.positive_text, self.positive_words)
            if self.trimmed_positive_example:
                positive_token = self.trimmed_positive_example.target_token.strip().lower()
                self.target_positive_reward = lexicon.get(positive_token, None)
                self.target_positive_token = positive_token
                self.target_positive_token_id = self.trimmed_positive_example.target_token_id
        
        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for positive example.')
            self.trimmed_positive_example = None
        
        try:
            self.trimmed_negative_example: "TextTokensIdTarget" = trim_example(self.negative_text, self.negative_words)
            if self.trimmed_negative_example:
                negative_token = self.trimmed_negative_example.target_token.strip().lower()
                self.target_negative_reward = lexicon.get(negative_token, None)
                self.target_negative_token = negative_token
                self.target_negative_token_id = self.trimmed_negative_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for negative example.')
            self.trimmed_negative_example = None

        try:
            self.trimmed_neutral_example: "TextTokensIdTarget" = trim_example(self.neutral_text, self.neutral_words)
            if self.trimmed_neutral_example:
                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()
                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for neutral example.')
            self.trimmed_neutral_example = None

    def __str__(self):
        return pprint.pformat(self.__dict__)


class LinearProbeTrainingPoint:
    def __init__(
        self, training_point: "TrainingPoint",
        # positive token
        target_positive_token_id: int,
        target_positive_token: str,
        positive_token_ae_features: [str, Tensor], 
        # negative token
        target_negative_token_id: int,
        target_negative_token: str,
        negative_token_ae_features: [str, Tensor],
        # neutral token
        target_neutral_token_id: int,
        target_neutral_token: str,
        neutral_token_ae_features: [str, Tensor]
    ):
        self.training_point: "TrainingPoint" = training_point

        self.target_positive_token = target_positive_token
        self.target_positive_token_id = target_positive_token_id
        self.target_positive_reward = self.training_point.target_positive_reward
        self.positive_token_ae_features = positive_token_ae_features

        self.target_negative_token = target_negative_token
        self.target_negative_token_id = target_negative_token_id
        self.target_negative_reward = self.training_point.target_negative_reward
        self.negative_token_ae_features = negative_token_ae_features

        self.target_neutral_token = target_neutral_token
        self.target_neutral_token_id = target_neutral_token_id
        self.neutral_token_ae_features = neutral_token_ae_features

    def __str__(self):
        return pprint.pformat(self.__dict__)


In [71]:
full_training_dataset = load_linear_probe_training_dataset()
_, test_split_dataset = random_split_list(full_training_dataset)

[34m[1mwandb[0m: Downloading large artifact linear_probe_training_dataset_pythia_160m_utility_reward:v0, 50.56MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5


In [72]:
print(test_split_dataset[0])

{'negative_token_ae_features': {'gpt_neox.layers.10.mlp': <1x768 sparse matrix of type '<class 'numpy.float64'>'
	with 50 stored elements in Compressed Sparse Row format>,
                                'gpt_neox.layers.11.mlp': <1x768 sparse matrix of type '<class 'numpy.float64'>'
	with 23 stored elements in Compressed Sparse Row format>,
                                'gpt_neox.layers.7.mlp': <1x768 sparse matrix of type '<class 'numpy.float64'>'
	with 18 stored elements in Compressed Sparse Row format>,
                                'gpt_neox.layers.8.mlp': <1x768 sparse matrix of type '<class 'numpy.float64'>'
	with 34 stored elements in Compressed Sparse Row format>,
                                'gpt_neox.layers.9.mlp': <1x768 sparse matrix of type '<class 'numpy.float64'>'
	with 31 stored elements in Compressed Sparse Row format>},
 'neutral_token_ae_features': {'gpt_neox.layers.10.mlp': <1x768 sparse matrix of type '<class 'numpy.float64'>'
	with 48 stored elements in Co

In [95]:
# Define a custom linear probe class
class RandomInitializedLinearProbe(nn.Module):
    def __init__(self, input_dim, output_dim, random_state=None):
        super(RandomInitializedLinearProbe, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.random_state = random_state
        self.random_initialize()

    def random_initialize(self):
        if self.random_state is not None:
            torch.manual_seed(self.random_state)
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, x):
        return self.linear(x)

In [96]:
class FeatureConstructor:

    def construct_feature_representation(self, linear_probe_inputs):
        feature_rep = np.array([point.features.toarray().flatten() for point in linear_probe_inputs])
        return feature_rep

feature_constructor = FeatureConstructor()

In [97]:
@dataclass
class LinearProbeFinalInput:
    token: str
    token_id: int
    divergence: float     # Divergence of the token to neutral token
    features: csr_matrix  # Corresponds to the features of positive or negative token
    point_type: str    # Can be positive or negative
    source_training_point: LinearProbeTrainingPoint  #In case we need to inspect/retrieve original features

    def __str__(self):
        return pprint.pformat(self.__dict__)

    def __repr__(self):
        return str(self)

In [98]:
def euclidean_distance(matrix1: csr_matrix, matrix2: csr_matrix):
    # Convert CSR matrices to dense arrays for cdist
    dense_matrix1 = matrix1.toarray().flatten()
    dense_matrix2 = matrix2.toarray().flatten()

    # Compute Euclidean distance using cdist
    distance = np.linalg.norm(dense_matrix1 - dense_matrix2)

    return distance


def concantenate_matrices(layer_to_csr_dict):
    """
    Given a dictionary of layername_to_features matrices, this flattens and concatenates
    the matrices, in canoncial sorted order of the dictionary keys (the layernames).
    """
    sorted_matrices = [
        layer_to_csr_dict[key] for key in sorted(layer_to_csr_dict.keys())
    ]
    concatenated_matrix = vstack(sorted_matrices)
    return concatenated_matrix


def map_lp_training_point_to_pair_of_lp_final_inputs(lp_training_point: LinearProbeTrainingPoint) -> list[LinearProbeFinalInput]:
    positive_features = concantenate_matrices(
        lp_training_point.positive_token_ae_features)

    negative_features = concantenate_matrices(
        lp_training_point.negative_token_ae_features)

    neutral_features = concantenate_matrices(
        lp_training_point.neutral_token_ae_features)

    positive_token = lp_training_point.target_positive_token
    positive_token_id = lp_training_point.target_positive_token_id
    positive_divergence = euclidean_distance(
        positive_features, neutral_features
    )

    # Positive input training example.
    positive_probe_final_input = LinearProbeFinalInput(
        token=positive_token, token_id=positive_token_id,
        divergence=positive_divergence, features=positive_features,
        point_type='positive',
        source_training_point = lp_training_point
    )

    negative_token = lp_training_point.target_negative_token
    negative_token_id = lp_training_point.target_negative_token_id
    negative_divergence = euclidean_distance(
        negative_features, neutral_features
    )

    # Negative input training example - multiply divergence by minus one.
    negative_probe_final_input = LinearProbeFinalInput(
        token=negative_token, token_id=negative_token_id,
        divergence=-1*negative_divergence, features=negative_features,
        point_type='negative',
        source_training_point = lp_training_point
    )

    return [positive_probe_final_input, negative_probe_final_input]



def map_lp_dataset_to_final_input_dataset(
    input_dataset) -> list[LinearProbeFinalInput]:

    final_dataset = []

    for datapoint in tqdm_notebook(input_dataset):
        positive_point, negative_point = map_lp_training_point_to_pair_of_lp_final_inputs(datapoint)
        final_dataset.append(positive_point)
        final_dataset.append(negative_point)

    return final_dataset


inputs: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(
    input_dataset = test_split_dataset
)

X = feature_constructor.construct_feature_representation(inputs)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for datapoint in tqdm_notebook(input_dataset):


  0%|          | 0/1120 [00:00<?, ?it/s]

In [99]:
input = torch.Tensor(X)
input_dim = input.size(1)
output_dim = 1

model = RandomInitializedLinearProbe(input_dim, output_dim, random_state=random_seed)
for name, param in model.state_dict().items():
    if 'weight' in name:
        print(f"Weight: {name}\n{param}\n")
    elif 'bias' in name:
        print(f"Bias: {name}\n{param}\n")


# Make predictions
model.eval()
with torch.no_grad():
    predictions = model(torch.Tensor(X))

Weight: linear.weight
tensor([[ 3.0217e-02,  3.2805e-02, -9.2592e-03,  3.6307e-02, -8.6597e-03,
          7.9754e-03, -1.9242e-02,  2.3211e-02,  3.4842e-02, -2.8995e-02,
          3.4354e-02,  7.3972e-03,  2.9200e-02,  5.3527e-03,  1.9058e-02,
         -5.5803e-03,  3.0468e-02,  5.8419e-03, -1.8451e-02,  1.0074e-02,
         -1.8210e-02, -4.6350e-03, -1.6053e-02,  2.6219e-02, -3.1199e-02,
         -1.8221e-02, -1.1160e-02, -2.3764e-02,  3.7303e-03, -3.9036e-02,
          3.5694e-02, -3.3574e-02,  3.0513e-02,  6.5774e-03, -1.2833e-02,
          2.4423e-02,  6.1597e-03,  3.1933e-02,  4.3207e-03, -1.2465e-02,
          1.0619e-02, -1.0718e-02,  1.6634e-02,  3.5287e-02,  2.2847e-02,
         -1.7279e-02,  2.2815e-02,  7.0718e-03,  2.0071e-02, -2.4090e-02,
         -3.9124e-02, -1.5270e-02, -3.0315e-02,  3.2430e-02,  1.1384e-02,
          1.6371e-02,  1.2500e-02, -6.8755e-04,  3.0931e-02, -2.8082e-02,
          2.4885e-03, -2.6976e-02,  1.2187e-02, -1.3611e-02,  1.2111e-02,
         -8.2343

In [108]:
torch.set_printoptions(threshold=torch.inf)
print(predictions)

tensor([[ 1.8343],
        [ 1.4013],
        [ 3.2289],
        [ 1.2013],
        [ 1.4986],
        [ 2.2743],
        [ 1.7920],
        [ 1.7221],
        [ 0.9280],
        [ 1.2682],
        [ 2.0106],
        [ 2.1503],
        [ 1.6616],
        [ 2.2675],
        [ 2.0423],
        [ 1.9260],
        [ 1.9395],
        [ 1.3860],
        [ 2.5446],
        [ 2.4690],
        [ 2.1716],
        [ 2.4079],
        [ 1.0744],
        [ 1.3617],
        [ 0.6591],
        [ 0.6624],
        [ 2.4725],
        [ 2.3156],
        [ 1.3164],
        [ 1.9665],
        [ 2.1603],
        [ 2.3257],
        [ 1.5927],
        [ 1.6942],
        [ 1.4967],
        [ 1.5930],
        [ 1.2776],
        [ 1.3186],
        [ 2.1765],
        [ 1.2932],
        [ 2.3839],
        [ 2.4636],
        [ 1.8193],
        [ 0.9052],
        [ 1.7501],
        [ 1.9115],
        [ 1.2495],
        [ 1.2872],
        [ 1.4936],
        [ 1.4872],
        [ 1.8352],
        [ 1.8162],
        [ 1.

In [107]:
print([i.target_negative_reward for i in test_split_dataset])

[-1.7, -2.0, -2.7, -2.9, -2.1, -2.1, -3.4, -2.5, -1.5, -1.9, -1.5, -2.1, -2.1, -2.7, -1.6, -1.9, -1.6, -2.1, -3.1, -2.8, -3.2, -1.9, -1.5, -3.1, -3.1, -1.7, -2.1, -3.1, -2.5, -2.1, -2.1, -1.8, -1.3, -3.1, -2.5, -2.1, -1.6, -1.9, -3.1, -1.9, -1.1, -2.5, -1.6, -2.1, -2.5, -1.1, -3.1, -2.2, -1.9, -1.6, -2.1, -2.1, -3.1, -1.3, -2.1, -2.0, -2.3, -3.1, -3.1, -2.5, -1.9, -1.3, -2.4, -3.2, 2.2, -1.9, -2.6, -2.5, -3.2, -2.7, -3.2, -1.9, -2.7, -2.7, -3.2, -0.2, -2.1, -2.2, -1.6, -2.1, -1.1, -2.1, -2.3, -2.5, -1.3, -3.2, -2.1, -1.5, -2.5, -2.3, -2.2, -1.3, -3.2, -1.9, -3.4, -2.1, -3.1, -2.0, -2.5, -1.9, -2.2, -2.1, -3.1, -2.0, -1.3, -2.1, -2.1, -3.0, -3.1, -1.9, -3.1, -2.5, -3.1, -3.4, -2.0, -1.3, -3.1, -0.8, -1.8, -1.6, -2.1, -2.0, -2.1, -2.5, -2.1, -2.1, -1.8, -1.9, -2.5, -1.8, -3.1, -2.7, -1.6, -1.9, -2.8, -1.9, -1.5, -1.3, -2.5, -2.1, -2.1, -2.0, -2.1, -3.1, -2.7, -2.1, -1.6, -1.9, -1.3, -3.1, -1.3, -2.1, -2.7, -2.5, -2.5, -2.1, -2.1, -2.0, -3.1, -2.1, -2.1, -2.3, -2.7, -1.6, -1.9, -1.9, -1.3

In [101]:
def clamp(number, min_value, max_value):
    return max(min(number, max_value), min_value)


def calculate_average_values(list_of_dicts):
    # Use defaultdict to simplify code
    token_sum = defaultdict(float)
    token_count = defaultdict(int)

    # Accumulate sums and counts
    for d in list_of_dicts:
        for token, value in d.items():
            token_sum[token] += value
            token_count[token] += 1

    # Calculate average values using a dictionary comprehension
    average_values = {token: round(token_sum[token] / token_count[token], 3) for token in token_sum}

    return average_values


def rescale_value(value, values_list, new_min=min_vader_value, new_max=max_vader_value):
    percentile_range = 90

    old_max = np.percentile(values_list, percentile_range)
    old_min = np.percentile(values_list, 100 - percentile_range)
    
    # First, normalize the value to a range between 0 and 1
    normalized_value = (value - old_min) / (old_max - old_min)
    
    # Then, scale the normalized value to the new range
    new_value = normalized_value * (new_max - new_min) + new_min

    new_value = clamp(new_value, new_min, new_max)
    if isinstance(new_value, torch.Tensor):
        new_value = new_value.item()

    return round(new_value, 2)


def scale_values_and_input_list_to_range(values_and_input_list, min_range=min_vader_value, max_range=max_vader_value):
    all_probe_values = []
    all_lp_inputs = []
    all_tokens = []

    for values_and_inputs in values_and_input_list:
        fitted_value = values_and_inputs[0]
        lp_input = values_and_inputs[1]
        token = lp_input.token

        all_probe_values.append(fitted_value)
        all_lp_inputs.append(lp_input)
        all_tokens.append(token)

    rescaled_token_to_value_dict_list = [{input.token: rescale_value(value, all_probe_values)} for value, input in values_and_input_list]
    
    return rescaled_token_to_value_dict_list, all_tokens, all_probe_values


predictions_and_inputs = list(zip(predictions, inputs))
rescaled_token_to_value_dict_list_on_test, all_test_tokens, all_test_probe_values = scale_values_and_input_list_to_range(
    values_and_input_list=predictions_and_inputs
)
averaged_token_values_on_test = calculate_average_values(rescaled_token_to_value_dict_list_on_test)

In [111]:
all_positive_test_tokens = [token for token in all_test_tokens if lexicon.get(token, 0) > 0]
all_negative_test_tokens = [token for token in all_test_tokens if lexicon.get(token, 0) < 0]

full_original_token_values = {token: lexicon[token] for token in averaged_token_values_on_test}

reconstructed_ranking = sorted(
    [(token, value) for token, value in averaged_token_values_on_test.items() if token in all_negative_test_tokens],
    key = lambda x: x[0]
)
print(reconstructed_ranking)

original_ranking = sorted(
    [(token, lexicon[token]) for token in averaged_token_values_on_test if token in all_negative_test_tokens],
    key = lambda x: x[0]
)
print(original_ranking)

reconstructed_ranking_tokens_only = [item[1] for item in reconstructed_ranking]
original_ranking_tokens_only = [item[1] for item in original_ranking]

tau, p_value = kendalltau(original_ranking_tokens_only, reconstructed_ranking_tokens_only)

print(f'Kendall Tau value: {tau}')
print(f'p-value: {p_value}')

[('annoyance', 0.38), ('annoyed', 3.4), ('ashamed', 2.35), ('awful', -2.232), ('bad', -1.483), ('betrayal', 2.55), ('bitter', -1.232), ('bitterly', -3.9), ('blame', 1.98), ('block', -2.37), ('bored', -0.057), ('boring', 0.154), ('coward', -3.47), ('cried', -0.014), ('cries', 3.4), ('criticism', 3.4), ('criticize', 0.415), ('cruel', -2.844), ('cruelty', 0.34), ('cry', 0.95), ('crying', 1.313), ('curse', -0.896), ('cynical', -0.02), ('dangerous', -1.055), ('deceit', -1.931), ('defeat', 2.91), ('depressing', 1.466), ('despair', 3.307), ('devastated', 3.03), ('difficult', 3.165), ('difficulty', 1.79), ('dirty', -2.005), ('disadvantage', -2.59), ('disappoint', 1.307), ('disappointed', 3.049), ('disappointing', -0.105), ('disappointment', 3.4), ('disaster', -0.605), ('disasters', 0.73), ('disastrous', -1.676), ('discourage', 2.055), ('discouraged', 3.24), ('disdain', 1.83), ('disgrace', 2.333), ('disgust', 1.035), ('dislike', 1.697), ('distracted', -2.6), ('distressed', 2.32), ('distrust', 2