In [None]:
import os
dir_list = os.chdir('./../reverse-dynamics-nlp/')

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, GPTNeoXForCausalLM
from prompt_optimizer import PromptOptimizer
from utils import reverse_normalized_generate, start_chunk_hf, forward_loss, reverse_normalized_forward, get_token_probabilities, rand_init, get_reverse_pair

Pile-10k eval. Load data, backwards and forwards models.

In [None]:
dataset = load_dataset("NeelNanda/pile-10k")
tokenizer = AutoTokenizer.from_pretrained("afterless/reverse-pythia-160m")
pairs = get_reverse_pair(dataset['train'], start_chunk_hf, tokenizer)
print(next(pairs))
bwd_model = GPTNeoXForCausalLM.from_pretrained("afterless/reverse-pythia-160m").cuda()
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m", cache_dir='/scratch/jp6263/hf/models/').cuda()

Evaluate GCG with forward LM-guided sampling

In [None]:
temp = 5 #None for default GCG with uniform sampling
prefix_probability_grad_weight = 0.1
gcg = PromptOptimizer(model, tokenizer, n_proposals=128, n_epochs=250, n_top_indices=128, prefix_loss_weight=prefix_probability_grad_weight)

In [None]:
gcg_tokenwise_acc = []
gcg_loss = []
for p,pair in enumerate(pairs):
    if len(gcg_loss)==100: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair, tokenizer)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    prefix_tokens = tokenizer.encode(prefix)
    len_prefix = len(prefix_tokens)
    rand_prefix = rand_init(len_prefix, tokenizer)
    optimized_string = gcg.optimize(rand_prefix, suffix, temperature=temp)
    predicted_prefix_tokens = tokenizer.encode(optimized_string)[:len_prefix]
    predicted_prefix = tokenizer.decode(predicted_prefix_tokens)
    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix), tokenizer)
    # print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    # print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    # print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')
    gcg_loss.append(predicted_suffix_loss.item())
    gcg_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == predicted_prefix_tokens[i]])/len(prefix_tokens))
print(f'Average tokenwise accuracy is {sum(gcg_tokenwise_acc)/len(gcg_tokenwise_acc)}')
print(f'Average loss is {sum(gcg_loss)/len(gcg_loss)}')

In [None]:
temp = 2 #None for default GCG with uniform sampling
prefix_probability_grad_weight = 0.25
gcg = PromptOptimizer(model, tokenizer, n_proposals=128, n_epochs=250, n_top_indices=128, prefix_loss_weight=prefix_probability_grad_weight)

In [None]:
gcg_tokenwise_acc = []
gcg_loss = []
for p,pair in enumerate(pairs):
    if len(gcg_loss)==100: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair, tokenizer)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    prefix_tokens = tokenizer.encode(prefix)
    len_prefix = len(prefix_tokens)
    rand_prefix = rand_init(len_prefix, tokenizer)
    optimized_string = gcg.optimize(rand_prefix, suffix, temperature=temp)
    predicted_prefix_tokens = tokenizer.encode(optimized_string)[:len_prefix]
    predicted_prefix = tokenizer.decode(predicted_prefix_tokens)
    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix), tokenizer)
    # print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    # print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    # print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')
    gcg_loss.append(predicted_suffix_loss.item())
    gcg_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == predicted_prefix_tokens[i]])/len(prefix_tokens))
print(f'Average tokenwise accuracy is {sum(gcg_tokenwise_acc)/len(gcg_tokenwise_acc)}')
print(f'Average loss is {sum(gcg_loss)/len(gcg_loss)}')

Now load dataset probabilities and setup for reverse LM eval with p(p) normalization

In [None]:
dataset_probs = get_token_probabilities(tokenizer)
inverse_dataset_probs = torch.reciprocal(dataset_probs)

In [None]:
dataset_p_temp = 0
rlm_temp = 0

In [None]:
rlm_tokenwise_acc = []
rlm_loss = []
for p,pair in enumerate(pairs):
    if len(rlm_loss)==100: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair, tokenizer)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    prefix_tokens = tokenizer.encode(prefix)
    len_prefix = len(prefix_tokens)
    predicted_prefix = reverse_normalized_generate(bwd_model, tokenizer, suffix, len_prefix, inverse_dataset_probs**dataset_p_temp, temperature=rlm_temp) 
    predicted_prefix_tokens = tokenizer.encode(predicted_prefix)[:len_prefix]
    predicted_prefix = tokenizer.decode(predicted_prefix_tokens)

    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix), tokenizer)
    # print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    # print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    # print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')

    rlm_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == predicted_prefix_tokens[i]])/len(prefix_tokens))
    rlm_loss.append(predicted_suffix_loss.item())

print(f'Average tokenwise accuracy is {sum(rlm_tokenwise_acc)/len(rlm_tokenwise_acc)}')
print(f'Average loss is {sum(rlm_loss)/len(rlm_loss)}')

Evaluate rejection sampling of RLM (no normalization)

In [None]:
rlm_tokenwise_acc = []
rlm_loss = []
rlm_best_tokenwise_acc = []
rlm_best_loss = []
all_losses = []
rlm_greedy_loss = []
all_naturals = []
greedy_natural = []
pile_prefix_loss = []

