# SuperGLUE Task formats

## Preliminaries

In [None]:
import torch
import torch.nn.functional as F
import os
import json
import sys
import gc
import time
import random
import re
from tqdm import tqdm 

from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side='left',)

def read_json(filepath: str) -> dict:
    with open(filepath, "r") as f:
        return json.load(f)
    
def read_jsonl(filepath: str) -> dict:
    data = []
    with open(filepath, "r") as f:
        for line in f.readlines():
            example = json.loads(line)
            data.append(example)
    return data

In [None]:
tasks = ["BoolQ", "CB", "COPA", "MultiRC", "ReCoRD", "RTE", "WiC", "WSC"]
data = {}
for task_name in tasks:
    path = os.path.join("/nethome/dhe83/mice/data", task_name)
    task = read_jsonl(os.path.join(path, "val.jsonl"))
    task = {x['idx']: x for x in task}
    data[task_name] = task

In [3]:
# def batch_score(model, prompts, labels):
#     prompt_tokens = tokenizer(prompts, padding=True, return_tensors="pt", return_attention_mask=True).to("cuda:0")
#     pad_length = prompt_tokens.input_ids.shape[-1]
#     labels = tokenizer(labels, padding="max_length", max_length=pad_length, return_tensors="pt").to("cuda:0")
#     labels_mask = labels.attention_mask
# #     print("prompt_tokens", prompt_tokens)
# #     print("labels_mask", labels_mask)
#     with torch.no_grad():
#         logits = model(**prompt_tokens).logits
        
# #     print("logits", logits)
    
# #     print(prompt_tokens.input_ids)

#     sequence_logits = logits.log_softmax(-1).gather(-1, prompt_tokens.input_ids.unsqueeze(-1)).squeeze(-1)
# #     print("sequence_logits", sequence_logits)
#     labels_logits = labels_mask * sequence_logits
# #     print("labels_logits", labels_logits)
#     log_probs = labels_logits.sum(-1).to("cpu")
#     return log_probs

## In-Context Aggregation

In [None]:
t = random.sample(sorted(data), k=400)s
train = {k: data[k] for k in t[:200]}
test = {k: data[k] for k in t[200:]}

prompt_map = {k: [tuple(random.sample(sorted(train), k=1)) for x in range(1)] for k in test.keys()}

In [None]:
results = {}
raw = {}
for test_id, train_ids in tqdm(prompt_map.items()):
    test = data[test_id]
    prompts = []
    labels = []
    for pair in train_ids:
        demonstrations = [data[i] for i in pair]
        p, l = COPA_few_shot(demonstrations, test)
        prompts.append(p)
        labels.append(l)
                
    prompts =[prompt for pair in prompts for prompt in pair]
    labels =[label for pair in labels for label in pair]
    log_probs = batch_score(model, prompts, labels)
    
    log_probs = log_probs.view(-1, 2)
    preds = log_probs.argmax(-1)
    
    pred = 0
    if preds.sum(-1) > preds.shape[0] / 2:
        pred = 1
    results[test_id] = pred
    raw[test_id] = preds

In [None]:
correct = 0
total = 0
for id, pred in results.items():
    if data[id]['label'] == pred:
        correct+=1
    total+=1

print(correct/total)

## Zero Shot

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side='left',pad_token="-")
results = {}
raw = {}
total = 0
correct = 0
for test_id, ex in tqdm(data.items()):
    prompts, labels = COPA_few_shot([], ex)
    
    print("prompts", prompts)
#     print("labels", labels)
          
    log_probs = batch_score(model, prompts, labels)
    print("log_probs, log_probs")
    log_probs = log_probs.view(-1, 2)
    print("viewed log_probs", log_probs)
    pred = log_probs.argmax(-1)
    print("pred", pred)
    
    if pred.item() == ex['label']:
#         print("pred.item()", pred.item())
#         print("label", ex['label'])
        correct +=1
    total+=1
            
    results[test_id] = pred.item()
    
#     print(log_probs)
#     break

    
print(correct/total)

## BoolQ

In [4]:
data["ReCoRD"][5]

