# Nudge Intrinsic Geometry

In this notebook, we will determine whether we can perform [distributional control](https://arxiv.org/abs/2310.04444)
by nudging the value vectors in the positive/negative valence and arousal dimensions

 1. Compute mean logits (next token distribution) for high/low valence and arousal datasets. 
     - Added to `scripts/get_value_reps.py` -- each element now has a set of logits over the last token associated with it. 
     - Now we just need to compute the means for each +/- valence thing, compute the standard deviation, see if there's a statistical difference. 
 2. Load `weights.npz` from `cache/happy_sad_0330b2024/` and `cache/low_high_arousal_0330b2024/` for the direction in value space. 
 3. Set up the GPT-2 model to do a forward pass, get the past_kv cache, add `epsilon * weights.npc['coeff']` for `epsilon = [-1, -0.5, -0.25, -0.125, -0.0625, -0.03125, -0.015625, 0, +0.015625, +0.03125, +0.625, +0.125, +0.25, +0.5, +1]`. 
     - Catch the logits for each of these. 
     - Compute the KL divergence between the logits and the mean happy/sad and low/high arousal logits. 
     - Record logits + KL divergences in a data structure. 
 4. Plot epsilon vs. KL divergence to each cluster center. 


In [14]:
RESULTS_DIR = '../cache/happy_sad_0330b2024' # this is where weights.npz are kept 
VALUE_REPS_JSON = '../cache/gpt2_happy_sad_0330b2024.json'

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# gpt-2 model 
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

In [6]:
# Load weights.npz from RESULTS_DIR
weights = np.load(os.path.join(RESULTS_DIR, 'weights.npz'))
linreg_weights = weights['arr_0']
linreg_bias = weights['arr_1']
print('linreg_weights:', linreg_weights.shape)
print('linreg_bias:', linreg_bias.shape)   

linreg_weights: (1, 9216)
linreg_bias: (1,)


In [8]:
# Load the gpt-2 model 
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')

## 1: Compute Mean Pr(Next Token) for Happy vs. Sad
 1. Load `VALUE_REPS_JSON`. 
 2. Iterate through the json, compute mean logits (numpy array) for each `class_0_true==True` and `class_0_true==False`. 

In [9]:
# load value reps json 
with open(VALUE_REPS_JSON) as f:
    value_reps = json.load(f)
print('value_reps:', len(value_reps))
print('value_reps[0].keys(): ', value_reps[0].keys())

value_reps: 48260
value_reps[0].keys():  dict_keys(['final_prompt', 'final_prompt_ids', 'token_of_interest', 'prompt_template', 'note', 'negation', 'class_0_true', 'class_name', 'adjective', 'model', 'latent_space'])


In [13]:
value_reps[0]['final_prompt_ids']

[3260, 4854, 262, 1705, 11, 24799, 373, 8131, 6507, 11, 290, 673, 3393]

In [11]:
# how many class_0_true==True vs. class_0_true==False
class_0_true = [x['class_0_true'] for x in value_reps]
print('class_0_true:', len(class_0_true))
num_class_0_true = sum(class_0_true)
print('class_0_true==True:', num_class_0_true)
num_class_0_false = len(class_0_true) - num_class_0_true
print('class_0_true==False:', num_class_0_false)

class_0_true: 48260
class_0_true==True: 24130
class_0_true==False: 24130


In [31]:
# now we iterate thru, and get the gpt-2 final token logits for each
mean_class_0_true_logits = None
mean_class_0_false_logits = None
do_softmax = False

for i, value_rep in tqdm(enumerate(value_reps)):
    input_ids = torch.tensor([value_rep['final_prompt_ids']]).to('cuda')
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
        # print("Shape of logits: ", logits.shape)
        logits = logits[0, -1, :]
        logits = logits.cpu().numpy()
        # softmax 
        if do_softmax:
            logits = np.exp(logits) / np.sum(np.exp(logits) + 1e-6)
        # print("Shape of logits: ", logits.shape)
        if value_rep['class_0_true']:
            if mean_class_0_true_logits is None:
                mean_class_0_true_logits = logits / num_class_0_true
            else:
                mean_class_0_true_logits += logits / num_class_0_true
        else:
            if mean_class_0_false_logits is None:
                mean_class_0_false_logits = logits / num_class_0_false
            else:
                mean_class_0_false_logits += logits / num_class_0_false


48260it [05:15, 153.16it/s]


In [32]:
# save the mean_class_0_true_logits and mean_class_0_false_logits to RESULTS_DIR 
np.savez(os.path.join(RESULTS_DIR, 'mean_class_0_logits.npz'), mean_class_0_true_logits, mean_class_0_false_logits)
print('mean_class_0_true_logits:', mean_class_0_true_logits.shape)
print("\tMean: ", np.mean(mean_class_0_true_logits))
print("\tStd: ", np.std(mean_class_0_true_logits))

print(f'mean_class_0_false_logits:', mean_class_0_false_logits.shape)
print(f"\tMean: ", np.mean(mean_class_0_false_logits))
print(f"\tStd: ", np.std(mean_class_0_false_logits))


# grab them back 
mean_class_0_logits = np.load(os.path.join(RESULTS_DIR, 'mean_class_0_logits.npz'))
mean_class_0_true_logits = mean_class_0_logits['arr_0']
mean_class_0_false_logits = mean_class_0_logits['arr_1']

print('\nmean_class_0_true_logits:', mean_class_0_true_logits.shape)
print("\tMean: ", np.mean(mean_class_0_true_logits))
print("\tStd: ", np.std(mean_class_0_true_logits))

print('mean_class_0_false_logits:', mean_class_0_false_logits.shape)
print("\tMean: ", np.mean(mean_class_0_false_logits))
print("\tStd: ", np.std(mean_class_0_false_logits))

mean_class_0_true_logits: (50257,)
	Mean:  -138.81699
	Std:  3.7843447
mean_class_0_false_logits: (50257,)
	Mean:  -138.69728
	Std:  3.7409842

mean_class_0_true_logits: (50257,)
	Mean:  -138.81699
	Std:  3.7843447
mean_class_0_false_logits: (50257,)
	Mean:  -138.69728
	Std:  3.7409842


## 2: KL Divergence vs. Epsilon Function

This function will take as arguments: 
 1. `model`: HuggingFace model 
 2. `full_prompt_ids`: Original prompt ids
 3. `mean_class_0_true_logits`
 4. `mean_class_0_false_logits`
 5. `epsilon_range = [-1, -0.5, -0.25, -0.125, -0.0625, -0.03125, -0.015625, 0, +0.015625, +0.03125, +0.625, +0.125, +0.25, +0.5, +1]`
 6. `weight_vec`: Numpy array of shape [1, 9216]. 

The function will get the past_kv for the full_prompt_ids, then modify the 
value reps of the last token by adding `epsilon * weight_vec` for 
`epsilon in epsilon_range`. 

Note that we will need to reshape `epsilon * weight_vec` into `n_layers, n_heads, embed_dim`. 

We can find all these from the `past_kv` object, which  has shape `[num_layers, 2=kv, [batch=1, num_heads, seq_len, head_dim]]`. 

`weight_vec[0:layer_size]` is the flattened value tensor of shape `[(batch = 1) * (num_heads=12) * (seq_len=1) * head_dim=64]`

weight_vec has shape 9216 = 12 layers * 12 heads * 64 head_dim. 

In [39]:
# compute kl divergence between mean_class_0_true_logits and mean_class_0_false_logits
def kl_divergence(p, q):
    # softmax each 
    p = np.exp(p) / (np.sum(np.exp(p)) + 1e-6)
    q = np.exp(q) / (np.sum(np.exp(q))+ 1e-6)
    return np.sum(p * np.log(p / (q+1e-6)))

kl_div = kl_divergence(mean_class_0_true_logits, mean_class_0_false_logits)
print('kl_divergence between +/- class:', kl_div)

kl_divergence between +/- class: nan


  return np.sum(p * np.log(p / (q+1e-6)))
  return np.sum(p * np.log(p / (q+1e-6)))


In [40]:
np.linalg.norm(mean_class_0_true_logits - mean_class_0_false_logits)

46.352303

In [42]:
# decode the top 10 tokens for each
top_k = 20
top_k_true = np.argsort(mean_class_0_true_logits)[::-1][:top_k]
top_k_false = np.argsort(mean_class_0_false_logits)[::-1][:top_k]

print('Top k true: ', top_k_true)
print('Top k false: ', top_k_false)

# decode the top k tokens
def decode_tokens(token_ids):
    return tokenizer.decode(token_ids)

top_k_true_decoded = [decode_tokens([x]) for x in top_k_true]
top_k_false_decoded = [decode_tokens([x]) for x in top_k_false]

print('Top k true decoded: ', top_k_true_decoded)
print('Top k false decoded: ', top_k_false_decoded)


Top k true:  [ 366  655  407  262 1107  257 1682   11  477  523  307  287 1234  691
  991  517  635  900 1464  379]
Top k false:  [ 366  655 1107  262  407  257 1682   11  523  477  287  517  307  991
  635 1234 1464  379  772  900]
Top k true decoded:  [' "', ' just', ' not', ' the', ' really', ' a', ' actually', ',', ' all', ' so', ' be', ' in', ' put', ' only', ' still', ' more', ' also', ' set', ' always', ' at']
Top k false decoded:  [' "', ' just', ' really', ' the', ' not', ' a', ' actually', ',', ' so', ' all', ' in', ' more', ' be', ' still', ' also', ' put', ' always', ' at', ' even', ' set']


In [73]:
def add_to_past_kv(weight_vec, past_kv_): 
    """ Past kv is the generic tuple[tuple[torch.Tensor]] 
    and weight_vec is a numpy array of shape [9216].

    We need to convert weight_vec to a torch tensor, split across len(past_kv)
    (num_layers). Then reshape each to (batch=1, num_heads, seq_len=1, head_dim) and
    add the weight_vec to the past_kv values. 
    """
    # deep copy to past_kv 
    past_kv = past_kv_
    # print("Shape of weight_vec: ", weight_vec.shape)
    num_layers = len(past_kv)
    assert len(past_kv[0]) == 2
    batch = past_kv[0][0].shape[0]
    assert batch == 1
    n_heads = past_kv[0][0].shape[1]
    seq_len = past_kv[0][0].shape[2]
    head_dim = past_kv[0][0].shape[3]
    # print(f"n_heads: {n_heads}, seq_len: {seq_len}, head_dim: {head_dim}")
    layer_dim = batch * n_heads * head_dim
    # print("layer_dim: ", layer_dim)

    for layer in range(len(past_kv)): 
        weight_vec_l = weight_vec[:, (layer*layer_dim): ((layer+1)*layer_dim)]
        # reshape to [n_heads, head_dim]
        weight_vec_l = weight_vec_l.reshape((n_heads, head_dim))
        # print("Shape of weight_vec_l: ", weight_vec_l.shape)
        # print("Shape of past_kv[layer][1]: ", past_kv[layer][1].shape)
        past_kv[layer][1][0, :, -1, :] += torch.tensor(weight_vec_l).to(past_kv[layer][1].device)
    return past_kv

In [77]:
# take a single example, see how the logits change with epsilon 
epsilon_range = np.linspace(-3, 3, 10)

# get the prompt
idx = 0
full_prompt_ids = value_reps[idx]['final_prompt_ids']
prompt = tokenizer.decode(full_prompt_ids)
print('prompt:', prompt)
print("class_0_true:", value_reps[idx]['class_0_true'])
print("Full prompt ids: ", full_prompt_ids)

# get the past kv 
full_prompt_ids = torch.tensor([full_prompt_ids]).to('cuda')
with torch.no_grad():
    outputs = model(full_prompt_ids)
    past_kv = outputs.past_key_values

print("Length of past_kv (12): ", len(past_kv))
print("Shape of past_kv[0] (2): ", len(past_kv[0]))
print("Shape of past_kv[0][0] (batch=1, n_head=12, seq_len, head_dim=64): ", past_kv[0][0].shape)
print("\n\n=== Starting past_kv epsilon arp ===")
for epsilon in epsilon_range:
    # add epsilon to the last token 
    weight_vec = linreg_weights * epsilon
    past_kv_ = add_to_past_kv(weight_vec, past_kv)
    with torch.no_grad():
        outputs = model(torch.tensor([[0]]).to('cuda'), past_key_values=past_kv)
        logits = outputs.logits
        logits = logits[0, -1, :]
        logits = logits.cpu().numpy()
        if do_softmax:
            logits = np.exp(logits) / np.sum(np.exp(logits) + 1e-6)
        print(f"epsilon: {epsilon}")

        # compute distance to mean_class_0_true_logits and mean_class_0_false_logits
        dist_true = np.linalg.norm(mean_class_0_true_logits - logits)
        dist_false = np.linalg.norm(mean_class_0_false_logits - logits)
        # print(f"\tDistance to mean_class_0_true_logits: {dist_true}")
        # print(f"\tDistance to mean_class_0_false_logits: {dist_false}")
        # cosine similarity
        sim_true = np.dot(mean_class_0_true_logits, logits) / (np.linalg.norm(mean_class_0_true_logits) * np.linalg.norm(logits))
        sim_false = np.dot(mean_class_0_false_logits, logits) / (np.linalg.norm(mean_class_0_false_logits) * np.linalg.norm(logits))
        # print(f"\tCosine similarity to mean_class_0_true_logits: {sim_true}")
        # print(f"\tCosine similarity to mean_class_0_false_logits: {sim_false}")
        print(f"\tCloser to class true: ", sim_true > sim_false)
        print(f"\tCloser to class true: ", dist_true > dist_false)

prompt: After hearing the news, Sara was incredibly sad, and she immediately
class_0_true: True
Full prompt ids:  [3260, 4854, 262, 1705, 11, 24799, 373, 8131, 6507, 11, 290, 673, 3393]
Length of past_kv (12):  12
Shape of past_kv[0] (2):  2
Shape of past_kv[0][0] (batch=1, n_head=12, seq_len, head_dim=64):  torch.Size([1, 12, 13, 64])


=== Starting past_kv epsilon arp ===
epsilon: -3.0
	Closer to class true:  False
epsilon: -2.9393939393939394
	Closer to class true:  False
epsilon: -2.878787878787879
	Closer to class true:  False
epsilon: -2.8181818181818183
	Closer to class true:  False
epsilon: -2.757575757575758
	Closer to class true:  False
epsilon: -2.696969696969697
	Closer to class true:  False
epsilon: -2.6363636363636362
	Closer to class true:  False
epsilon: -2.5757575757575757
	Closer to class true:  False
epsilon: -2.515151515151515
	Closer to class true:  False
epsilon: -2.4545454545454546
	Closer to class true:  False
epsilon: -2.393939393939394
	Closer to class true:  