dataset_gold_loss = []
dataset_p_temp = 0
rlm_temp=0.01
rejection_sample = 100
eval_size=100

for p,pair in enumerate(tqdm(pairs)):
    if len(rlm_loss)==eval_size: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair, tokenizer)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    prefix_tokens = tokenizer.encode(prefix)
    len_prefix = len(prefix_tokens)

    min_loss, min_prefix = float('inf'), None
    all_losses.append([])
    all_naturals.append([])
    for t in range(rejection_sample):
        predicted_prefix = reverse_normalized_generate(bwd_model, tokenizer, suffix, len_prefix, None, temperature=rlm_temp) 
        predicted_prefix_tokens = tokenizer.encode(predicted_prefix)[:len_prefix]
        predicted_prefix = tokenizer.decode(predicted_prefix_tokens)
        predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix), tokenizer)
        all_losses[-1].append(predicted_suffix_loss.item())
        all_naturals[-1].append(predicted_prefix_loss.item())
        if predicted_suffix_loss < min_loss:
            min_loss = predicted_suffix_loss
            min_prefix = predicted_prefix
            min_prefix_tokens = predicted_prefix_tokens
    # print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{min_prefix}\nfor suffix:\n {suffix}')
    # print(f'Loss for suffix given predicted prefix is {min_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    # print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')

    #Now get greedy loss as baseline

    predicted_prefix = reverse_normalized_generate(bwd_model, tokenizer, suffix, len_prefix, None, temperature=0) 
    predicted_prefix_tokens = tokenizer.encode(predicted_prefix)[:len_prefix]
    predicted_prefix = tokenizer.decode(predicted_prefix_tokens)
    greedy_prefix_loss, greedy_loss = forward_loss(model, (predicted_prefix, suffix), tokenizer)
    
    pile_prefix_loss.append(prefix_loss.item())
    greedy_natural.append(greedy_prefix_loss.item())
    dataset_gold_loss.append(suffix_loss.item())
    rlm_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == predicted_prefix_tokens[i]])/len(prefix_tokens))
    rlm_loss.append(predicted_suffix_loss.item())
    rlm_best_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == min_prefix_tokens[i]])/len(prefix_tokens))
    rlm_best_loss.append(min_loss.item())
    rlm_greedy_loss.append(greedy_loss.item())

print(f'Average tokenwise accuracy is {sum(rlm_tokenwise_acc)/len(rlm_tokenwise_acc)}')
print(f'Average loss is {sum(rlm_loss)/len(rlm_loss)}')
print(f'Average dataset gold loss is {sum(dataset_gold_loss)/len(dataset_gold_loss)}')
print(f'Best tokenwise accuracy is {sum(rlm_best_tokenwise_acc)/len(rlm_best_tokenwise_acc)}')
print(f'Best loss is {sum(rlm_best_loss)/len(rlm_best_loss)}')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Initialization
Ns = range(1, rejection_sample)
mean_best_of_N_loss = []

for N in Ns:
    best_of_N_loss = [min(single_list[:N]) for single_list in all_losses]
    mean_best_of_N_loss.append(np.mean(best_of_N_loss))

plt.axhline(y=sum(dataset_gold_loss)/len(dataset_gold_loss), color='r', linestyle='--', label='Loss given true prefix')
plt.axhline(y=sum(rlm_greedy_loss)/len(rlm_greedy_loss), color='g', linestyle='--', label='Loss given greedy decode prefix')

# Plotting
plt.figure()
plt.plot(Ns, mean_best_of_N_loss, marker='o')
plt.xlabel('Number of Rejection Sampling Steps')
plt.ylabel('Arithmetic Mean of Best-of-N Loss')
plt.title('Arithmetic Mean of Best-of-N Loss vs Rejection Sampling Steps')
plt.grid(True)
plt.show()

In [None]:
# Initialization
mean_greedy_natural = sum(greedy_natural)/len(greedy_natural)
mean_greedy_loss = sum(rlm_greedy_loss)/len(rlm_greedy_loss)
pile_suffix_loss = sum(dataset_gold_loss)/len(dataset_gold_loss)
pile_prefix_natural = sum(pile_prefix_loss)/len(pile_prefix_loss)

Ns = range(1, rejection_sample)
mean_natural_loss = []
best_of_N_loss = []

for N in Ns:
    mean_natural_loss.append(np.mean([np.mean(single_list[:N]) for single_list in all_naturals]))  # Assuming all_naturals is a list of lists
    best_of_N_loss.append(np.mean([min(single_list[:N]) for single_list in all_losses]))  # Assuming all_losses is a list of lists

# Plotting
plt.figure()
plt.plot(mean_natural_loss, best_of_N_loss, marker='o', label='Best-of-N')
plt.plot([mean_greedy_natural], [mean_greedy_loss], marker='x', linestyle='', color='red', label='Greedy')
plt.plot([pile_prefix_natural], [pile_suffix_loss], marker='s', linestyle='', color='green', label='Pile')
plt.xlabel('Arithmetic Mean of NLL of forwards LM on Prefix')
plt.ylabel('Best-of-N Suffix Loss')
plt.title('Best-of-N Suffix Loss vs Arithmetic Mean of NLL of forwards LM on Prefix')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()