{'source': 'Daily mail',
 'passage': {'text': "The wife of one of the\xa0San Bernardino victims believes shooter\xa0Syed Farook targeted her husband because he was a Jew. Jennifer Thalasinos said her husband\xa0Nicholas had discussed religion and Israel with his co-worker Farook, as well as whether Islam is a peaceful religion. It has previously been reported that Mr Thalasinos, a Messianic Jew who wore tzitzits and the Star of David, harbored strong views against radical Islam and was a staunch supporter of the right to bear arms. 'Because of my husband being a Messianic Jew and because of the discussions, I think the shooter was intending on getting my husband,' Mrs\xa0Thalasinos told Fox News.\n@highlight\nJennifer Thalasinos believes\xa0Syed Farook targeted her husband Nicholas\n@highlight\nHe had discussed Israel and religion with his co-worker Farook, she said\n@highlight\nMr Thalasinos and Farook may have argued on Facebook before attack",
  'entities': [{'start': 23, 'end': 36}

In [14]:
max_generated_len = 40
def format_BoolQ(ex: dict, in_context=False):
    prompt = f"Context: {ex['passage']}\nQuestion: {ex['question']}\nAnswer:"
    
    substitutions = ["no", "yes"]
    if in_context:
        prompt+=substitutions[ex['label']]
        
    return prompt

def format_CB(ex, in_context=False):
    return f"{ex['premise']}\nquestion: {ex['hypothesis']}. true, false, or neither?"

def format_COPA(ex: dict, in_context=False):
    
    substitute = {"cause": "because", "effect": "so"}
    
    context = f"Context: {ex['premise'][:-1]} {substitute[ex['question']]} "
    
    if in_context:
        label = ex['choice1'] if ex['label'] == 0 else ex['choice2']
        return context + label
    
    else:
        prompts = [context + ex['choice1'], context + ex['choice2']]
        labels = [ex['choice1'], ex['choice2']]

        return prompts, labels

def format_MultiRC(ex, in_context=False):
    prompts = []
    labels = []
    substitutions = ["False", "True"]
    
    def format_answer(label, text):
        return f"[{substitutions[label]}] {text}"
    
    for question in ex['passage']["questions"]:
        for answer in question["answers"]:
            prompt = (f"READING COMPREHENSION ANSWER KEY\n{ex['passage']['text']}\n\n"
                        f"{question['question']}\n")
            prompts.append((prompt + format_answer(0, answer['text']), prompt + format_answer(1, answer['text'])))
            labels.append(prompt + format_answer(answer['label'], answer['text']))
            
    return prompts, labels

#TODO: 
def format_ReCoRD(ex, in_context=False):
    return (f"{ex['premise']}\n"
            f"question: {ex['hypothesis']}. true, false, or neither?")

def format_RTE(ex, in_context=False):
    return (f"{ex['premise']}\n"
            f"question: {ex['hypothesis']}. True or False?"
            f"answer:")

def format_WiC(ex, in_context=False):
    return (f"{ex['sentence1']}\n{ex['sentence2']}\n"
            f"question: Is the word \'{ex['word']}\' used in the same way in the two sentences above?"
            f"answer:")

#TODO:
def format_WSC(ex, in_context=False):
    return (f"Passage: {ex['text']}\n"
            f"Questions: In the passage above, does the pronoun" 
            f"\"{ex['target']['span2_text']}\" refer to {ex['target']['span1_text']}?\n"
            f"Answer:")

def format_example(ex : dict, dataset: str, in_context=False):
    assert dataset in tasks
    
    templates = {"BoolQ": format_BoolQ,
                "CB": format_CB,
                "RTE": format_RTE,
                "WiC": format_WiC}
    
    return templates[dataset](ex, in_context)

def format_few_shot(demonstrations, test, dataset):
    instructions = {"BoolQ": "Answer the question."}
    demonstrations = [format_example(ex, dataset, in_context=True) for ex in demonstrations]
    context = [instructions[dataset], *demonstrations]
    
    prompt = format_example(test, dataset)
    prompt = "\n\n".join([*context, prompt])
    return prompt

def COPA_few_shot(demonstrations, test):
    instruction = "Pick the more likely continuation to the following sentence."
    demonstrations = [format_COPA(ex, in_context=True) for ex in demonstrations]
    context = "\n".join([instruction, *demonstrations])
    
    prompts, labels = format_COPA(test)
    
    prompts = [("\n".join([context, prompt])) for prompt in prompts]
    return prompts, labels

def first_word(s):
    return "".join([c for c in re.split(" |\n",s.strip())[0] if str.isalpha(c)]).lower()

In [37]:
def zero_shot_inference(model, ex, dataset):
    prompt = format_few_shot([], ex, dataset)
    tokens = tokenizer(prompt, padding=True, return_tensors="pt").to('cuda:0')
    outputs = model.generate(
         **tokens,
         max_new_tokens=max_generated_len,
         temperature=0,
         return_dict_in_generate=True,
         output_scores=True,
         eos_token_id=198,  # special character 'ċ' (bytecode for new line?) NOTE use this for generation
    )
    output_text = tokenizer.decode(outputs.sequences.squeeze()[-max_generated_len:])
    res = first_word(output_text)    
    return res

def zero_shot_scoring(model, prompts):
    encoded_prompt = tokenizer(
        prompts,
        truncation=True,
        padding=True,
        return_tensors="pt",
    )
    encoded_prompt = encoded_prompt.to("cuda:0")
    with torch.no_grad():
        logits = model(**encoded_prompt,).logits
        
    labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
    masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
        logits.float(), dim=-1
    )
    seq_token_log_probs = torch.gather(
        masked_log_probs, -1, encoded_prompt["input_ids"].unsqueeze(-1)
    )
    seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
    seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
    return seq_log_prob

In [83]:
batch_size = 100
def batch_encode(prompts):
    return tokenizer(prompts, padding=True, return_tensors="pt").to('cuda:0')

def batch_inference(tokenized):
    output_tokens = torch.empty(0, dtype=torch.int64).to('cuda:0')
    num_batches = round(tokenized.input_ids.shape[0] / batch_size + 0.5)
    for tokens, mask in tqdm(zip(tokenized.input_ids.chunk(num_batches), tokenized.attention_mask.chunk(num_batches)), total=num_batches):
        outputs = model.generate(
            tokens,
            attention_mask=mask,
            max_new_tokens=max_generated_len,
            temperature=0,
            return_dict_in_generate=True,
            output_scores=True,
            eos_token_id=198,  # special character 'ċ' (bytecode for new line?) NOTE use this for generation
        )
        output_tokens = torch.cat((output_tokens, outputs.sequences))
        print(output_tokens)

    return output_tokens

def batch_scoring(prompts):
    log_probs = []
    tokenized = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda:0')
    
    num_batches = round(tokenized.input_ids.shape[0] / batch_size + 0.5)
    with torch.no_grad():
        for tokens, mask in zip(tokenized.input_ids.chunk(num_batches), tokenized.attention_mask.chunk(num_batches)):
            logits = model(tokens, attention_mask=mask).logits
            labels_attention_mask = mask.unsqueeze(-1)
            masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
                logits.float(), dim=-1
            )
            seq_token_log_probs = torch.gather(
                masked_log_probs, -1, tokens.unsqueeze(-1)
            )
            seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
            seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
            print(seq_log_prob)
            
            log_probs.extend(seq_log_prob)
        
    return log_probs

In [19]:
print(format_few_shot([], data["BoolQ"][5], "BoolQ"))

Answer the question.

Context: Barq's -- Barq's /ˈbɑːrks/ is an American soft drink. Its brand of root beer is notable for having caffeine. Barq's, created by Edward Barq and bottled since the turn of the 20th century, is owned by the Barq family but bottled by the Coca-Cola Company. It was known as Barq's Famous Olde Tyme Root Beer until 2012.
Question: is barq's root beer a pepsi product
Answer:


In [40]:
t = "BoolQ"
prompts = []
for idx, ex in tqdm(data[t].items()):
    prompts.append(format_example(ex, t))
#     prompts.append(format_few_shot([], ex, t))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3270/3270 [00:00<00:00, 736870.69it/s]


In [8]:
tokens = batch_encode(prompts[:20])

In [36]:
zero_shot_inference(model, data[t][37], t)

no


' No. The copyright holder is the copyright holder, and the song is not covered by the copyright. The song is not covered by the copyright.\n\nAnswer: No. The copyright holder is the'

In [39]:
correct = 0
total = 0
for idx, ex in tqdm(list(data[t].items())[:1000]):
    r = zero_shot_inference(model, ex, t)
    if r == "yes":
        r = True
    elif r == "no":
        r = False
    else:
        r = None
    
    if r != None and r == ex["label"]:
        correct+=1
    total+=1

print(correct/total)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:22<00:00,  3.81it/s]

