In [7]:
from openicl import DatasetReader, PromptTemplate, ZeroRetriever, GenInferencer, PPLInferencer, CoTInferencer, RandomRetriever
from openicl.icl_dataset_reader import load_dataset
import pandas as pd
from accelerate import Accelerator
from QKPGeNAlgRetriever import QPKRetriever
from QPKTabuRetriever import QPKTabuRetriever
import numpy as np


qa_models = ["roberta-large", "gpt2-large"]
task = 'sentiment-analysis'
model = qa_models[1]

dataset_names = ['commonsense_qa', 'tasksource/bigbench', 'gpt3mix/sst2', 'imdb']

dataset_name = dataset_names[2]
if dataset_name == 'tasksource/bigbench':
    dataset = load_dataset(dataset_name, 'disambiguation_qa', split='train')
    dataset = dataset.train_test_split(test_size=50, train_size=156, shuffle=True)

else:
    dataset = load_dataset(dataset_name, split='train')
    dataset = dataset.train_test_split(test_size=100, train_size=500, shuffle=True)

def cmqa_pre_process(example):
    for i in range(5):
        example[chr(ord('A') + i)] = example['choices']['text'][i]
    return example

def bb_pre_process(example):
    for i in range(3):
        example[chr(ord('A') + i)] = example['multiple_choice_targets'][i]
    example['multiple_choice_scores'] = chr(ord('A') + np.where(np.array(example['multiple_choice_scores']) == 1)[0][0])
    example['context'] = "Disambiguation"
    return example

if dataset_name=='commonsense_qa':
    dataset = dataset.map(cmqa_pre_process)
    dataset = dataset.rename_column("question_concept","context")
    dataset = dataset.rename_column("answerKey","answer")
    input_cols = ["question", "context", "A", "B", "C", "D", "E"]
    data = DatasetReader(dataset=dataset, input_columns=input_cols, output_column="answer")

elif dataset_name=='wiki_qa':
    dataset = dataset.rename_column("document_title","context")
    input_cols = ["question", "context"]
    data = DatasetReader(dataset=dataset, input_columns=input_cols, output_column="answer")

elif dataset_name=='tasksource/bigbench':
    dataset = dataset.map(bb_pre_process)
    dataset = dataset.rename_column("multiple_choice_scores","answer")
    dataset = dataset.rename_column("inputs","question")
    input_cols = ["question", "context", "A", "B", "C"]
    data = DatasetReader(dataset=dataset, input_columns=input_cols, output_column="answer")

elif dataset_name=='gpt3mix/sst2' or dataset_name=='imdb':
    data = DatasetReader(dataset=dataset, input_columns=['text'], output_column="label")

# print(data['test']['te'])
# print(data['test']['answer'])

Found cached dataset sst2 (/home/nlonyuk/.cache/huggingface/datasets/gpt3mix___sst2/default/0.0.0/90167692658fa4abca2ffa3ede1a43a71e2bf671078c5c275c64c4231d5a62fa)


In [8]:
# Accelerate Prepare
accelerator = Accelerator()

In [9]:
#common sense QA
cmsqa_template=PromptTemplate(
    {
        'A': "</E>Answer the following question:\n</Q>\nAnswer: </Ans1>",
        'B': "</E>Answer the following question:\n</Q>\nAnswer: </Ans2>",
        'C': "</E>Answer the following question:\n</Q>\nAnswer: </Ans3>",
        'D': "</E>Answer the following question:\n</Q>\nAnswer: </Ans4>",
        'E': "</E>Answer the following question:\n</Q>\nAnswer: </Ans5>",
    },
    {'question':'</Q>', 'A': '</Ans1>', 'B': '</Ans2>', 'C': '</Ans3>', 'D': '</Ans4>', 'E': '</Ans5>'},
    ice_token='</E>' 
)

bb_template=PromptTemplate(
    {
        'A': "</E>Answer the following question:\n</Q>\nAnswer: </Ans1>",
        'B': "</E>Answer the following question:\n</Q>\nAnswer: </Ans2>",
        'C': "</E>Answer the following question:\n</Q>\nAnswer: </Ans3>"
    },
    {'question':'</Q>', 'A': '</Ans1>', 'B': '</Ans2>', 'C': '</Ans3>'},
    ice_token='</E>' 
)
sst2_template=PromptTemplate({
            0: '</E>Positive Movie Review: \"<X>\"', 
            1: '</E>Negative Movie Review: \"<X>\"',
        }, column_token_map={'text' : '<X>'}, 
        ice_token='</E>'
    )



In [17]:
# retriever = QPKTabuRetriever(data, model=model, task=task, ice_num=10, sentence_transformer='sentence-transformers/all-mpnet-base-v2')
retriever = ZeroRetriever(data)
# retriever = RandomRetriever(data)
# retr_idxs = retriever.retrieve()

# for idx in retr_idxs[0]:
#     print(f"{retriever.train_ds['question'][idx]} --- {retriever.train_ds['answers'][idx]['text']}")


In [18]:
# Define a Inferencer
inferencer = PPLInferencer(model_name=model, accelerator=accelerator)

If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`


In [19]:
# Inference
predictions = inferencer.inference(retriever, ice_template=sst2_template)
print(predictions)

[2023-06-14 07:39:39,884] [openicl.icl_inferencer.icl_ppl_inferencer] [INFO] Calculating PPL for prompts labeled '0'
100%|██████████| 100/100 [00:21<00:00,  4.71it/s]
[2023-06-14 07:40:01,115] [openicl.icl_inferencer.icl_ppl_inferencer] [INFO] Calculating PPL for prompts labeled '1'
 41%|████      | 41/100 [00:08<00:14,  4.11it/s]

In [14]:
print(retriever.test_ds['text'])
    
print(retriever.test_ds['label'])



predictions


['A zombie movie in every sense of the word -- mindless , lifeless , meandering , loud , painful , obnoxious .', 'A solid , spooky entertainment worthy of the price of a ticket .', 'Great fun both for sports aficionados and for ordinary louts whose idea of exercise is climbing the steps of a stadium-seat megaplex .', 'Criminal conspiracies and true romances move so easily across racial and cultural lines in the film that it makes My Big Fat Greek Wedding look like an apartheid drama .', 'Goofy , nutty , consistently funny .', "Frida is n't that much different from many a Hollywood romance .", 'It is parochial , accessible to a chosen few , standoffish to everyone else , and smugly suggests a superior moral tone is more important than filmmaking skill', "We hate -LRB- Madonna -RRB- within the film 's first five minutes , and she lacks the skill or presence to regain any ground .", "She 's all-powerful , a voice for a pop-cyber culture that feeds on her Bjorkness .", "-LRB- Moore 's -RRB

[0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [16]:
accuracy = np.sum(np.array(retriever.test_ds['label']) == np.array(predictions))
accuracy

43