# Nudge LLM Trajectories via Emotional Axes

_See `notebooks/llm emotional control.pdf` for less abstract treatment._

In this notebook we will demonstrate control over probabilities assigned to 
trajectories $t_i^+, t_i^-$ for $i\in[N]$ from $t_i^+ \sim P_{LM}(t_i^+ | x_0^+)$ 
and $t_i^- \sim P_{LM}(t_i^- | x_0^-)$ i.i.d. 

**Experiment 1**: cache `key_value_plus`, `key_value_minus` from
*`model(x_0_plus)`, `model(x_0_minus)` 
k

In [1]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# gpt-2 model 
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

In [2]:
# Load the gpt-2 model 
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')

## 1: Compute `past_kv_plus`, `past_kv_minus`

In [13]:
# define dataset 
x_0_plus = "Alice was happy, so Alice"
x_0_minus = "Alice was sad, so Alice"

# tokenize as pytorch tensors 
x_0_plus_ids = tokenizer.encode(x_0_plus, return_tensors='pt').to('cuda')
x_0_minus_ids = tokenizer.encode(x_0_minus, return_tensors='pt').to('cuda')
print("x_0_plus_ids: ", x_0_plus_ids)
print("x_0_minus_ids: ", x_0_minus_ids)

# ensure tokenizer can decode
print("x_0_plus: ", tokenizer.decode(x_0_plus_ids[0]))
print("x_0_minus: ", tokenizer.decode(x_0_minus_ids[0]))

# retrieve past_kv_plus and past_kv_minus 
with torch.no_grad():
    outputs = model(x_0_plus_ids, past_key_values=None)
    past_kv_plus = outputs.past_key_values
    outputs = model(x_0_minus_ids, past_key_values=None)
    past_kv_minus = outputs.past_key_values

x_0_plus_ids:  tensor([[44484,   373,  3772,    11,   523, 14862]], device='cuda:0')
x_0_minus_ids:  tensor([[44484,   373,  6507,    11,   523, 14862]], device='cuda:0')
x_0_plus:  Alice was happy, so Alice
x_0_minus:  Alice was sad, so Alice


In [14]:
# let's go through one token at a time: 
for i in range(x_0_plus_ids.shape[1]): 
    print(f"Decoded {x_0_plus_ids[0, i]}: {tokenizer.decode(x_0_plus_ids[0, i])}")

Decoded 44484: Alice
Decoded 373:  was
Decoded 3772:  happy
Decoded 11: ,
Decoded 523:  so
Decoded 14862:  Alice


 - `past_kv_plus`, `past_kv_minus` are tuples of length `num_layer=12`. 
 - `past_kv[l]` is a tuple of length 2 (key, values) for layer `l`
 - `past_kv[l][0]` are the keys for layer `l`, shape `[b, num_heads, ]


In [26]:
print("0th layer, keys shape: ", past_kv_plus[0][0].shape) 
print(f"\tshape [batch, num_heads, seq_len, head_dim]")
print("0th layer, values shape: ", past_kv_plus[0][1].shape)
print(f"\tshape [batch, num_heads, seq_len, head_dim]")


0th layer, keys shape:  torch.Size([1, 12, 6, 64])
	shape [batch, num_heads, seq_len, head_dim]
0th layer, values shape:  torch.Size([1, 12, 6, 64])
	shape [batch, num_heads, seq_len, head_dim]


In [24]:
print("Number of heads in gpt2 attention: ", model.transformer.h[0].attn.num_heads)
print("Head dim in gpt2 attention: ", model.transformer.h[0].attn.head_dim)

Number of heads in gpt2 attention:  12
Head dim in gpt2 attention:  64
