# RL for LLMs

## Load the LLM

In [1]:
model_size = "70m"  # or 2.8b
revision = "step143000"

In [2]:
from datasets import load_dataset
import functools
import random
import re
import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer, BertTokenizer, BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = AutoTokenizer.from_pretrained(
  f"EleutherAI/pythia-{model_size}-deduped",
  revision=revision,
  cache_dir=f"./pythia-{model_size}-deduped/{revision}",
  padding_side='left',
)

In [4]:
lm = GPTNeoXForCausalLM.from_pretrained(
  f"EleutherAI/pythia-{model_size}-deduped",
  revision=revision,
  cache_dir=f"./pythia-{model_size}-deduped/{revision}",
)

In [5]:
def generate(lm, text: str) -> str:
    tokens = lm.generate(
        **tokenizer(text, return_tensors="pt"),
        pad_token_id=tokenizer.eos_token_id, max_length=30)
    return tokenizer.decode(tokens[0])

In [6]:
generate(lm, "Hi there, how")

"Hi there, how can I get the data from the database?\n\nA:\n\nI'm not sure what you mean.  I'm"

In [7]:
generate(lm, "Tell me a joke!")

"Tell me a joke!\n\nI'm not sure what to do with the words.\n\nI'm not sure what to do with the words"

## Reward Function

In [8]:
def sentence_contains_one_of(sequence: str, words: set[str]) -> float:
    sequence_words = set(word.lower() for word in sequence.split())
    return len(sequence_words & words)

In [9]:
def has_length(sample: str, target_len: int) -> float:
    return -abs(len(sample) - target_len)

In [10]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')


@functools.cache
def bert_sentence_embedding(sentence: str) -> torch.Tensor:    
    inputs = bert_tokenizer(
        sentence, return_tensors='pt', padding=True, truncation=True
    )

    with torch.no_grad():
        outputs = bert_model(**inputs)

    embeddings = outputs.last_hidden_state.mean(dim=1)

    return embeddings.squeeze()

In [11]:
def bert_emb_dist(t1: str, t2: str) -> float:
    t1_emb = bert_sentence_embedding(t1)
    t2_emb = bert_sentence_embedding(t2)

    return -torch.norm(t1_emb - t2_emb)

In [12]:
[
    bert_emb_dist("How are you?", "What's up?"),
    bert_emb_dist("How are you?", "How are you?!"),
    bert_emb_dist("How are you?", "Beethoven was a great composer."),
]

[tensor(-8.4843), tensor(-4.2851), tensor(-10.2245)]

In [13]:
# OpenAssistant dataset

def is_capitalized(s: str) -> bool:
    return s[:1].isupper() and s[1:] == s[1:].lower()

oasst1_data = load_dataset("OpenAssistant/oasst1")
oasst1_prompts = [
    x['text']
    for x in oasst1_data['train']
    if 0 < len(x['text']) < 32 and x['lang'] == 'en' and is_capitalized(x['text'][0])
]
len(oasst1_prompts), oasst1_prompts[:5]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


(2054,
 ['Now explain it to a dog',
  'Is SEO relevant in 2023?',
  "Don't care, didn't ask!",
  "I'm sorry",
  'No I just wanted to thank you.'])

## Optimization

In [14]:
def get_selected_log_probs(tokens, log_probs, prompt_len: int):
    """Returns the scores of the tokens that were sampled.

    tokens: [mc_samples, seq_len (incl. prompt)]
    log_probs: [len seq_len (excl. prompt), mc_samples, vocab_size]
    """

    output = []

    for mc_sample_idx in range(len(tokens)):
        sample_log_probs = []
        for token_pos in range(len(tokens[mc_sample_idx, prompt_len:])):
            token_idx = tokens[mc_sample_idx][token_pos + prompt_len]
            log_prob = log_probs[token_pos][mc_sample_idx][token_idx]
            sample_log_probs.append(log_prob)
        output.append(torch.stack(sample_log_probs))

    return output  # list[tensor] with shape [mc_samples, seq_len]

In [15]:
def draw_mc_samples(prompt, lm, num_samples, temperature=1.0, max_length=30):
    """Draws MC samples for a given prompt."""
    prompt_input = tokenizer(prompt, return_tensors="pt")
    prompt_input_ids = prompt_input['input_ids']
    prompt_attention_mask = prompt_input['attention_mask']

    outputs = lm.generate(
        input_ids=prompt_input_ids.repeat(num_samples, 1),
        attention_mask=prompt_attention_mask.repeat(num_samples, 1),
        pad_token_id=tokenizer.eos_token_id,
        max_length=max_length,
        return_dict_in_generate=True,
        output_scores=True,
        do_sample=True,  # Must be True for temperature to take effect
        temperature=temperature,
    )
    # TODO: Missing <bos> here?
    prompt_len = len(prompt_input_ids[0])
    output_tokens = outputs.sequences
    log_probs_with_grad = lm(output_tokens).logits[:, prompt_len:, :]  # [num_mc_samples, seq_len, vocab_size]
    log_probs_with_grad = log_probs_with_grad.permute(1, 0, 2)

    output_log_softmax = tuple(
        torch.log_softmax(score_matrix, dim=-1)
        for score_matrix in log_probs_with_grad
    )
    selected_log_probs = get_selected_log_probs(output_tokens, output_log_softmax, prompt_len)
    return output_tokens, selected_log_probs

