In [31]:
from openicl import (DatasetReader, PromptTemplate, 
                     ZeroRetriever, RandomRetriever, BM25Retriever,
                     GenInferencer, PPLInferencer)
from openicl.icl_dataset_reader import load_dataset
import pandas as pd
from accelerate import Accelerator
from QPKTabuRetriever import QPKTabuRetriever
import numpy as np
import matplotlib.pyplot as plt

In [32]:
MODELS = ["roberta-large", "gpt2-large"]
TASKS = ['question-answering', 'sentiment-analysis']
DATASET_NAMES = {
    'question-answering':['commonsense_qa','tasksource/bigbench'],
    'sentiment-analysis':['imdb', 'gpt3mix/sst2']
}
RETRIEVERS = ['zero', 'random', 'bm25', 'qkp']

In [33]:
def cmqa_pre_process(example):
    for i in range(5):
        example[chr(ord('A') + i)] = example['choices']['text'][i]
    return example

def bb_pre_process(example):
    for i in range(3):
        example[chr(ord('A') + i)] = example['multiple_choice_targets'][i]
    example['multiple_choice_scores'] = chr(ord('A') + np.where(np.array(example['multiple_choice_scores']) == 1)[0][0])
    example['context'] = "Disambiguation"
    return example

In [34]:
def select_dataset(name):
    if name == 'commonsense_qa':
        dataset = load_dataset(name, split='train')
        dataset = dataset.train_test_split(test_size=10, train_size=20, shuffle=True)
        dataset = dataset.map(cmqa_pre_process)
        dataset = dataset.rename_column("question_concept","context")
        dataset = dataset.rename_column("answerKey","answer")
        input_cols = ["question", "context", "A", "B", "C", "D", "E"]
        return DatasetReader(dataset=dataset, input_columns=input_cols, output_column="answer")
    elif name == 'tasksource/bigbench':
        dataset = load_dataset(name, 'disambiguation_qa', split='train')
        dataset = dataset.train_test_split(test_size=10, train_size=20, shuffle=True)
        dataset = dataset.map(bb_pre_process)
        dataset = dataset.rename_column("multiple_choice_scores","answer")
        dataset = dataset.rename_column("inputs","question")
        input_cols = ["question", "context", "A", "B", "C"]
        return DatasetReader(dataset=dataset, input_columns=input_cols, output_column="answer")
    elif name == 'imdb' or name == 'gpt3mix/sst2':
        dataset = load_dataset(name, split='train')
        dataset = dataset.train_test_split(test_size=10, train_size=20, shuffle=True)
        return DatasetReader(dataset=dataset, input_columns=["text"], output_column="label")


In [35]:
TEMPLATES = {
    'commonsense_qa':PromptTemplate(
        {
            'A': "</E>Answer the following question:\n</Q>\nAnswer: </Ans1>",
            'B': "</E>Answer the following question:\n</Q>\nAnswer: </Ans2>",
            'C': "</E>Answer the following question:\n</Q>\nAnswer: </Ans3>",
            'D': "</E>Answer the following question:\n</Q>\nAnswer: </Ans4>",
            'E': "</E>Answer the following question:\n</Q>\nAnswer: </Ans5>",
        },
        {'question':'</Q>', 'A': '</Ans1>', 'B': '</Ans2>', 'C': '</Ans3>', 'D': '</Ans4>', 'E': '</Ans5>'},
        ice_token='</E>' 
    ),
    'tasksource/bigbench':PromptTemplate(
        {
            'A': "</E>Answer the following question:\n</Q>\nAnswer: </Ans1>",
            'B': "</E>Answer the following question:\n</Q>\nAnswer: </Ans2>",
            'C': "</E>Answer the following question:\n</Q>\nAnswer: </Ans3>"
        },
        {'question':'</Q>', 'A': '</Ans1>', 'B': '</Ans2>', 'C': '</Ans3>'},
        ice_token='</E>' 
    ),
    'imdb':PromptTemplate({
            0: '</E>Positive Movie Review: \"<X>\"', 
            1: '</E>Negative Movie Review: \"<X>\"',
        }, column_token_map={'text' : '<X>'}, 
        ice_token='</E>'
    ),
    'gpt3mix/sst2':PromptTemplate({
            0: '</E>Positive Movie Review: \"<X>\"', 
            1: '</E>Negative Movie Review: \"<X>\"',
        }, column_token_map={'text' : '<X>'}, 
        ice_token='</E>'
    ),
}

