In [None]:
import os
dir_list = os.chdir('./../reverse-dynamics-nlp/')

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, GPTNeoXForCausalLM
from prompt_optimizer import PromptOptimizer
from utils import reverse_normalized_generate, start_chunk_hf, forward_loss, reverse_normalized_forward, get_token_probabilities, rand_init

Pile-10k eval. Load data, backwards and forwards models.

In [None]:
dataset = load_dataset("NeelNanda/pile-10k")
tokenizer = AutoTokenizer.from_pretrained("afterless/reverse-pythia-160m")
pairs = get_reverse_pair(dataset['train'], start_chunk_hf, tokenizer)
print(next(pairs))
bwd_model = GPTNeoXForCausalLM.from_pretrained("afterless/reverse-pythia-160m").cuda()
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m", cache_dir='/scratch/jp6263/hf/models/').cuda()

In [None]:
gcg = PromptOptimizer(model, tokenizer, n_proposals=128, n_epochs=250, n_top_indices=128, prefix_loss_weight=0.1)

Evaluate GCG with forward LM-guided sampling

In [None]:
temp = 5 #None for default GCG with uniform sampling

In [None]:
gcg_tokenwise_acc = []
gcg_loss = []
for p,pair in enumerate(pairs):
    if len(gcg_loss)==100: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    prefix_tokens = tokenizer.encode(prefix)
    len_prefix = len(prefix_tokens)
    rand_prefix = rand_init(len_prefix, tokenizer)
    optimized_string = gcg.optimize(rand_prefix, suffix, temperature=temp)
    predicted_prefix_tokens = tokenizer.encode(optimized_string)[:len_prefix]
    predicted_prefix = tokenizer.decode(predicted_prefix_tokens)
    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix))
    print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')
    gcg_loss.append(predicted_suffix_loss.item())
    gcg_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == predicted_prefix_tokens[i]])/len(prefix_tokens))
print(f'Average tokenwise accuracy is {sum(gcg_tokenwise_acc)/len(gcg_tokenwise_acc)}')
print(f'Average loss is {sum(gcg_loss)/len(gcg_loss)}')

Now load dataset probabilities and setup for reverse LM eval

In [None]:
dataset_probs = get_token_probabilities(tokenizer)
inverse_dataset_probs = torch.reciprocal(dataset_probs)

In [None]:
dataset_p_temp = 0.25
rlm_temp = 0

In [None]:
rlm_tokenwise_acc = []
rlm_loss = []
for p,pair in enumerate(pairs):
    if len(rlm_loss)==100: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    prefix_tokens = tokenizer.encode(prefix)
    len_prefix = len(prefix_tokens)
    predicted_prefix = reverse_normalized_generate(bwd_model, tokenizer, suffix, len_prefix, inverse_dataset_probs**dataset_p_temp, temperature=rlm_temp) #1.425 at 0.25 partial Bayes update vs 1.437 at 0 i.e. default
    predicted_prefix_tokens = tokenizer.encode(predicted_prefix)[:len_prefix]
    predicted_prefix = tokenizer.decode(predicted_prefix_tokens)

    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix))
    print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')

    rlm_tokenwise_acc.append(sum([1 for i in range(len(prefix_tokens)) if prefix_tokens[i] == predicted_prefix_tokens[i]])/len(prefix_tokens))
    rlm_loss.append(predicted_suffix_loss.item())

print(f'Average tokenwise accuracy is {sum(rlm_tokenwise_acc)/len(rlm_tokenwise_acc)}')
print(f'Average loss is {sum(rlm_loss)/len(rlm_loss)}')