In [16]:
output_tokens, output_log_probs = draw_mc_samples("Hello", lm, num_samples=4)
output_log_probs[0]

tensor([ -7.5776,  -7.6723, -12.8539,  -8.5047,  -9.0768,  -7.4754,  -5.6155,
         -7.4578,  -5.9396,  -9.2624,  -9.5459,  -7.9808,  -8.3465,  -7.3535,
         -6.8771,  -9.5508, -10.9028, -10.4836,  -9.2769,  -9.0039,  -7.6014,
         -8.9415,  -7.8402,  -8.8226,  -7.8247,  -7.3780,  -4.9987,  -8.5346,
        -10.5963], grad_fn=<StackBackward0>)

l-th output token: $y_l$

Probability of the sequence: $p(y) = p(y_0) \cdot p(y_1 | y_0) \cdot ...$

\begin{align}
\nabla \ln p(y) &= \nabla \ln p(y_0) \cdot p(y_1 | y_0) \cdot ... \\
&= \nabla \sum_i \ln p(y_i \vert y_{<i})
\end{align}

In [17]:
batch_size = 128
num_steps = 50
num_mc_samples = 8
max_length_tokens = 20  # Sampling length in tokens.
required_words = {
    "red", "blue", "green", "yellow", "orange", "purple", "pink", "cyan", "magenta", "turquoise",
    "violin", "piano", "guitar", "drums", "flute", "trumpet", "saxophone", "cello", "clarinet", "harp",
    "running", "swimming", "cycling", "painting", "dancing", "singing", "writing", "climbing", "skiing", "cooking",
    "black", "white", "beige", "brown", "gray",
    "banjo", "accordion", "trombone", "oboe", "mandolin",
    "jogging", "hiking", "knitting", "gardening", "surfing",
    "karate", "yoga", "chess", "fishing", "skating"
}

sample_sentence = "Hi there!"

opt = torch.optim.Adam(lm.parameters(), lr=3e-4)

reward_history = []
baseline_history = []

for step in range(num_steps):
    batch_loss = 0
    batch_rewards = []

    for _ in range(batch_size):
        # Sample prompt from text dataset.
        prompt = random.choice(oasst1_prompts)
        prompt_length = tokenizer(prompt, return_tensors="pt")['input_ids'].shape[1]

        # Sample batch of continuations for a given prompt (MC samples per prompt).
        output_tokens, output_log_probs = draw_mc_samples(
            prompt, lm, num_mc_samples, temperature=1.0, max_length=max_length_tokens)

        # Compute sequence log probability from individual token probs.
        sequence_log_probs = torch.stack([vec.sum() for vec in output_log_probs])

        # Decode to text to compute the reward per MC sample.
        output_texts = [tokenizer.decode(tokens) for tokens in output_tokens]  # Prompt + continuation
        output_continuations = [tokenizer.decode(tokens) for tokens in output_tokens[:, prompt_length:]]  # Only continuation
        rewards = torch.tensor(
            [bert_emb_dist(text, sample_sentence) for text in output_continuations], dtype=torch.float32)

        reward_baseline = rewards.mean()
        batch_rewards.append(rewards.mean().item())

        loss = -((rewards - reward_baseline) * sequence_log_probs).mean()
        batch_loss += loss

    # Take the mean loss across the batch of prompts.
    batch_loss /= batch_size

    opt.zero_grad()
    batch_loss.backward()
    opt.step()

    mean_batch_reward = sum(batch_rewards) / batch_size
    reward_history.append(mean_batch_reward)
    baseline_history.append(mean_batch_reward)

    # Print learning progress (using first continuation from last prompt of batch)
    sample_output = re.sub(r'\s+', ' ', output_continuations[0]).strip()

    print(
        f"Step: {step} "
        f"Baseline: {mean_batch_reward:.4f} "
        f"Reward variance: {rewards.var().item():.4f} "
        f"Example sentence: {sample_output}"
    )

Step: 0 Baseline: -9.8361 Reward variance: 0.4873 Example sentence: But there are now many different questions that are not yet known.
Step: 1 Baseline: -9.9320 Reward variance: 0.5705 Example sentence: What’’’
Step: 2 Baseline: -11.8985 Reward variance: 0.7114 Example sentence: ,"!!!!!,"!,"!,"!,"!!,"!?" I
Step: 3 Baseline: -11.8068 Reward variance: 1.9457 Example sentence: you you You you you you you you you you you you you you
Step: 4 Baseline: -11.4442 Reward variance: 1.1765 Example sentence: ,"?"?"?","?"?","'re,"'d?"?",","
Step: 5 Baseline: -11.1384 Reward variance: 0.4877 Example sentence: iiii'di'd'd'd'd'd'd'd'd'd'd'd
Step: 6 Baseline: -11.1759 Reward variance: 0.1344 Example sentence: iiiiiiiiiiiiiii
Step: 7 Baseline: -12.1487 Reward variance: 1.0732 Example sentence: --""""""""""""
Step: 8 Baseline: -11.8994 Reward variance: 0.4735 Example sentence: ,"--'m'm'm'm'm'm'm'm'm'm'm'm'm'm'm
Step: 9 Baseline: -11.4417 Reward variance: 0.3433 Example sentence: --------,"---'t't't
Step:

In [18]:
generate(lm, "PDG")

'PDG,",",",",",",",",",",",",",",",",",",",",",",",",",",","'