In [24]:
import os
import torch
import numpy as np
import pandas as pd
from faiss import IndexIDMap, IndexFlatIP
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM
from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer, AccEvaluator
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [25]:
exemplar_count = 10
dataset = load_dataset("ag_news")
dataset["train"] = dataset["train"].select(range(100))
dataset["test"] = dataset["test"].select(range(100))
data = DatasetReader(dataset, input_columns=["text"], output_column="label")
tp_dict = {
    0: "</text>:0</E>",
    1: "</text>:1</E>",
    2: "</text>:2</E>",
    3: "</text>:3</E>",
}

template = PromptTemplate(tp_dict, {'text': '</text>'}, ice_token='</E>')
exemplar_retriever = TopkRetriever(data, ice_num=exemplar_count, index_split='train')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

Found cached dataset ag_news (/home/kyle/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
100%|██████████| 2/2 [00:00<00:00, 617.85it/s]
[2023-04-26 21:17:47,461] [openicl.icl_retriever.icl_topk_retriever] [INFO] Creating index for index set...
  0%|          | 0/100 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 100/100 [00:01<00:00, 68.24it/s]


In [28]:
edit_entries = []
num_successfully_edits = 0
edit_retriever = IndexIDMap(IndexFlatIP(exemplar_retriever.model.get_sentence_embedding_dimension()))

for i in range(len(dataset["test"])):
    current_entry = dataset["test"][i]
    input_text = current_entry["text"]
    input_label = current_entry["label"]
    input_sequence_embedding = exemplar_retriever.model.encode([input_text], convert_to_numpy=True)
    distances, exemplar_indices = exemplar_retriever.index.search(input_sequence_embedding, k=exemplar_count)
    exemplars = [exemplar_retriever.dataset_reader.dataset["train"][int(index)] for index in exemplar_indices[0]]
    prompt_lines = [template.generate_ice_item(entry, entry["label"]) for entry in exemplars]
    prompt_lines.append(input_text + ":")
    prompts = "\n".join(prompt_lines)

    tokenized_prompt = tokenizer.encode(prompts, return_tensors="pt").to(device)
    outputs = model.generate(
                tokenized_prompt,
                max_new_tokens=1,
                do_sample=False,
                output_scores=True,
                return_dict_in_generate=True,
                pad_token_id=tokenizer.eos_token_id)
    
    
    judgment = None
    try:
        judgment = int(tokenizer.decode(outputs.sequences[:, -1]))
    except:
        pass

    if judgment != input_label:
        edit_entries.append({
            "text": input_text,
            "label": input_label,
        })
        edit_retriever.add_with_ids(input_sequence_embedding, np.array([len(edit_entries) - 1]))

        # Get edit pool exemplars - filter out -1 indices
        edit_distances, edit_exemplar_indices = edit_retriever.index.search(input_sequence_embedding, k=exemplar_count)
        edit_exemplar_indices = [int(index) for index in edit_exemplar_indices[0] if index != -1]
        edit_exemplars = [edit_entries[index] for index in edit_exemplar_indices]

        # Backfill with exemplars from the original dataset
        if len(edit_exemplars) < exemplar_count:
            exemplar_index = 0
            while exemplar_index < 4:
                edit_exemplars.append(exemplars[exemplar_index])
                exemplar_index += 1
        
        edit_prompt_lines = [template.generate_ice_item(entry, entry["label"]) for entry in edit_exemplars]
        edit_prompt_lines.reverse()
        edit_prompt_lines.append(input_text + ":")
        edit_prompt = "\n".join(edit_prompt_lines)
        tokenized_prompt = tokenizer.encode(prompts, return_tensors="pt").to(device)
        outputs = model.generate(
                    tokenized_prompt,
                    max_new_tokens=1,
                    do_sample=False,
                    output_scores=True,
                    return_dict_in_generate=True,
                    pad_token_id=tokenizer.eos_token_id)
        
        
        judgment = None
        try:
            judgment = int(tokenizer.decode(outputs.sequences[:, -1]))
        except:
            pass

        # print(edit_prompt)
        # print(f"input label: {input_label}, new token: {judgment}")
        if judgment == input_label:
            num_successfully_edits += 1

print(f"num_successfully_edits: {num_successfully_edits}")
print(f"num_edits: {len(edit_entries)}")
print(f"success rate: {num_successfully_edits / len(edit_entries)}")
edit_entries


num_successfully_edits: 0
num_edits: 81
success rate: 0.0


[{'text': 'The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\\privately funded suborbital space flight, has officially announced the first\\launch date for its manned rocket.',
  'label': 3},
 {'text': 'Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.',
  'label': 3},
 {'text': "Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar.",
  'label': 3},
 {'text': "

In [27]:
x = [1, 2, 3]
x.reverse()
x

[3, 2, 1]