0.379





In [9]:
res = batch_inference(tokens)

 50%|██████████████████████████████████████████████████████████████████████████████                                                                              | 1/2 [00:03<00:03,  3.38s/it]

tensor([[    2, 48522,    35,  ...,   473, 24731,   185],
        [    1,     2, 48522,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ..., 45641,    35,    16],
        [    1,     1,     1,  ..., 45641,    35,    64],
        [    1,     1,     1,  ...,    80,   893,    21]], device='cuda:0')


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.85s/it]

tensor([[    2, 48522,    35,  ...,   473, 24731,   185],
        [    1,     2, 48522,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,    35,  3216,     4],
        [    1,     1,     1,  ...,     4, 50118, 45641],
        [    1,     1,     1,  ...,    24,     4, 50118]], device='cuda:0')





In [10]:
output_text = tokenizer.batch_decode(res)
for x in output_text:
    print(x)
    print("=====")

</s>Context: Ethanol fuel -- All biomass goes through at least some of these steps: it needs to be grown, collected, dried, fermented, distilled, and burned. All of these steps require resources and an infrastructure. The total amount of energy input into the process compared to the energy released by burning the resulting ethanol fuel is known as the energy balance (or ``energy returned on energy invested''). Figures compiled in a 2007 report by National Geographic Magazine point to modest results for corn ethanol produced in the US: one unit of fossil-fuel energy is required to create 1.3 energy units from the resulting ethanol. The energy balance for sugarcane ethanol produced in Brazil is more favorable, with one unit of fossil-fuel energy required to create 8 from the ethanol. Energy balance estimates are not easily produced, thus numerous such reports have been generated that are contradictory. For instance, a separate survey reports that production of ethanol from sugarcane, whi

