In [None]:
import json
import os
import pickle
import pprint
import random

from collections import Counter, defaultdict
from dataclasses import dataclass
from nltk.sentiment.vader import SentimentIntensityAnalyzer

import nltk
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import wandb

from matplotlib import pyplot
from scipy.sparse import csr_matrix, vstack
from scipy.stats import kendalltau
from sklearn.linear_model import Ridge

from torch import Tensor
from tqdm import tqdm_notebook

In [None]:
nltk.download('vader_lexicon')
sentiment_analyzer = SentimentIntensityAnalyzer()
lexicon = sentiment_analyzer.lexicon

min_vader_value = min(lexicon.values())
max_vader_value = max(lexicon.values())

In [None]:
filename = 'linear_probe_training_dataset'
model_name = 'gpt_neo_125m'
policy_model_name = f'{model_name}_utility_reward'
project_name = 'utility_reconstruction'

versions_dict = {"gpt_neo_125m": 'v1'}
version = versions_dict.get(model_name, 'v0')
random_seed = 42

os.environ['WANDB_API_KEY'] = ''

In [None]:
class ModelCustomizer:
    '''
    Used to customize model layer numbers and other network parsing details
    '''

    def __init__(self):
        '''
        Initialize
        '''
        self.target_layers = None

    def set_target_layers(self) -> list[str]:
        '''
        Set target layers
        '''

    def get_target_layers(self) -> list[str]:
        '''
        Get target layers.
        '''

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        '''
        Parse layer name to layer number
        '''

    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        '''
        Parse ae dict keys to full layer names.
        '''

In [None]:
class GPTNeoCustomizer(ModelCustomizer):

    def get_target_layers(self) -> list[str]:
        if self.target_layers:
            return self.target_layers
        else:
            return [self.layer_num_to_full_name(layer_no) for layer_no in range(12)]

    def set_target_layers(self, target_layers):
        self.target_layers = target_layers

    def layer_num_to_full_name(self, layer_no):
        return f'transformer.h.{layer_no}.mlp'

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        return layer_name.split('.')[-2]

    # Standardize layer names to full names instead of 'int'
    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        output_dict = {}
        for key, autoencoder in autoencoders_dict.items():
            output_dict[self.layer_num_to_full_name(key)] = autoencoder
        return output_dict