In [36]:
def select_retriever(retr_name, data, model, task, ice_num, accelerator):
    if retr_name == 'zero':
        return ZeroRetriever(data)
    elif retr_name == 'random':
        return RandomRetriever(data, ice_num=ice_num, accelerator=accelerator)
    elif retr_name == 'bm25':
        return BM25Retriever(data, ice_num=ice_num, accelerator=accelerator)
    elif retr_name == 'qkp':
        return QPKTabuRetriever(data, model=model, task=task, ice_num=ice_num, accelerator=accelerator)
    else:
        raise Exception()

In [37]:
results = {
    'model':[],
    'task':[],
    'dataset':[],
    'retriever':[],
    'accuracy_mean':[],
    'accuracy_std':[],
    'predictions':[],
    'inputs':[]
}

accelerator = Accelerator()
ice_num = 5
reps = 3

for model in MODELS:
    for task in TASKS:
        for dataset_name in DATASET_NAMES[task]:
            for retr_name in RETRIEVERS:
                print(retr_name)
                accuracies = list()
                all_predictions = list()
                all_inputs = list()
                results['model'].append(model)
                results['task'].append(task)
                results['dataset'].append(dataset_name)
                results['retriever'].append(retr_name)

                for _ in range(reps):
                    data = select_dataset(dataset_name)
                    retriever = select_retriever(retr_name, data, model, task, ice_num, accelerator)
                    inferencer = PPLInferencer(model_name=model, accelerator=accelerator)
                    ice_template = TEMPLATES[dataset_name] 
                    predictions = inferencer.inference(retriever, ice_template=ice_template)
                    all_predictions.append(predictions)
                    all_inputs.append(retriever.test_ds[retriever.dataset_reader.input_columns[0]])
                    accuracies.append(np.sum(np.sum(np.array(retriever.test_ds[retriever.dataset_reader.output_column]) == np.array(predictions))))
                
                results['accuracy_mean'].append(np.mean(accuracies))
                results['accuracy_std'].append(np.std(accuracies))
                results['predictions'].append(all_predictions)
                results['inputs'].append(all_inputs)
            

Using the latest cached version of the module from /home/nlonyuk/.cache/huggingface/modules/datasets_modules/datasets/gpt3mix--sst2/90167692658fa4abca2ffa3ede1a43a71e2bf671078c5c275c64c4231d5a62fa (last modified on Fri Jun  2 10:43:35 2023) since it couldn't be found locally at gpt3mix/sst2., or remotely on the Hugging Face Hub.
Found cached dataset sst2 (/home/nlonyuk/.cache/huggingface/datasets/gpt3mix___sst2/default/0.0.0/90167692658fa4abca2ffa3ede1a43a71e2bf671078c5c275c64c4231d5a62fa)


zero


[2023-06-14 09:02:48,472] [openicl.icl_inferencer.icl_ppl_inferencer] [INFO] Calculating PPL for prompts labeled '0'
100%|██████████| 10/10 [00:04<00:00,  2.20it/s]
[2023-06-14 09:02:53,022] [openicl.icl_inferencer.icl_ppl_inferencer] [INFO] Calculating PPL for prompts labeled '1'
100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Using the latest cached version of the module from /home/nlonyuk/.cache/huggingface/modules/datasets_modules/datasets/gpt3mix--sst2/90167692658fa4abca2ffa3ede1a43a71e2bf671078c5c275c64c4231d5a62fa (last modified on Fri Jun  2 10:43:35 2023) since it couldn't be found locally at gpt3mix/sst2., or remotely on the Hugging Face Hub.
Found cached dataset sst2 (/home/nlonyuk/.cache/huggingface/datasets/gpt3mix___sst2/default/0.0.0/90167692658fa4abca2ffa3ede1a43a71e2bf671078c5c275c64c4231d5a62fa)


random


In [None]:
df_results = pd.DataFrame(results)

In [None]:
def plot(x, ys, title, labels, savefile):
    configs = ['g*-', 'bo-', 'r+-']
    fig = plt.figure()
    ax = fig.gca()
    # ax.set_xscale('log')
    ax.set_xticks(x)
    ax.set_xticklabels(x)
    ax.set_title(title)
    plt.grid()
    for idx, y in enumerate(ys):
        plt.plot(x, y, configs[idx], label=labels[idx])
        plt.legend()
    plt.savefig(f'{savefile}.png')

In [None]:
for task in TASKS:
    df_instance = df_results[df_results['task'] == task]
    evals = [df_instance[df_instance['retriever'] == retr_name]['evals'] for retr_name in RETRIEVERS]

    plot(DATASET_NAMES[task], evals, f'Mean acccuracy for task: {task}', RETRIEVERS, f'{task}_evals')