In [84]:
tokens = batch_encode(prompts[:1000])
x = batch_inference(tokens)

 10%|███████████████▌                                                                                                                                           | 1/10 [00:02<00:26,  3.00s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,    16,     5,  1154],
        [    1,     1,     1,  ...,     4, 50118, 45641],
        [    1,     1,     1,  ..., 45641,    35,   222]], device='cuda:0')


 20%|███████████████████████████████                                                                                                                            | 2/10 [00:06<00:24,  3.01s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ..., 50118, 33683,    35],
        [    1,     1,     1,  ...,     5,  1926,  9438],
        [    1,     1,     1,  ...,  2171,  1082, 50118]], device='cuda:0')


 30%|██████████████████████████████████████████████▌                                                                                                            | 3/10 [00:09<00:21,  3.01s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,     4, 50118, 45641],
        [    1,     1,     1,  ...,   490,   116, 50118],
        [    1,     1,     1,  ..., 45641,    35,    40]], device='cuda:0')


 40%|██████████████████████████████████████████████████████████████                                                                                             | 4/10 [00:12<00:18,  3.01s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ..., 45641,    35,    64],
        [    1,     1,     1,  ...,    10, 11906,     9],
        [    1,     1,     1,  ...,     4, 50118, 50118]], device='cuda:0')


 50%|█████████████████████████████████████████████████████████████████████████████▌                                                                             | 5/10 [00:15<00:15,  3.01s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ..., 50118, 33683,    35],
        [    1,     1,     1,  ...,  8985,     7,   121],
        [    1,     1,     1,  ..., 45641,    35,    64]], device='cuda:0')


 60%|█████████████████████████████████████████████████████████████████████████████████████████████                                                              | 6/10 [00:18<00:12,  3.02s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ..., 45641,    35,    16],
        [    1,     1,     1,  ...,    11,   144,  1200],
        [    1,     1,     1,  ..., 45641,    35,    21]], device='cuda:0')


 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                              | 7/10 [00:21<00:09,  3.02s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,   623,   968,     4],
        [    1,     1,     1,  ...,    11,     5,   232],
        [    1,     1,     1,  ..., 19258,  8766,   462]], device='cuda:0')


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                               | 8/10 [00:24<00:06,  3.02s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,     4, 50118, 45641],
        [    1,     1,     1,  ..., 45641,    35,   109],
        [    1,     1,     1,  ...,     8,    16,   278]], device='cuda:0')


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 9/10 [00:27<00:03,  3.02s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,     4, 50118, 45641],
        [    1,     1,     1,  ...,   129,   114,    47],
        [    1,     1,     1,  ..., 27769,   179,   413]], device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:30<00:00,  3.02s/it]

tensor([[    1,     1,     1,  ...,   473, 24731,   185],
        [    1,     1,     1,  ...,   691,     4, 23027],
        [    1,     1,     1,  ...,   443,     4, 50118],
        ...,
        [    1,     1,     1,  ...,  3420,    16,  6408],
        [    1,     1,     1,  ...,  8915,    13,  9473],
        [    1,     1,     1,  ...,    35,  4420,     4]], device='cuda:0')





In [82]:
type(x)

list

In [79]:
from collections import Counter
c = Counter()
for x in i:
    c.update([len(x)])
print(c)

Counter({774: 968, 773: 28, 775: 1, 772: 1, 770: 1, 771: 1})


In [65]:
xd = tokenizer.batch_decode(x)

In [47]:
z = [first_word(e) for e in x]
y = [s for s in z if s not in ["yes", "no"]]
print(len(y))

1000


In [None]:
for i in x:
    if first_word(i) not in ["yes", "no"]:
        print(i)
        print("==========")

In [None]:
w = [None] * 1000
for i, e in enumerate(z): 
    if e == "yes":
        w[i] = True
    elif e == "no":
        w[i] = False

In [None]:
correct = 0
for i, e in enumerate(w):
    if e and e == data[t][i]["label"]:
        correct+=1

print(correct/1000)

In [None]:
a = [i for i, e in zip(x, w) if e == None]

In [None]:
a