In [None]:
class PythiaCustomizer(ModelCustomizer):
    def __init__(self, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.target_layers = None

    def get_target_layers(self) -> list[str]:
        if self.target_layers:
            return self.target_layers
        else:
            return [self.layer_num_to_full_name(layer_no) for layer_no in range(self.num_layers)]

    def set_target_layers(self, target_layers):
        self.target_layers = target_layers

    def layer_num_to_full_name(self, layer_no):
        return f'gpt_neox.layers.{layer_no}.mlp'

    def parse_layer_name_to_layer_number(self, layer_name) -> str:
        return layer_name.split('.')[-2]

    # Standardize layer names to full names instead of 'int'
    def convert_ae_dict_keys(self, autoencoders_dict: [str, Tensor]):
        output_dict = {}
        for key, autoencoder in autoencoders_dict.items():
            output_dict[self.layer_num_to_full_name(key)] = autoencoder
        return output_dict

In [None]:
model_customizers = {
    "pythia_70m": PythiaCustomizer(num_layers=6),
    "pythia_160m": PythiaCustomizer(num_layers=12),
    "pythia_410m": PythiaCustomizer(num_layers=24),
    "gpt_neo_125m": GPTNeoCustomizer()
}
model_customizer = model_customizers[model_name]

### Randomization and other utilities.

In [None]:
def clamp(number, min_value, max_value):
    return max(min(number, max_value), min_value)

In [None]:
def calculate_average_values(list_of_dicts):
    # Use defaultdict to simplify code
    token_sum = defaultdict(float)
    token_count = defaultdict(int)

    # Accumulate sums and counts
    for d in list_of_dicts:
        for token, value in d.items():
            token_sum[token] += value
            token_count[token] += 1

    # Calculate average values using a dictionary comprehension
    average_values = {token: round(token_sum[token] / token_count[token], 3) for token in token_sum}

    return average_values

In [None]:
def rescale_value(value, values_list, new_min=min_vader_value, new_max=max_vader_value):
    percentile_range = 90

    old_max = np.percentile(values_list, percentile_range)
    old_min = np.percentile(values_list, 100 - percentile_range)
    
    # First, normalize the value to a range between 0 and 1
    normalized_value = (value - old_min) / (old_max - old_min)
    
    # Then, scale the normalized value to the new range
    new_value = normalized_value * (new_max - new_min) + new_min

    new_value = clamp(new_value, new_min, new_max)
    
    return round(new_value, 2)

In [None]:
def random_split_list(lst, split_ratio=0.8, seed=random_seed):
    if seed is not None:
        random.seed(seed)
    
    shuffled_list = lst[:]
    random.shuffle(shuffled_list)
    
    split_index = int(len(shuffled_list) * split_ratio)
    return shuffled_list[:split_index], shuffled_list[split_index:]

random_split_list([1,2,3,4,5,6,7,8,9,10])

In [None]:
def concantenate_matrices(layer_to_csr_dict):
    """
    Given a dictionary of layername_to_features matrices, this flattens and concatenates
    the matrices, in canoncial sorted order of the dictionary keys (the layernames).
    """
    sorted_matrices = [
        layer_to_csr_dict[key] for key in sorted(layer_to_csr_dict.keys())
    ]
    concatenated_matrix = vstack(sorted_matrices)
    return concatenated_matrix

In [None]:
def euclidean_distance(matrix1: csr_matrix, matrix2: csr_matrix):
    # Convert CSR matrices to dense arrays for cdist
    dense_matrix1 = matrix1.toarray().flatten()
    dense_matrix2 = matrix2.toarray().flatten()

    # Compute Euclidean distance using cdist
    distance = np.linalg.norm(dense_matrix1 - dense_matrix2)

    return distance

def euclidean_distance_bw_dicts_of_csr_matrices(
    matrix_dict_1: dict[str, csr_matrix], matrix_dict_2: dict[str, csr_matrix]):

    feature_matrix_1 = concantenate_matrices(matrix_dict_1)
    feature_matrix_2 = concantenate_matrices(matrix_dict_2)

    return euclidean_distance(feature_matrix_1, feature_matrix_2)

### Load artifact from wandb

In [None]:
run = wandb.init(project=f'{project_name}_{policy_model_name}')

In [None]:
wandb.run.config['random_seed'] = random_seed

In [None]:
def load_linear_probe_training_dataset(policy_model_name=policy_model_name, project_name=project_name, version=version):
    artifact_path = f'linear_probe_training_dataset_{policy_model_name}:{version}'
    
    artifact = run.use_artifact(
        f'nlp_and_interpretability/{project_name}/{artifact_path}', type='data'
    )
    artifact_dir = artifact.download()

    with open(f'artifacts/{artifact_path}/{filename}', 'rb') as f_in:
        training_dataset = pickle.load(f_in)

    return training_dataset

In [None]:
@dataclass
class TextTokensIdsTarget:
    attention_mask: list[int]
    text: str
    tokens: list[str]
    ids: list[int]
    target_token: str
    target_token_id: int
    target_token_position: int

    @staticmethod
    def get_tensorized(datapoints: "TextTokensIdsTarget"):
        max_length = max([len(datapoint.tokens) for datapoint in datapoints])
        
        input_ids = [datapoint.ids for datapoint in datapoints]
        attention_masks = [datapoint.attention_mask for datapoint in datapoints]

        input_ids_padded = pad_list_of_lists(input_ids, tokenizer.encode(tokenizer.pad_token)[0])
        attention_masks_padded = pad_list_of_lists(attention_masks, 0)
        all_tokenized = {
            "input_ids": torch.IntTensor(input_ids_padded).cuda(), "attention_mask": torch.ByteTensor(attention_masks_padded).cuda()
        }
        return all_tokenized

### Source training point in the wandb artifact
class TrainingPoint:

    def __init__(self, input_dict: dict, tokenizer=None):
        self.input_dict = input_dict
        self.positive_text = input_dict['input_text']
        self.negative_text = input_dict['output_text']
        self.neutral_text = input_dict['neutral_text']
        
        # Dictionary of layer name to activations by mlp layer.
        self.activations: dict = None

        # Dictionary of layer name to autoencoder feature by mlp layer
        self.autoencoder_feature: dict = None

        # Reward value of target_token.
        self.target_positive_reward = None
        self.target_negative_reward = None

        self.positive_text_tokens, self.positive_input_ids = get_tokens_and_ids(self.positive_text)
        self.negative_text_tokens, self.negative_token_ids = get_tokens_and_ids(self.negative_text)
        
        self.positive_words = input_dict['positive_words']
        self.negative_words = list(input_dict['new_words'].values())
        self.neutral_words = list(input_dict['neutral_words'].values())

        self.target_positive_reward = None
        self.target_positive_token = None
        self.target_positive_token_id = None
    
        self.target_negative_reward = None
        self.target_negative_token = None
        self.target_negative_token_id = None

        self.target_neutral_token = None
        self.target_neutral_token_id = None

        try:
            self.trimmed_positive_example: "TextTokensIdTarget" = trim_example(self.positive_text, self.positive_words)
            if self.trimmed_positive_example:
                positive_token = self.trimmed_positive_example.target_token.strip().lower()
                self.target_positive_reward = lexicon.get(positive_token, None)
                self.target_positive_token = positive_token
                self.target_positive_token_id = self.trimmed_positive_example.target_token_id
        
        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for positive example.')
            self.trimmed_positive_example = None
        
        try:
            self.trimmed_negative_example: "TextTokensIdTarget" = trim_example(self.negative_text, self.negative_words)
            if self.trimmed_negative_example:
                negative_token = self.trimmed_negative_example.target_token.strip().lower()
                self.target_negative_reward = lexicon.get(negative_token, None)
                self.target_negative_token = negative_token
                self.target_negative_token_id = self.trimmed_negative_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for negative example.')
            self.trimmed_negative_example = None

        try:
            self.trimmed_neutral_example: "TextTokensIdTarget" = trim_example(self.neutral_text, self.neutral_words)
            if self.trimmed_neutral_example:
                self.target_neutral_token = self.trimmed_neutral_example.target_token.strip().lower()
                self.target_neutral_token_id = self.trimmed_neutral_example.target_token_id

        except Exception as e:
            print(f'Caught exception {e} on {input_dict} for neutral example.')
            self.trimmed_neutral_example = None

    def __str__(self):
        return pprint.pformat(self.__dict__)


class LinearProbeTrainingPoint:
    def __init__(
        self, training_point: "TrainingPoint",
        # positive token
        target_positive_token_id: int,
        target_positive_token: str,
        positive_token_ae_features: [str, Tensor], 
        # negative token
        target_negative_token_id: int,
        target_negative_token: str,
        negative_token_ae_features: [str, Tensor],
        # neutral token
        target_neutral_token_id: int,
        target_neutral_token: str,
        neutral_token_ae_features: [str, Tensor]
    ):
        self.training_point: "TrainingPoint" = training_point

        self.target_positive_token = target_positive_token
        self.target_positive_token_id = target_positive_token_id
        self.target_positive_reward = self.training_point.target_positive_reward
        self.positive_token_ae_features = positive_token_ae_features

        self.target_negative_token = target_negative_token
        self.target_negative_token_id = target_negative_token_id
        self.target_negative_reward = self.training_point.target_negative_reward
        self.negative_token_ae_features = negative_token_ae_features

        self.target_neutral_token = target_neutral_token
        self.target_neutral_token_id = target_neutral_token_id
        self.neutral_token_ae_features = neutral_token_ae_features

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [None]:
full_training_dataset = load_linear_probe_training_dataset()

In [None]:
train_split_dataset, test_split_dataset = random_split_list(full_training_dataset)

### Define linear probe helper classes.

In [None]:
@dataclass
class LinearProbeFinalInput:
    token: str
    token_id: int
    divergence: float     # Divergence of the token to neutral token
    features: csr_matrix  # Corresponds to the features of positive or negative token
    point_type: str    # Can be positive or negative
    source_training_point: LinearProbeTrainingPoint  #In case we need to inspect/retrieve original features

    def __str__(self):
        return pprint.pformat(self.__dict__)

    def __repr__(self):
        return str(self)

### Construct training dataset and linear probe

In [None]:
def map_lp_training_point_to_pair_of_lp_final_inputs(lp_training_point: LinearProbeTrainingPoint) -> list[LinearProbeFinalInput]:
    positive_features = concantenate_matrices(
        lp_training_point.positive_token_ae_features)

    negative_features = concantenate_matrices(
        lp_training_point.negative_token_ae_features)

    neutral_features = concantenate_matrices(
        lp_training_point.neutral_token_ae_features)

    positive_token = lp_training_point.target_positive_token
    positive_token_id = lp_training_point.target_positive_token_id
    positive_divergence = euclidean_distance(
        positive_features, neutral_features
    )

    # Positive input training example.
    positive_probe_final_input = LinearProbeFinalInput(
        token=positive_token, token_id=positive_token_id,
        divergence=positive_divergence, features=positive_features,
        point_type='positive',
        source_training_point = lp_training_point
    )

    negative_token = lp_training_point.target_negative_token
    negative_token_id = lp_training_point.target_negative_token_id
    negative_divergence = euclidean_distance(
        negative_features, neutral_features
    )

    # Negative input training example - multiply divergence by minus one.
    negative_probe_final_input = LinearProbeFinalInput(
        token=negative_token, token_id=negative_token_id,
        divergence=-1*negative_divergence, features=negative_features,
        point_type='negative',
        source_training_point = lp_training_point
    )

    return [positive_probe_final_input, negative_probe_final_input]

In [None]:
test_point = train_split_dataset[4]

In [None]:
positive_test_point, negative_test_point = map_lp_training_point_to_pair_of_lp_final_inputs(test_point)

In [None]:
print(f'\nPositive point:\n{pprint.pformat(positive_test_point)}')
print(f'\nNegative point:\n{pprint.pformat(negative_test_point)}')

In [None]:
def map_lp_dataset_to_final_input_dataset(
    input_dataset: list[LinearProbeTrainingPoint]) -> list[LinearProbeFinalInput]:

    final_dataset = []

    for datapoint in tqdm_notebook(input_dataset):
        positive_point, negative_point = map_lp_training_point_to_pair_of_lp_final_inputs(datapoint)
        final_dataset.append(positive_point)
        final_dataset.append(negative_point)

    return final_dataset

In [None]:
mapped_train_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(
    input_dataset = train_split_dataset
)

mapped_test_split_dataset: list[LinearProbeFinalInput] = map_lp_dataset_to_final_input_dataset(
    input_dataset = test_split_dataset
)

In [None]:
class FeatureConstructor:

    def construct_feature_representation(self, linear_probe_inputs):
        feature_rep = np.array([point.features.toarray().flatten() for point in linear_probe_inputs])
        return feature_rep

feature_constructor = FeatureConstructor()

In [None]:
def train_linear_model(train_linear_probe_inputs: list[LinearProbeFinalInput], feature_constructor: FeatureConstructor = feature_constructor):
    input_points = feature_constructor.construct_feature_representation(train_linear_probe_inputs)

    output_points = np.array([point.divergence for point in train_linear_probe_inputs])

    print(f'Shapes are {input_points.shape} and {output_points.shape}')

    model = Ridge()
    wandb.run.summary['linear_model_type'] = 'Ridge'
    model.fit(input_points, output_points)
    return model

In [None]:
linear_model = train_linear_model(train_linear_probe_inputs=mapped_train_split_dataset)

In [None]:
def get_fitted_values(linear_model, test_linear_probe_inputs, feature_constructor: FeatureConstructor = feature_constructor):
    """
    """
    test_inputs = feature_constructor.construct_feature_representation(test_linear_probe_inputs)
    test_values = linear_model.predict(test_inputs)
    return test_values

In [None]:
fitted_values = get_fitted_values(linear_model=linear_model, test_linear_probe_inputs=mapped_test_split_dataset)

In [None]:
fitted_values_and_inputs = list(zip(fitted_values, mapped_test_split_dataset))

In [None]:
fitted_values_and_inputs[15]

### Do analysis on divergence values viz-a-viz original Vader lexicon.

In [None]:
def scale_values_and_input_list_to_range(values_and_input_list, min_range=min_vader_value, max_range=max_vader_value):
    all_probe_values = []
    all_lp_inputs = []
    all_tokens = []

    for values_and_inputs in values_and_input_list:
        fitted_value = values_and_inputs[0]
        lp_input = values_and_inputs[1]
        token = lp_input.token

        all_probe_values.append(fitted_value)
        all_lp_inputs.append(lp_input)
        all_tokens.append(token)

    rescaled_token_to_value_dict_list = [{input.token: rescale_value(value, all_probe_values)} for value, input in values_and_input_list]
    
    return rescaled_token_to_value_dict_list, all_tokens, all_probe_values

In [None]:
# Rescale values to a range and drop outliers.
rescaled_token_to_value_dict_list, all_test_tokens, all_test_probe_values = scale_values_and_input_list_to_range(
    values_and_input_list=fitted_values_and_inputs
)

# These are the full token values.
averaged_token_values = calculate_average_values(rescaled_token_to_value_dict_list)

In [None]:
rescaled_fitted_values_and_inputs = [
    (rescale_value(fitted_value, all_test_probe_values), lp_input) for fitted_value, lp_input in fitted_values_and_inputs
]

In [None]:
all_positive_test_tokens = [token for token in all_test_tokens if lexicon.get(token, 0) > 0]
all_negative_test_tokens = [token for token in all_test_tokens if lexicon.get(token, 0) < 0]

In [None]:
rescaled_token_to_value_dict_list = sorted(
    rescaled_token_to_value_dict_list, key=lambda x: list(x.keys())[0]
)
# rescaled_token_to_value_dict_list

In [None]:
random.seed(random_seed)

random_positive_tokens = random.sample(all_positive_test_tokens, 3)
random_negative_tokens = random.sample(all_negative_test_tokens, 3)

random_positive_token_values = {pos_token: averaged_token_values[pos_token] for pos_token in random_positive_tokens}
random_negative_token_values = {neg_token: averaged_token_values[neg_token] for neg_token in random_negative_tokens}

original_positive_values = {key: lexicon[key] for key in random_positive_tokens}
original_negative_values = {key: lexicon[key] for key in random_negative_tokens}

### Plot distribution of scores for positive and negative values.

In [None]:
def plot_original_vs_modified_values(token_values):
    all_original_token_values = {key: lexicon[key] for key in token_values}
    token_values_list = token_values.values()
    original_token_values_list = all_original_token_values.values()
    #sns.set(style="whitegrid")

    # Plot the distributions
    sns.kdeplot(token_values_list, label='Reconstructed', linestyle='-')
    sns.kdeplot(original_token_values_list, label='Original', linestyle='-')

    # Add labels and title
    plt.xlabel('Values')
    plt.ylabel('Frequency')
    plt.title('Comparative Distributions of Two Lists')

    # Add legend
    plt.legend()    
    # Show plot
    plt.show()


plot_original_vs_modified_values(token_values = averaged_token_values)

In [None]:
def plot_reconstruction_errors(token_values):
    all_original_token_values = {key: lexicon[key] for key in token_values}
    token_values_list = token_values.values()
    original_token_values_list = all_original_token_values.values()

    differences = {key: token_values[key] - all_original_token_values[key] for key in token_values}
    
    # Plot the distribution of differences
    sns.histplot(list(differences.values()), kde=True, color='skyblue', bins=10)
    plt.xlabel('Difference')
    plt.ylabel('Frequency')
    plt.title('Distribution of reconstructed values')
    plt.show()

plot_reconstruction_errors(token_values = averaged_token_values)

### Sampled positive and negative values.

In [None]:
print(f'Reconstructed positive values: {random_positive_token_values}')
print(f'Original positive values: {original_positive_values}')

print(f'Reconstructed negative values: {random_negative_token_values}')
print(f'Original negative values: {original_negative_values}')

In [None]:
def log_dictionary_as_table(table_name: str, dictionary_values: dict, columns=["token", "value"]):
    all_values = []

    for token, value in dictionary_values.items():
        all_values.append({"token": token, "value": value})

    final_df = pd.DataFrame(all_values)

    print(final_df)

    wandb.log({table_name: final_df})

In [None]:
log_dictionary_as_table(
    "sample_reconstructed_negative_token_utilities", random_negative_token_values
) 

In [None]:
log_dictionary_as_table(
    "sample_reconstructed_positive_token_utilities", random_positive_token_values
) 

In [None]:
log_dictionary_as_table(
    "full_reconstructed_token_utilities", averaged_token_values
)

In [None]:
full_original_token_values = {token: lexicon[token] for token in averaged_token_values}

In [None]:
log_dictionary_as_table(
    "original_vader_token_utilities", full_original_token_values
)

In [None]:
reconstructed_ranking = sorted(
    [(token, value) for token, value in averaged_token_values.items() if token in all_negative_test_tokens],
    key = lambda x: x[1]
)

original_ranking = sorted(
    [(token, lexicon[token]) for token in averaged_token_values if token in all_negative_test_tokens],
    key = lambda x: x[1]
)

reconstructed_ranking_tokens_only = [item[0] for item in reconstructed_ranking]
original_ranking_tokens_only = [item[0] for item in original_ranking]

In [None]:
kendall_tau_result = kendalltau(original_ranking_tokens_only, reconstructed_ranking_tokens_only)

In [None]:
kendall_tau_result

In [None]:
wandb.log({"kendall_tau_result": kendall_tau_result})

### Compute correlation of GPT-4 features with activations.

In [None]:
threshold_reward = 3.0
high_value_features_key = policy_model_name.replace("_utility_reward", "")

In [None]:
def get_high_value_features(version='v0'):
    artifact = run.use_artifact(f'nlp_and_interpretability/utility_reconstruction/high_value_features_artifact:{version}', type='data')
    artifact_dir = artifact.download()

    artifact_name = "high_value_features_artifact"

    with open(f'artifacts/{artifact_name}:{version}/{artifact_name}', "r") as in_file:
        high_value_features = json.load(in_file)

    return high_value_features

high_value_features = get_high_value_features()

In [None]:
high_value_features_for_model = high_value_features[high_value_features_key][1]

In [None]:
high_reward_inputs = [item for item in rescaled_fitted_values_and_inputs if item[0] >= threshold_reward]
lower_reward_inputs = [item for item in rescaled_fitted_values_and_inputs if item[0] < threshold_reward]

In [None]:
@dataclass
class ActivationsTokenReward:
    # Of the form (layer_num, feature_index) -> boolean (0 if inactive, 1 if active)
    activations_dict: dict

    # Of the form (layer_name: csr_matrix)
    raw_activations_dict: dict
    token: str
    linear_probe_reward: float

    def count_features(self):
        return len(self.activations_dict.keys())

    def count_all_activations(self):
        return sum(list(self.activations_dict.values()))
        
    def count_activations(self, targeted_features: list[int, int]):
        results = [self.activations_dict[tuple(feature)] for feature in targeted_features]
        return sum(results)

    def __str__(self):
        return pprint.pformat(self.__dict__)

In [None]:
def process_lp_reward_and_lp_final_input_into_atr(
    lp_reward_and_lp_final_input: list[float, LinearProbeFinalInput], 
):
    lp_reward = lp_reward_and_lp_final_input[0]
    lp_final_input = lp_reward_and_lp_final_input[1]

    token = lp_final_input.token
    lp_reward = lp_reward

    source_training_point = lp_final_input.source_training_point
    if lp_final_input.point_type == 'positive':
        activations_dict = source_training_point.positive_token_ae_features
    else:
        activations_dict = source_training_point.negative_token_ae_features

    mapped_dict = {
        model_customizer.parse_layer_name_to_layer_number(layer_name): 
        csr_activations.toarray()[0].astype(bool) for layer_name, csr_activations in activations_dict.items()
    }

    final_mapped_dict = {}

    for layer_num, activations_fired_boolean_list in mapped_dict.items():
        for index, activation_fired_boolean in enumerate(activations_fired_boolean_list):
            final_mapped_dict[(layer_num, index)] = activation_fired_boolean
    
    return ActivationsTokenReward(
        activations_dict=final_mapped_dict, raw_activations_dict=activations_dict,
        token=token, linear_probe_reward=lp_reward
    )


atr = process_lp_reward_and_lp_final_input_into_atr(
    high_reward_inputs[124])

atr.count_activations(targeted_features = high_value_features_for_model)
atr.count_features()

In [None]:
 high_reward_atr_points = [process_lp_reward_and_lp_final_input_into_atr(reward_and_lp_input) for reward_and_lp_input in high_reward_inputs]
 lower_reward_atr_points = [process_lp_reward_and_lp_final_input_into_atr(reward_and_lp_input) for reward_and_lp_input in lower_reward_inputs]

In [None]:
def get_total_feature_activation_percentage(atr_points):
    num_inputs = len(atr_points)
    num_features = atr_points[0].count_features()

    total_feature_activations = sum([atr.count_all_activations() for atr in atr_points])
    fa_percentage = 100*(total_feature_activations / (num_inputs * num_features))

    return fa_percentage

def get_feature_activation_percentage(atr_points, high_value_features_for_model=high_value_features_for_model):
    num_inputs = len(atr_points)
    num_features = len(high_value_features_for_model)

    total_targeted_feature_activations = sum([atr.count_activations(targeted_features=high_value_features_for_model) for atr in atr_points])
    targeted_fa_percentage = 100*(total_targeted_feature_activations / (num_inputs * num_features))

    return targeted_fa_percentage

In [None]:
total_fa_percentage_on_high_reward = get_total_feature_activation_percentage(high_reward_atr_points)
high_reward_feature_activation_percentage = get_feature_activation_percentage(high_reward_atr_points)
lower_reward_feature_activation_percentage = get_feature_activation_percentage(lower_reward_atr_points)

In [None]:
print(total_fa_percentage_on_high_reward)
print(high_reward_feature_activation_percentage)
print(lower_reward_feature_activation_percentage)

In [None]:
wandb.run.summary["total_fa_percentage_on_high_reward"] = total_fa_percentage_on_high_reward
wandb.run.summary["high_reward_feature_activation_percentage"] = high_reward_feature_activation_percentage
wandb.run.summary["lower_reward_feature_activation_percentage"] = lower_reward_feature_activation_percentage