In [None]:
%pip install --upgrade pip
%pip install openicl

# Self-adaptive In-context Learning
---
Code for paper [Self-adaptive In-context Learning](https://arxiv.org/abs/2212.10375)

## Templates 

In [1]:
from openicl import PromptTemplate

# SST-2
sst2_tp_dict = {
    0: '</E>Positive Movie Review: \"<X>\"', 
    1: '</E>Negative Movie Review: \"<X>\"',
}
sst2_template = PromptTemplate(sst2_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')

# SST-5
sst5_tp_dict = {
    0: "</E>Review: <X>\nSentiment: terrible",
    1: "</E>Review: <X>\nSentiment: bad",
    2: "</E>Review: <X>\nSentiment: okay",
    3: "</E>Review: <X>\nSentiment: good",
    4: "</E>Review: <X>\nSentiment: great",
}
sst5_template = PromptTemplate(sst5_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')

# AG_NEWS
ag_news_tp_dict = {
    0: "</E>\"<X>\" It is about world.",
    1: "</E>\"<X>\" It is about sports.",
    2: "</E>\"<X>\" It is about business.",
    3: "</E>\"<X>\" It is about science and technology.",
}
ag_news_template = PromptTemplate(ag_news_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')

# TREC
trec_tp_dict = {
    0: "</E>\"<X>\" It is about abbreviation.",
    1: "</E>\"<X>\" It is about entity.",
    2: "</E>\"<X>\" It is about description and abstract concept.",
    3: "</E>\"<X>\" It is about human being.",
    4: "</E>\"<X>\" It is about location.",
    5: "</E>\"<X>\" It is about numeric value."
}
trec_template = PromptTemplate(trec_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')

# SNLI & MNLI
xnli_tp_dict = {
    0: '</E><X1>? Yes, <X2>',
    1: '</E><X1>? Maybe, <X2>',
    2: '</E><X1>? No, <X2>'
}
xnli_template = PromptTemplate(xnli_tp_dict, column_token_map={'premise' : '<X1>', 'hypothesis' : '<X2>'}, ice_token='</E>')

# QNLI 
qnli_tp_dict = {
    0: "</E><X1> Can we know <X2>? Yes.",
    1: "</E><X1> Can we know <X2>? No.",
}
qnli_template = PromptTemplate(qnli_tp_dict, column_token_map={'sentence' : '<X1>', 'question' : '<X2>'}, ice_token='</E>')

# Commonsense QA
cmsqa_template=PromptTemplate(
    {
        'A': "</E>Answer the following question:\n</Q>\nAnswer: </Ans1>",
        'B': "</E>Answer the following question:\n</Q>\nAnswer: </Ans2>",
        'C': "</E>Answer the following question:\n</Q>\nAnswer: </Ans3>",
        'D': "</E>Answer the following question:\n</Q>\nAnswer: </Ans4>",
        'E': "</E>Answer the following question:\n</Q>\nAnswer: </Ans5>",
    },
    {'question':'</Q>', 'A': '</Ans1>', 'B': '</Ans2>', 'C': '</Ans3>', 'D': '</Ans4>', 'E': '</Ans5>'},
    ice_token='</E>' 
)

templates = {'sst2': sst2_template,
             'snli': xnli_template,
             'mnli': xnli_template,
             "qnli": qnli_template,
             "sst5": sst5_template,
             "ag_news": ag_news_template,
             "trec": trec_template,
             "commonsense_qa": cmsqa_template
            }

In [None]:
## Datasets 

In [None]:
from datasets import load_dataset
from openicl import DatasetReader

data_path = {'sst2': ["gpt3mix/sst2", None],
             'snli': ['snli', None],
             'mnli': ['LysandreJik/glue-mnli-train', None],
             "qnli": ["glue", "qnli"],
             "sst5": ["SetFit/sst5", None],
             "ag_news": ["ag_news", None],
             "trec": ["trec", None],
             "commonsense_qa": ["commonsense_qa", None]
            }

input_columns={'sst2': ["text"],
             'snli': ['premise', 'hypothesis'],
             'mnli': ['premise', 'hypothesis'],
             "qnli": ["sentence", "question"],
             "sst5": ["text"],
             "ag_news": ["text"],
             "trec": ["text"],
             "commonsense_qa": ['question', 'A', 'B', 'C', 'D', 'E']
            }

output_column={'sst2': 'label',
             'snli': 'label',
             'mnli': 'label',
             "qnli": 'label',
             "sst5": 'label',
             "ag_news": 'label',
             "trec": 'label-coarse',
             "commonsense_qa": "answerKey"
            }

# Change it for other tasks
task_name='snli'

path,name=data_path[task_name]
dataset = load_dataset(path=path,name=name)

# Preprocess for commonsense_qa
def pre_process(example):
    for i in range(5):
        example[chr(ord('A') + i)] = example['choices']['text'][i]
    return example

if task_name=='commonsense_qa':
    dataset=dataset.map(pre_process).remove_columns(['question_concept', 'id', 'choices'])


data=DatasetReader(dataset, input_columns=input_columns[task_name], output_column=output_column[task_name])


test_split={
    'sst2': 'test',
    'snli': 'test',
    "sst5": 'test',
    "ag_news": 'test',
    "trec": 'test',
    'mnli': 'validation', # cannot get gold labels for the test split
    "qnli": 'validation',
    "commonsense_qa": "validation"
}
# If you only want to test part of the test set for faster running, you can use the following codes
# dataset['test'] = dataset['test'].select(list(range(100)))
# dataset['validation'] = dataset['validation'].select(list(range(100))) # trec,agnews don't have validation
# dataset['train'] = dataset['train'].select(list(range(100)))

### TopK-MDL Experiments (method in the paper)

In [None]:
from openicl import MDLRetriever, PPLInferencer, AccEvaluator

retriever = MDLRetriever(data, ice_num=8, candidate_num=30, select_time=10, seed=1, batch_size=12, test_split=test_split[task_name])

inferencer = PPLInferencer(model_name='gpt2-xl', batch_size=8)

predictions = inferencer.inference(retriever, ice_template=templates[task_name], output_json_filename=f'mdl_{task_name}')

scores = AccEvaluator().score(predictions=predictions, references=data.references)
print(scores)