In [1]:
from transformers import GPT2Tokenizer, OPTForCausalLM, GPTJForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch
import os
import numpy as np
import torch.nn.functional as F
import json

from collections import defaultdict
from tqdm import tqdm
from promptsource.templates import DatasetTemplates
import copy


@torch.no_grad()
def gen(
        model, 
        tokenizer, 
        prompt, 
        input_device, 
        num_return_sequences = 5, 
        do_sample = True,
        max_length = 32,
        temperature = 0.9
    ):
    # inputs = tokenizer.encode(prompt, return_tensors="pt").to(input_device)
    # generation_output = model.generate(inputs, return_dict_in_generate = True)
    inputs = tokenizer(prompt, return_tensors="pt",truncation=True,padding=True).to(input_device)
    generation_output = model.generate(
        inputs.input_ids,
        no_repeat_ngram_size = 3,
        temperature=temperature,
        max_length=max_length,
        do_sample=do_sample,
        num_return_sequences=num_return_sequences,
        output_scores = True,
        return_dict_in_generate = True
    )
    response = tokenizer.batch_decode(generation_output['sequences'], skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return generation_output, response

@torch.no_grad()
def cal_log_perplexity_generate(generation_output):
    # print(generation_output.keys())
    sequence_num, generated_sequence_length = generation_output['sequences'].size()[0] ,len(generation_output['scores'])
    print("generate length:", generated_sequence_length)
    generation_output['scores'] = torch.stack(list(generation_output['scores']), dim=0)
    # print(generation_output['scores'].size())
    # print(f'sequence num = {sequence_num}, generated_sequence_length = {generated_sequence_length}')
    
    perp = []
    for i in range(sequence_num):
        generated_squence_ids = generation_output['sequences'][i][-generated_sequence_length:]
        scores = generation_output['scores'][:,i,:]
        log_softmax_scores = F.log_softmax(scores, dim=1)
        # print(scores.size(),log_softmax_scores.size())
        # print(scores[0][:10], log_softmax_scores[0][:10])
        assert scores.size()[0] == generated_squence_ids.size()[0]
        generated_squence_ids = generated_squence_ids.cpu().numpy()
        log_softmax_scores = log_softmax_scores.cpu().numpy()
        log_sum = 0
        for j in range(len(generated_squence_ids)):
            idx = generated_squence_ids[j]
            log_sum += log_softmax_scores[j][idx]
        perp.append(np.exp((-1/generated_sequence_length)*log_sum))
    return perp

def cal_log_perplexity_decode(logits, input_ids):
    perp = []
    for i in range(input_ids.size()[0]):
        generated_squence_ids = input_ids[i]
        n = generated_squence_ids.size()[0]
        scores = logits[i]
        log_softmax_scores = F.log_softmax(scores, dim=1)
        assert scores.size()[0] == n
        generated_squence_ids = generated_squence_ids.cpu().numpy()
        log_softmax_scores = log_softmax_scores.cpu().numpy()
        log_sum = 0
        for j in range(len(generated_squence_ids)):
            idx = generated_squence_ids[j]
            log_sum += log_softmax_scores[j][idx]
        perp.append(np.exp((-1/n)*log_sum))
    return perp

@torch.no_grad()
def rerank(input_text, tokenizer, model, input_device):
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(input_device)
    outputs = model(**inputs, labels=inputs["input_ids"])
    # outputs = model(**inputs) #TODO
    logits = outputs.logits
    perps = cal_log_perplexity_decode(logits, inputs.input_ids)
    assert len(perps) == len(input_text)
    ranked = sorted([ (i, input_text[i], perps[i]) for i in range(len(perps))], key = lambda x: x[2])
    return ranked

In [9]:
#### GPT-2 ####
os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3,4,5,6,7"
main_model_name = "gpt2-large"

print(f"loading main model: {main_model_name}")
### set up model ###
main_model = AutoModelForCausalLM.from_pretrained(main_model_name)
main_model.parallelize()

### set up tokenizer ###
main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
main_tokenizer.pad_token = main_tokenizer.eos_token

### set input device as ###
main_input_device = main_model.device

loading main model: gpt2-large


In [None]:
#### GPT-J ####

os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3,4,5,6,7"
main_model_name = "EleutherAI/gpt-j-6B"

print(f"loading main model: {main_model_name}")
### set up model ###
main_model = GPTJForCausalLM.from_pretrained(main_model_name)
main_model.parallelize()

### set input device as ###
main_input_device = main_model.device

### set up tokenizer ###
main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
main_tokenizer.pad_token = main_tokenizer.eos_token

In [None]:
#### T0 ####

### set up device
os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3,4,5,6,7"
# main_model_name = "bigscience/T0-3B"
main_model_name = "bigscience/T0"
# main_model_name = "bigscience/T0p" # plus
# main_model_name = "bigscience/T0pp" # plus plus

### set up tokenizer ###
main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
# main_tokenizer.pad_token = main_tokenizer.eos_token

print(f"loading main model: {main_model_name}")
### set up model ###
main_model = AutoModelForSeq2SeqLM.from_pretrained(main_model_name)
main_model.parallelize()

### set input device as ###
main_input_device = main_model.device
print('model main device:',main_input_device)


In [10]:
# prompts from retrieval result json
def load_prompts_from_retrieval_result():
    retrieval_results = json.load(open("/cephfs/user/mikeeewang/summer_22/code/t-zero/output/T0_3B__super_glue__wic_debug/retrieval_results.json"))
    all_prompts = []
    tempalte_names = list(retrieval_results.keys())
    for i in range(len(retrieval_results[tempalte_names[0]])):
        prompts = []
        for template_name in tempalte_names:
            items = retrieval_results[template_name]
            prompts.append(items[i]['input'])
        all_prompts.append(prompts)
    return tempalte_names, all_prompts
tempalte_names, all_prompts = load_prompts_from_retrieval_result()

In [3]:


# dataset_name = "super_glue"
# dataset_config_name = "wic"
dataset_name = "cos_e"
dataset_config_name ="v1.11"
template_name = None

if dataset_name == "anli":
    raw_datasets = load_dataset(dataset_name, split=dataset_config_name)
else:
    raw_datasets = load_dataset(dataset_name, dataset_config_name, split="validation")

prompts_wrapper = DatasetTemplates(
    f"{dataset_name}"
    if dataset_config_name is None
    else f"{dataset_name}/{dataset_config_name}"
)


dummy_example = copy.deepcopy(raw_datasets[0])
dummy_example['word'] = "N/A"
dummy_example['sentence1'] = "N/A"
dummy_example['sentence2'] = "N/A"

# example = raw_datasets[0]
example = dummy_example

# print(example)
tempalte_names = []
prompts = []
print()
for t in list(prompts_wrapper.templates.values()):
    # print("template name:", t.name)
    input, target = t.apply(example)
    # print("INPUT:", input, "\nTARGET:", target, "\n")
    prompts.append(input)
    tempalte_names.append(t.name)
    print(t.name)
all_prompts = [prompts]


Downloading builder script:   0%|          | 0.00/2.38k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.54k [00:00<?, ?B/s]

Downloading and preparing dataset cos_e/v1.11 (download: 6.23 MiB, generated: 2.91 MiB, post-processed: Unknown size, total: 9.14 MiB) to /data2/mikeeewang/.cache/huggingface/cos_e/v1.11/1.11.0/e8dc57a5b321a2a97063efb8d316d6d8a0d9a2d3a392dafc913e55bed42736d2...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/3.79M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/472k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/423k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/67.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/539k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/9741 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1221 [00:00<?, ? examples/s]

Dataset cos_e downloaded and prepared to /data2/mikeeewang/.cache/huggingface/cos_e/v1.11/1.11.0/e8dc57a5b321a2a97063efb8d316d6d8a0d9a2d3a392dafc913e55bed42736d2. Subsequent calls will reuse this data.

question_description_option_text
question_description_option_id
rationale
question_option_description_text
aligned_with_common_sense
description_question_option_id
explain_why_human
generate_explanation_given_text
description_question_option_text
i_think
question_option_description_id


In [12]:

print(tempalte_names, len(all_prompts), all_prompts[0])
template_name_to_ranks = defaultdict(int)

for prompts in tqdm(all_prompts):
    reranked_prompts = rerank(prompts, main_tokenizer, main_model, main_input_device)
    top_1_index = reranked_prompts[0][0]
    template_name_to_ranks[tempalte_names[top_1_index]] += 1
    # print(prompts)
    print(reranked_prompts)
    print(template_name_to_ranks)


['question-context-meaning-with-label', 'question-context-meaning', 'grammar_homework', 'affirmation_true_or_false', 'GPT-3-prompt', 'same_sense', 'question-context', 'GPT-3-prompt-with-label', 'polysemous', 'similar-sense'] 1 ['Does the word "N/A" have the same meaning in these two sentences? Yes, No?\nN/A\nN/A', 'Does the word "N/A" have the same meaning in these two sentences?\nN/A\nN/A', 'Homework\n\nDecide whether the word "N/A" is used with the same meaning in the two following sentences. Answer by yes or no.\nN/A\nN/A', 'Sentence A: N/A\nSentence B: N/A\n\n"N/A" has a similar meaning in sentences A and B. True or False?', "N/A\nN/A\nQuestion: Is the word 'N/A' used in the same sense in the two sentences above?", 'Sentence 1: N/A\nSentence 2: N/A\n\nDetermine whether the word "N/A" is used in the same sense in both sentences. Yes or no?', "Determine if the word 'N/A' is used in the same way in the two sentences below. \nN/A\nN/A", "N/A\nN/A\nQuestion: Is the word 'N/A' used in th

100%|██████████| 1/1 [00:03<00:00,  3.14s/it]

[(2, 'Homework\n\nDecide whether the word "N/A" is used with the same meaning in the two following sentences. Answer by yes or no.\nN/A\nN/A', 13006.41474802519), (7, "N/A\nN/A\nQuestion: Is the word 'N/A' used in the same sense in the two sentences above? Yes, No?", 18160.680249904097), (6, "Determine if the word 'N/A' is used in the same way in the two sentences below. \nN/A\nN/A", 25010.554855909813), (5, 'Sentence 1: N/A\nSentence 2: N/A\n\nDetermine whether the word "N/A" is used in the same sense in both sentences. Yes or no?', 27375.047994202898), (4, "N/A\nN/A\nQuestion: Is the word 'N/A' used in the same sense in the two sentences above?", 29543.16402057159), (3, 'Sentence A: N/A\nSentence B: N/A\n\n"N/A" has a similar meaning in sentences A and B. True or False?', 32327.70038905803), (0, 'Does the word "N/A" have the same meaning in these two sentences? Yes, No?\nN/A\nN/A', 33362.95340983403), (8, 'The word "N/A" has multiple meanings. Does it have the same meaning in sentenc


