In [1]:
import os
dir_list = os.chdir('./../reverse-dynamics-nlp/')

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, GPTNeoXForCausalLM
from prompt_optimizer import PromptOptimizer
from utils import reverse_generate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Callable, Iterable, Any
def get_reverse_pair(dataset: Iterable[Any], chunk_func: Callable[..., Any], tokenizer: AutoTokenizer):
    for chunk in dataset:
        for chunk in chunk_func(chunk, tokenizer):
            yield chunk

def end_chunk_hf(chunk, tokenizer):
    chunk = chunk['text']
    tokens = tokenizer(chunk[-200:])['input_ids'][2:] #drop first couple tokens given risk of incomplete token
    yield tokenizer.decode(tokens[-40:-30]), tokenizer.decode(tokens[-30:])

def start_chunk_hf(chunk, tokenizer, num_prefix_tokens=10, num_suffix_tokens=40):
    chunk = chunk['text']
    tokens = tokenizer(chunk[:200])['input_ids'] #drop first couple tokens given risk of incomplete token
    yield tokenizer.decode(tokens[:num_prefix_tokens]), tokenizer.decode(tokens[num_prefix_tokens:num_prefix_tokens+num_suffix_tokens])

def rand_init(seq_length: int, tokenizer):
    return tokenizer.decode(torch.randint(0, tokenizer.vocab_size, (seq_length,)))

In [3]:

dataset = load_dataset("NeelNanda/pile-10k")
tokenizer = AutoTokenizer.from_pretrained("afterless/reverse-pythia-160m")
pairs = get_reverse_pair(dataset['train'], start_chunk_hf, tokenizer)
print(next(pairs))
bwd_model = GPTNeoXForCausalLM.from_pretrained("afterless/reverse-pythia-160m").cuda()
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m", cache_dir='/scratch/jp6263/hf/models/').cuda()


Found cached dataset parquet (/home/jp6263/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 1/1 [00:00<00:00, 433.12it/s]


('It is done, and submitted. You can play', ' “Survival of the Tastiest” on Android, and on the web. Playing on the web works, but you have to simulate multi-touch for table moving and that can be a bit')


In [4]:
def forward_loss(model, pair, loss=torch.nn.CrossEntropyLoss(),):
    prefix, suffix = pair
    whole_tensor = tokenizer(prefix+suffix, return_tensors='pt').input_ids.cuda()
    with torch.no_grad():
        logs = model(whole_tensor).logits
    start_ind = len(tokenizer.encode(prefix))
    l_pref = loss(logs[0,:start_ind], whole_tensor[0,1:start_ind+1])
    l_suff = loss(logs[0,start_ind:-1], whole_tensor[0,start_ind+1:])
    return l_pref, l_suff

In [5]:
gcg = PromptOptimizer(model, tokenizer, n_proposals=128, n_epochs=250, n_top_indices=128, prefix_loss_weight=0.1)

In [6]:
tokenwise_acc = []
loss = []
temp = 5 #None for default GCG with uniform sampling
for p,pair in enumerate(pairs):
    if len(loss)==5: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair)
    if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    len_prefix = len(tokenizer(prefix)['input_ids'])
    rand_prefix = rand_init(len_prefix, tokenizer)
    optimized_string = gcg.optimize(rand_prefix, suffix, temperature=temp)
    optimal_prefix = tokenizer.decode(tokenizer.encode(optimized_string)[:len_prefix])
    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (optimal_prefix, suffix))
    print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{optimal_prefix}\nfor suffix:\n {suffix}')
    print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')
    loss.append(predicted_suffix_loss.item())
    # tokenwise_acc.append(sum([1 for i in range(len(prefix)) if prefix[i] == optimal_prefix[i]])/len(prefix))
# print(f'Average tokenwise accuracy is {sum(tokenwise_acc)/len(tokenwise_acc)}')
print(f'Average loss is {sum(loss)/len(loss)}')

True prefix is:
<?xml version="1.0" encoding=" 

