In [1]:
import transformers
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import os
import re
import random
import datasets
from datasets import load_dataset
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Dict
from tqdm import tqdm
import pickle
from dotenv import load_dotenv
import openai
import json

from typing import List, Optional, Tuple, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

import plotly.graph_objects as go
import plotly.express as px

from utils import untuple, eval_completions, levenshtein_distance
from act_add.contrast_dataset import ContrastDataset
from scripts.get_activations import gen_pile_data, compare_token_lists, slice_acts
from act_add.model_wrapper import ModelWrapper
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = 'EleutherAI/pythia-12b'
# use two gpus
model = AutoModelForCausalLM.from_pretrained(model_name, device_map = 'auto')
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side = 'left')
tokenizer.pad_token = tokenizer.eos_token

# dataset_name = 'duped.12b'
# N_PROMPTS = 5000
# mem_data = load_dataset('EleutherAI/pythia-memorized-evals')[dataset_name]
# pile_prompts = gen_pile_data(N_PROMPTS, tokenizer, min_n_toks = 64)
mw = ModelWrapper(model, tokenizer)

Loading checkpoint shards: 100%|██████████| 3/3 [00:37<00:00, 12.40s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
mw = ModelWrapper(model, tokenizer)

In [4]:
# load probe
from probes import LRProbe
path = '../../gld/train-data-probes/data/12b/'

# open .pth file
probe_weights = torch.load(path + 'probe_weights_regged.pth')
probe = LRProbe.from_weights(probe_weights['net.0.weight'], probe_weights['net.0.bias'])
probe.to('cuda')

LRProbe(
  (net): Sequential(
    (0): Linear(in_features=5120, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

In [5]:
from datasets import load_from_disk, Dataset, DatasetDict

dataset = load_from_disk(os.path.join(path, 'hf_token_dataset_v2'))

# # filter for labels == 1
# prompts = [x[:32] for x in dataset['test'].filter(lambda x: x['labels'] == 1)['input_ids']]
# prompts = tokenizer.batch_decode(prompts, skip_special_tokens=True)
# prompts = list(set(prompts))

# words_list = [x.split() for x in dataset['quotes'].filter(lambda x: x['labels'] == 1)['text']]
# prompts = [x[:len(x) // 2] for x in words_list]
# prompts = [' '.join(x) for x in prompts]
# # prompts = list(set(prompts))

prompts = [x[:32] for x in dataset.filter(lambda x: x['labels'] == 1)['input_ids']]
ref = dataset.filter(lambda x: x['labels'] == 1)['input_ids']

print(len(prompts))
print(prompts[0])

3373
[187, 186, 94, 187, 186, 7481, 13830, 299, 18, 3843, 27045, 9, 9316, 64, 9465, 21840, 64, 44285, 64, 26679, 13, 10131, 9, 18939, 15, 11732, 10107, 81, 17, 9679, 10131, 9]


In [6]:
dataset

Dataset({
    features: ['input_ids', 'labels', 'df_ref_idx'],
    num_rows: 6746
})

In [7]:
# sample 500 
seed = 0
random.seed(seed)
np.random.seed(seed)
prompts, ref = zip(*random.sample(list(zip(prompts, ref)), 200))
prompts = list(prompts)
ref = list(ref)

print(len(prompts))

200


In [8]:
# batch_size = 10
# decoded_refs = [tokenizer.decode(x, skip_special_tokens=True) for x in ref]


# final_generations = []
# prompt_idxs = []

# i = 0
# while len(final_generations) < 100: 
#     input_ids = torch.tensor(prompts[i:i+batch_size]).cuda()
#     out = model.generate(
#                     input_ids=input_ids,
#                     pad_token_id=tokenizer.eos_token_id,
#                     max_new_tokens = 32,
#                     return_dict_in_generate = True,
#                     output_hidden_states = False,
#                     do_sample=False,
#                     )
#     decoded = [tokenizer.decode(x, skip_special_tokens=True) for x in out['sequences']]
#     # print(decoded_refs[i])
#     distance = levenshtein_distance(decoded, decoded_refs[i:i+batch_size])
#     print(distance)
#     for j, d in enumerate(distance):
#         if d == 1: 
#             final_generations.append(decoded[j])
#             prompt_idxs.append(i+j)
#     print(len(final_generations))
#     i += batch_size

In [9]:
# len(final_generations)

In [10]:
# final_generations = final_generations[:100]
# prompt_idxs = prompt_idxs[:100]

In [11]:
# with open('rej_sample/normal_generations.pkl', 'wb') as f:
#     pickle.dump(final_generations, f)

# with open('rej_sample/normal_generations_idxs.pkl', 'wb') as f:
#     pickle.dump(prompt_idxs, f)



In [12]:
with open('rej_sample/normal_generations.pkl', 'rb') as f:
    final_generations = pickle.load(f)

with open('rej_sample/normal_generations_idxs.pkl', 'rb') as f:
    prompt_idxs = pickle.load(f)


In [13]:
new_generations = {}
distances = {}

In [18]:
batch_size = 12

# new_generations = {}
# distances = {}

decoded_refs = [tokenizer.decode(x, skip_special_tokens=True) for x in ref]

new_generations[32] = []
for i in tqdm(range(0, 60, batch_size)):
    idxs = prompt_idxs[i:i+batch_size]
    to_run = torch.tensor(prompts)[idxs].cuda()
    t_gen, _, _ = mw.rej_sampl_generate(to_run, probe, 34, rej_sample_length=32, max_tries=1, log_rej_samples=True, max_new_tokens=32)
    new_generations[32].extend(t_gen)

distances[32] = levenshtein_distance(new_generations[32], decoded_refs)
print(distances)


  0%|          | 0/5 [00:00<?, ?it/s]

part 0 of generation


In [16]:
new_generations.keys()

dict_keys([17])

In [17]:
with open('rej_sample/rej_generations17.pkl', 'wb') as f:
    pickle.dump(new_generations, f)

In [10]:
t_gen

[' for an item by clicking the product "Add to Cart" button or "See Price In Cart" link.\n\nPlease be assured that simply adding an item into Your Cart (except in cases where it is indicated that the item is in stock and/or available for purchase) does NOT reserve the item for You - the buyer.',
 '.ly/1W9Lk0n\nRead more: http://www.businessinsider.com/\n--------------------------------------------------\nBusiness Insider is the hottest new site on the web, covering entrepreneurship news, tech news and more.\n--------------------------------------------------Q:\n\nHow to get the value of a variable in a function']

In [11]:
t_tries, t_failures

(tensor([35., 20.]), 11)

In [8]:
input_ids = torch.tensor(prompts[:2]).to(model.device)

out = model.generate(
                    input_ids=input_ids,
                    pad_token_id=tokenizer.eos_token_id,
                    max_new_tokens = 32,
                    return_dict_in_generate = True,
                    output_hidden_states = True,
                    )

In [9]:
out['sequences'][0], tokenizer.decode(out['sequences'][0])

(tensor([  323,   271,  5382,   407, 19009,   253,  1885,   346,  4717,   281,
         16619,     3,  6409,   390,   346,  5035, 16040,   496, 16619,     3,
          3048,    15,   187,   187,  7845,   320, 17839,   326,  3365,  6240,
           271,  5382,   281,   634,  7281,  1057,   417,  7206,   366,   368,
           281,  4489,   352,    15,  1422,   476,  1900,  1818,   634,  2564,
           285, 11352,   253,  5382,   432,   634,  7281,   604,   368,  7617,
           417,   281,  7471,   352], device='cuda:0'),
 ' for an item by clicking the product "Add to Cart" button or "See Price In Cart" link.\n\nPlease be assured that simply adding an item to your cart does not obligate you to buy it. You can always change your mind and delete the item from your cart if you decide not to purchase it')

In [11]:
out['sequences'], tokenizer.decode(out['sequences'])

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [17]:
torch.tensor(ref[0]), tokenizer.decode(ref[0])

(tensor([  323,   271,  5382,   407, 19009,   253,  1885,   346,  4717,   281,
         16619,     3,  6409,   390,   346,  5035, 16040,   496, 16619,     3,
          3048,    15,   187,   187,  7845,   320, 17839,   326,  3365,  6240,
           271,  5382,   281,   634,  7281,  1057,   417,  7206,   366,   368,
           281,  4489,   352,    15,  1422,   476,  1900,  1818,   634,  2564,
           285, 11352,   253,  5382,   432,   634]),
 ' for an item by clicking the product "Add to Cart" button or "See Price In Cart" link.\n\nPlease be assured that simply adding an item to your cart does not obligate you to buy it. You can always change your mind and delete the item from your')

In [None]:
out = model.generate(tokenizer(prompts[0], return_tensors='pt')['input_ids'], max_length=1, do_sample=False, temperature=0.0, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_hidden_states=True, bad_words_ids=[[625], [253]])
out.sequences

tensor([[26851,   257,   900,   844,   304,   423,   310,   253, 17696, 11981,
         25113,   296,   323,   253,  5197,   457,    84,   443,  4539,   285,
         23325, 11981,  2285,    15,  3837, 23967,  3658,   187,   187,  1992,
          1239,   436]])

In [None]:
len(out.sequences[0]), len(tokenizer(prompts[0], return_tensors='pt')['input_ids'][0])

(32, 31)