Predicted prefix:
 Zend_, prostate missionary}$$\ige ${{\waitSeg?"
for suffix:
 UTF-8"?>
<segment>
    <name>PD1</name>
    <description>Patient Additional Demographic</description>
    <elements>
        <field
Loss for suffix given predicted prefix is 1.967311143875122 
 Suffix loss for true prefix is 2.0244696140289307
NLL on predicted prefix is 12.9093599319458 
 NLL on true prefix is 0.6212190389633179
True prefix is:
{
  "fpsLimit": 60, 

Predicted prefix:
dialopacity bilWVeveJSON ripownerOFFSET collagen
for suffix:
 
  "preset": "basic",
  "background": {
    "color": "#0d47a1",
    "image": "",
    "position": "50
Loss for suffix given predicted prefix is 2.122087240219116 
 Suffix loss for true prefix is 1.8352282047271729
NLL on predicted prefix is 12.266073226928711 
 NLL on true prefix is 3.9634578227996826
True prefix is:
This application is based upon and claims the benefit of 

Predicted prefix:
iticMeasureManagement outst

In [144]:
tokenwise_acc = []
loss = []
for p,pair in enumerate(pairs):
    if len(loss)==100: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair)
    if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    len_prefix = len(tokenizer(prefix)['input_ids'])
    predicted_prefix = reverse_generate(bwd_model, tokenizer, suffix, len_prefix)[0]
    predicted_prefix = tokenizer.decode(tokenizer.encode(predicted_prefix)[:len_prefix])

    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix))
    print(f'True prefix is:\n{prefix} \n\nPredicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')
    loss.append(predicted_suffix_loss.item())
# print(f'Average tokenwise accuracy is {sum(tokenwise_acc)/len(tokenwise_acc)}')
print(f'Average loss is {sum(loss)/len(loss)}')

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


True prefix is:
                                                                                  [PUBLISH]


                   

Predicted prefix:
           [DO NOT PUBLISH]


                IN
for suffix:
 IN THE UNITED STATES COURT OF APPEALS

                            FOR THE ELEVENTH CIRC
Loss for suffix given predicted prefix is 1.5578012466430664 
 Suffix loss for true prefix is 1.1715296506881714
NLL on predicted prefix is 5.3471269607543945 
 NLL on true prefix is 3.7086830139160156
True prefix is:
Mystikal (album)

My 

Predicted prefix:
 he said.Mystikal

My
for suffix:
 stikal is the eponymous self-titled debut studio album by American rapper Mystikal. It was independently self-released on June 14, 1994, by Big Boy Records. The
Loss for suffix given predicted prefix is 2.3714146614074707 
 Suffix loss for true prefix is 2.054072141647339
NLL on predicted prefix is 4.269046783447266 
 NLL on true prefix is 3.6076202392578125


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


True prefix is:
5. Let b be (-105)/(-10)* 

Predicted prefix:
*k. Let b be 4/2*
for suffix:
 6/k. Let y be (-3)/1 + (-196)/b. Solve -y = 4*u + 3*v, -2*v + 5*v = -2
Loss for suffix given predicted prefix is 1.522289514541626 
 Suffix loss for true prefix is 1.646061658859253
NLL on predicted prefix is 2.721585750579834 
 NLL on true prefix is 3.1036183834075928
True prefix is:
Hyperolius ferrugineus

Hyper 

Predicted prefix:
Hyperolius ferrugineus

Hyper
for suffix:
 olius ferrugineus is a species of frog in the family Hyperoliidae.
It is endemic to Democratic Republic of the Congo.
Its natural habitats are subtropical or tropical mois
Loss for suffix given predicted prefix is 1.979470133781433 
 Suffix loss for true prefix is 1.979470133781433
NLL on predicted prefix is 4.303708553314209 
 NLL on true prefix is 4.303708553314209


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


True prefix is:
Hatamabad, Markazi

H 

Predicted prefix:
 he said.Jatamabad

J
for suffix:
 atamabad (, also Romanized as Ḩātamābād) is a village in Farmahin Rural District, in the Central District of Farahan County, Markazi
Loss for suffix given predicted prefix is 1.5585622787475586 
 Suffix loss for true prefix is 1.4432446956634521
NLL on predicted prefix is 4.541042804718018 
 NLL on true prefix is 4.4801740646362305
True prefix is:
<?php defined('BX_DOL') 

Predicted prefix:
_BEFORE_HACK_PAUSE')
for suffix:
  or die('hack attempt');
/**
 * Copyright (c) UNA, Inc - https://una.io
 * MIT License - https://opensource.org/licenses/MIT
Loss for suffix given predicted prefix is 1.8252325057983398 
 Suffix loss for true prefix is 1.6047663688659668
NLL on predicted prefix is 4.435699462890625 
 NLL on true prefix is 4.108913421630859


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


True prefix is:
<header class="header-wrapper">

   

Predicted prefix:
>
<div class="container">
  
for suffix:
 <nav class="inner">
    <div class="title">
      <a href="/">
        <img class="logo" src="<%- url_for(theme.profile
Loss for suffix given predicted prefix is 1.8682317733764648 
 Suffix loss for true prefix is 1.8162084817886353
NLL on predicted prefix is 1.1058136224746704 
 NLL on true prefix is 2.780362367630005
True prefix is:
/*
Copyright (C) 2011 Mark Chandler ( 

Predicted prefix:
 part of the Desura(R) project (
for suffix:
 Desura Net Pty Ltd)
Copyright (C) 2014 Bad Juju Games, Inc.

This program is free software: you can redistribute it and/or modify
it under the
Loss for suffix given predicted prefix is 2.050963878631592 
 Suffix loss for true prefix is 1.5706523656845093
NLL on predicted prefix is 4.815761566162109 
 NLL on true prefix is 4.146985054016113


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


True prefix is:
The present invention relates to catalyst components for the polymerization 

Predicted prefix:
, to the use of said catalysts in the polymerization
for suffix:
  of olefins, to the catalysts obtained therefrom and to the use of said catalysts in the polymerization of olefins CH2xe2x95x
Loss for suffix given predicted prefix is 1.7928210496902466 
 Suffix loss for true prefix is 1.878375768661499
NLL on predicted prefix is 3.803551435470581 
 NLL on true prefix is 2.978477954864502
True prefix is:
1. Introduction {#sec1-ijerph-17 

Predicted prefix:
1. Introduction {#sec1-ijerph-17
for suffix:
 -01067}

Nasolacrimal duct obstruction (NLDO) is the most common cause of childhood epiphora \[[@B1-ijerph-17-010
Loss for suffix given predicted prefix is 1.7206493616104126 
 Suffix loss for true prefix is 1.7206493616104126
NLL on predicted prefix is 1.8608165979385376 
 NLL on true prefix is 1.8608165979385376
Average loss is 1.824743640422821


In [None]:
from utils import reverse_tokenize, reverse_decode
def reverse_normalized_forward(reverse_model, tokenizer, target, normalizer):
    inputs = reverse_tokenize(tokenizer, target)
    outputs = reverse_model(inputs).logits[0,-1,:]
    outputs = torch.nn.Softmax(dim=-1)(outputs).cpu()
    outputs = torch.mul(outputs, normalizer)
    return outputs

def reverse_normalized_generate(reverse_model, tokenizer, target, max_length, normalizer, temperature=1):
    prefix = []
    for i in range(max_length):
        normalized_probs = reverse_normalized_forward(reverse_model, tokenizer, ''.join(prefix[::-1]) + target, normalizer)
        if not temperature:
            token = tokenizer.decode(torch.argmax(normalized_probs))
        else:
            probs = torch.div(normalized_probs, temperature)
            probs = torch.nn.Softmax(dim=-1)(probs)
            token = tokenizer.decode(torch.multinomial(probs, num_samples=1))
        if token == '[PAD]' or token == '[EOS]':
            break
        prefix.append(token)
    return ''.join(prefix[::-1])+target

inverse_dataset_probs = torch.reciprocal(dataset_probs)
reverse_normalized_generate(bwd_model, tokenizer, ' on the mat next to the kitchen.', 5, inverse_dataset_probs**0.1, temperature=0)

In [156]:
tokenwise_acc = []
loss = []
for p,pair in enumerate(pairs):
    if len(loss)==250: break
    if len(pair[0])<10 or len(pair[1])<10: continue
    prefix_loss,suffix_loss = forward_loss(model, pair)
    # if suffix_loss>2.1: continue #this is around 10th percentile of losses for 170m
    prefix, suffix = pair
    len_prefix = len(tokenizer(prefix)['input_ids'])
    predicted_prefix = reverse_normalized_generate(bwd_model, tokenizer, suffix, len_prefix, inverse_dataset_probs**0.25, temperature=0) #1.425 at 0.25 partial Bayes update vs 1.437 at 0 i.e. default
    predicted_prefix = tokenizer.decode(tokenizer.encode(predicted_prefix)[:len_prefix])

    predicted_prefix_loss, predicted_suffix_loss = forward_loss(model, (predicted_prefix, suffix))
    # print(f'True prefix is:\n{prefix} \n\n')
    # print(f'Predicted prefix:\n{predicted_prefix}\nfor suffix:\n {suffix}')
    # print(f'Loss for suffix given predicted prefix is {predicted_suffix_loss.item()} \n Suffix loss for true prefix is {suffix_loss.item()}')
    # print(f'NLL on predicted prefix is {predicted_prefix_loss.item()} \n NLL on true prefix is {prefix_loss.item()}')
    loss.append(predicted_suffix_loss.item())
# print(f'Average tokenwise accuracy is {sum(tokenwise_acc)/len(tokenwise_acc)}')
print(f'Average loss is {sum(loss)/len(loss)}')

Average loss is 1.5034217107743024


In [57]:
def get_token_probabilities(tokenizer, dataset="NeelNanda/pile-10k", vocab_size=50304):
    data = load_dataset(dataset)
    counts = torch.zeros(vocab_size, dtype=torch.float) #tokenizer.vocab_size is fake 50304 is the model output dimension which is what we care about

    for chunk in data['train']:
        # Extract text from chunk (assuming each chunk is a dictionary with a "text" key)
        text = chunk['text']

        # Tokenize the text
        tokens = tokenizer(text, return_tensors="pt").input_ids[0]

        # Count occurrences for each token
        for tok in tokens:
            counts[tok] += 1

    # Normalize the counts to get probabilities
    total_tokens = torch.sum(counts)
    probabilities = counts / total_tokens
    min_val = probabilities[probabilities > 0].min()
    probabilities[probabilities == 0] = min_val
    return probabilities

dataset_probs = get_token_probabilities(tokenizer)

Found cached dataset parquet (/home/jp6263/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 1/1 [00:00<00:00, 508.71it/s]


In [117]:
reverse_generate(bwd_model, tokenizer, ' on the mat next to the kitchen.', 5,)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


[' it, and set it on the mat next to the kitchen.']