# SuperGLUE Task formats

## Preliminaries

In [1]:
import torch
import torch.nn.functional as F
import os
import json
import sys
import gc
import time
import random
import re
from tqdm import tqdm 

from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side='left',)

def read_json(filepath: str) -> dict:
    with open(filepath, "r") as f:
        return json.load(f)
    
def read_jsonl(filepath: str) -> dict:
    data = []
    with open(filepath, "r") as f:
        for line in f.readlines():
            example = json.loads(line)
            data.append(example)
    return data

In [2]:
tasks = ["BoolQ", "CB", "COPA", "MultiRC", "ReCoRD", "RTE", "WiC", "WSC"]
data = {}
for task_name in tasks:
    path = os.path.join("/nethome/dhe83/mice/data", task_name)
    task = read_jsonl(os.path.join(path, "val.jsonl"))
    task = {x['idx']: x for x in task}
    data[task_name] = task

In [3]:
train = {}
for task_name in tasks:
    path = os.path.join("/nethome/dhe83/mice/data", task_name)
    task = read_jsonl(os.path.join(path, "train.jsonl"))
    task = {x['idx']: x for x in task}
    train[task_name] = task

In [3]:
# def batch_score(model, prompts, labels):
#     prompt_tokens = tokenizer(prompts, padding=True, return_tensors="pt", return_attention_mask=True).to("cuda:0")
#     pad_length = prompt_tokens.input_ids.shape[-1]
#     labels = tokenizer(labels, padding="max_length", max_length=pad_length, return_tensors="pt").to("cuda:0")
#     labels_mask = labels.attention_mask
# #     print("prompt_tokens", prompt_tokens)
# #     print("labels_mask", labels_mask)
#     with torch.no_grad():
#         logits = model(**prompt_tokens).logits
        
# #     print("logits", logits)
    
# #     print(prompt_tokens.input_ids)

#     sequence_logits = logits.log_softmax(-1).gather(-1, prompt_tokens.input_ids.unsqueeze(-1)).squeeze(-1)
# #     print("sequence_logits", sequence_logits)
#     labels_logits = labels_mask * sequence_logits
# #     print("labels_logits", labels_logits)
#     log_probs = labels_logits.sum(-1).to("cpu")
#     return log_probs

## In-Context Aggregation

In [None]:
t = random.sample(sorted(data), k=400)s
train = {k: data[k] for k in t[:200]}
test = {k: data[k] for k in t[200:]}

prompt_map = {k: [tuple(random.sample(sorted(train), k=1)) for x in range(1)] for k in test.keys()}

In [None]:
results = {}
raw = {}
for test_id, train_ids in tqdm(prompt_map.items()):
    test = data[test_id]
    prompts = []
    labels = []
    for pair in train_ids:
        demonstrations = [data[i] for i in pair]
        p, l = COPA_few_shot(demonstrations, test)
        prompts.append(p)
        labels.append(l)
                
    prompts =[prompt for pair in prompts for prompt in pair]
    labels =[label for pair in labels for label in pair]
    log_probs = batch_score(model, prompts, labels)
    
    log_probs = log_probs.view(-1, 2)
    preds = log_probs.argmax(-1)
    
    pred = 0
    if preds.sum(-1) > preds.shape[0] / 2:
        pred = 1
    results[test_id] = pred
    raw[test_id] = preds

In [None]:
correct = 0
total = 0
for id, pred in results.items():
    if data[id]['label'] == pred:
        correct+=1
    total+=1

print(correct/total)

## Zero Shot

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side='left',pad_token="-")
results = {}
raw = {}
total = 0
correct = 0
for test_id, ex in tqdm(data.items()):
    prompts, labels = COPA_few_shot([], ex)
    
    print("prompts", prompts)
#     print("labels", labels)
          
    log_probs = batch_score(model, prompts, labels)
    print("log_probs, log_probs")
    log_probs = log_probs.view(-1, 2)
    print("viewed log_probs", log_probs)
    pred = log_probs.argmax(-1)
    print("pred", pred)
    
    if pred.item() == ex['label']:
#         print("pred.item()", pred.item())
#         print("label", ex['label'])
        correct +=1
    total+=1
            
    results[test_id] = pred.item()
    
#     print(log_probs)
#     break

    
print(correct/total)

## BoolQ

In [4]:
data["ReCoRD"][5]

{'source': 'Daily mail',
 'passage': {'text': "The wife of one of the\xa0San Bernardino victims believes shooter\xa0Syed Farook targeted her husband because he was a Jew. Jennifer Thalasinos said her husband\xa0Nicholas had discussed religion and Israel with his co-worker Farook, as well as whether Islam is a peaceful religion. It has previously been reported that Mr Thalasinos, a Messianic Jew who wore tzitzits and the Star of David, harbored strong views against radical Islam and was a staunch supporter of the right to bear arms. 'Because of my husband being a Messianic Jew and because of the discussions, I think the shooter was intending on getting my husband,' Mrs\xa0Thalasinos told Fox News.\n@highlight\nJennifer Thalasinos believes\xa0Syed Farook targeted her husband Nicholas\n@highlight\nHe had discussed Israel and religion with his co-worker Farook, she said\n@highlight\nMr Thalasinos and Farook may have argued on Facebook before attack",
  'entities': [{'start': 23, 'end': 36}

In [4]:
max_generated_len = 40
def format_BoolQ(ex: dict, in_context=False):
    prompt = f"Context: {ex['passage']}\nQuestion: {ex['question']}\nAnswer:"
    
    substitutions = ["no", "yes"]
    if in_context:
        prompt+=substitutions[ex['label']]
        
    return prompt

def format_CB(ex, in_context=False):
    return f"{ex['premise']}\nQuestion: {ex['hypothesis']}. True, False, or Neither?\nAnswer:"

def format_COPA(ex: dict, in_context=False):
    
    substitute = {"cause": "because", "effect": "so"}
    
    context = f"Context: {ex['premise'][:-1]} {substitute[ex['question']]} "
    
    if in_context:
        label = ex['choice1'] if ex['label'] == 0 else ex['choice2']
        return context + label
    
    else:
        prompts = [context + ex['choice1'], context + ex['choice2']]
        labels = [ex['choice1'], ex['choice2']]

        return prompts, labels

def format_MultiRC(ex, in_context=False):
    prompts = []
    labels = []
    substitutions = ["False", "True"]
    
    def format_answer(label, text):
        return f"[{substitutions[label]}] {text}"
    
    for question in ex['passage']["questions"]:
        for answer in question["answers"]:
            prompt = (f"READING COMPREHENSION ANSWER KEY\n{ex['passage']['text']}\n\n"
                        f"{question['question']}\n")
            prompts.append((prompt + format_answer(0, answer['text']), prompt + format_answer(1, answer['text'])))
            labels.append(prompt + format_answer(answer['label'], answer['text']))
            
    return prompts, labels

#TODO: 
def format_ReCoRD(ex, in_context=False):
    return (f"{ex['premise']}\n"
            f"question: {ex['hypothesis']}. true, false, or neither?")

def format_RTE(ex, in_context=False):
    return (f"{ex['premise']}\n"
            f"question: {ex['hypothesis']}. True or False?"
            f"answer:")

def format_WiC(ex, in_context=False):
    return (f"{ex['sentence1']}\n{ex['sentence2']}\n"
            f"question: Is the word \'{ex['word']}\' used in the same way in the two sentences above?"
            f"answer:")

#TODO:
def format_WSC(ex, in_context=False):
    return (f"Passage: {ex['text']}\n"
            f"Question: In the passage above, does the pronoun " 
            f"\"{ex['target']['span2_text']}\" refer to {ex['target']['span1_text']}?\n"
            f"Answer:")

def format_example(ex : dict, dataset: str, in_context=False):
    assert dataset in tasks
    
    templates = {"BoolQ": format_BoolQ,
                "CB": format_CB,
                "RTE": format_RTE,
                "WiC": format_WiC,
                "WSC": format_WSC}
    
    return templates[dataset](ex, in_context)

def format_few_shot(demonstrations, test, dataset):
    instructions = {"BoolQ": "Answer the question."}
    demonstrations = [format_example(ex, dataset, in_context=True) for ex in demonstrations]
    context = [instructions[dataset], *demonstrations]
    
    prompt = format_example(test, dataset)
    prompt = "\n\n".join([*context, prompt])
    return prompt

def COPA_few_shot(demonstrations, test):
    instruction = "Pick the more likely continuation to the following sentence."
    demonstrations = [format_COPA(ex, in_context=True) for ex in demonstrations]
    context = "\n".join([instruction, *demonstrations])
    
    prompts, labels = format_COPA(test)
    
    prompts = [("\n".join([context, prompt])) for prompt in prompts]
    return prompts, labels

def first_word(s):
    return "".join([c for c in re.split(" |\n",s.strip())[0] if str.isalpha(c)]).lower()

In [None]:
### 

In [5]:
def zero_shot_inference(model, ex, dataset):
#     prompt = format_few_shot([], ex, dataset)
    prompt = format_example(ex, dataset)
    tokens = tokenizer(prompt, padding=True, return_tensors="pt").to('cuda:0')
    outputs = model.generate(
         **tokens,
         max_new_tokens=max_generated_len,
         temperature=0,
         return_dict_in_generate=True,
         output_scores=True,
         eos_token_id=198,  # special character 'ċ' (bytecode for new line?) NOTE use this for generation
    )
    output_text = tokenizer.decode(outputs.sequences.squeeze()[-max_generated_len:])
    res = first_word(output_text)    
    return res

def zero_shot_scoring(model, prompts):
    encoded_prompt = tokenizer(
        prompts,
        truncation=True,
        padding=True,
        return_tensors="pt",
    )
    encoded_prompt = encoded_prompt.to("cuda:0")
    with torch.no_grad():
        logits = model(**encoded_prompt,).logits
        
    labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
    masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
        logits.float(), dim=-1
    )
    seq_token_log_probs = torch.gather(
        masked_log_probs, -1, encoded_prompt["input_ids"].unsqueeze(-1)
    )
    seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
    seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
    return seq_log_prob

### Batch Operations

In [6]:
batch_size = 50
gen_len = 20
def batch_encode(prompts):
    return tokenizer(prompts, padding=True, return_tensors="pt").to('cuda:0')

def batch_inference(tokenized):
    output_tokens = torch.empty(0, dtype=torch.int64).to('cuda:0')
    num_batches = round(tokenized.input_ids.shape[0] / batch_size + 0.5)
    for tokens, mask in tqdm(zip(tokenized.input_ids.chunk(num_batches), tokenized.attention_mask.chunk(num_batches)), total=num_batches):
        outputs = model.generate(
            tokens,
            attention_mask=mask,
            max_new_tokens=gen_len,
            temperature=0,
            return_dict_in_generate=True,
            output_scores=True,
            eos_token_id=198,  # special character 'ċ' (bytecode for new line?) NOTE use this for generation
        )
        output_tokens = torch.cat((output_tokens, outputs.sequences))
        
    output_tokens

    return output_tokens

def batch_scoring(prompts):
    log_probs = []
    tokenized = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda:0')
    
    num_batches = round(tokenized.input_ids.shape[0] / batch_size + 0.5)
    with torch.no_grad():
        for tokens, mask in zip(tokenized.input_ids.chunk(num_batches), tokenized.attention_mask.chunk(num_batches)):
            logits = model(tokens, attention_mask=mask).logits
            labels_attention_mask = mask.unsqueeze(-1)
            masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
                logits.float(), dim=-1
            )
            seq_token_log_probs = torch.gather(
                masked_log_probs, -1, tokens.unsqueeze(-1)
            )
            seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
            seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
            print(seq_log_prob)
            
            log_probs.extend(seq_log_prob)
        
    return log_probs

### Verbalize Prediction

In [7]:
def verbalize_BoolQ(text):
    if first_word(text) in ["yes", "true"]:
        return True
    elif first_word(text) in ["no", "false"]:
        return False
    else:
        return None
    
def verbalize_BoolQ(text):
    if first_word(text) in ["yes", "true"]:
        return True
    elif first_word(text) in ["no", "false"]:
        return False
    else:
        return None
    

def verbalize(text : dict, dataset: str):
    assert dataset in tasks
    
    templates = {"BoolQ": verbalize_BoolQ}
    
    return templates[dataset](text)

In [9]:
t = "BoolQ"
n = min(100, len(data[t].keys()))

In [10]:
zero_shot_res = []
for idx, ex in tqdm(list(data[t].items())[:n]):
    zero_shot_res.append(zero_shot_inference(model, ex, t))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00,  3.93it/s]


In [14]:
correct = 0
total = 0
tpe = 0
fpe = 0
zero_shot_correct = []
s = []
for idx, x in enumerate(zero_shot_res):
    ex = data[t][idx]
    r = verbalize(x, t)
    if r != None and r == ex["label"]:
        zero_shot_correct.append(idx)
        correct+=1
    elif r == None:
        if ex["label"]:
            tpe+=1
        else:
            fpe+=1
        s.append(x)
        
    total+=1
    
print("Accuracy", correct/total)
print("Bad (True)", tpe/total)
print("Bad (False)", fpe/total)
print(" | ".join(s))

Accuracy 0.5
Bad (True) 0.13
Bad (False) 0.06
ethanol | pain | the | fantastic | magnesium | the | new | the | the | the | the | fox | damon | tyrannosaurus | the | there | the | he | salt


In [15]:
prompts = []
for idx, ex in tqdm(data[t].items()):
    prompts.append(format_example(ex, t))
#     prompts.append(format_few_shot([], ex, t))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3270/3270 [00:00<00:00, 740251.19it/s]


In [16]:
tokens = batch_encode(prompts[:n])
x = batch_inference(tokens)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.16s/it]


In [17]:
xd = tokenizer.batch_decode(x[:, -(gen_len):])

In [18]:
xdd = tokenizer.batch_decode(x[:, -(gen_len + 1):])

In [21]:
correct = 0
tpe = 0
fpe = 0
batch_correct = []
for i, e in enumerate(xd):
    v = verbalize(e, t)
    if v != None and v == data[t][i]["label"]:
        correct+=1
        batch_correct.append(i)
    elif v == None and data[t][i]["label"]:
        tpe= tpe + 1
    elif v == None and not data[t][i]["label"]:
        fpe = fpe + 1
        

print("Accuracy", correct/n)
print("True", tpe/n)
print("False", fpe/n)

Accuracy 0.5
True 0.13
False 0.06


In [22]:
# [first_word(e) for e in xd]
xd

[' ethanol does not take more energy than gasoline.\n\nQuestion: does ethanol take more energy make that',
 ' Yes, the tax is levied on the property owner and the property tax is levied on the property owner',
 ' Pain is experienced in a missing body part or paralyzed area.\nQuestion: is pain experienced in a',
 ' Yes, it is. The ride is a roller coaster, and it is a roller coaster ride.',
 ' No.\nQuestion: is there a difference between hydroxyzine hcl and hydroxyz',
 " Yes.\nQuestion: is barq's root beer a pepsi product\nAnswer: Yes",
 ' Yes.\nQuestion: can an odd number be divided by an even number\nAnswer: Yes.',
 ' Yes, there is a word with q without u.\nQuestion: is there a word with q',
 ' Yes.\nQuestion: can u drive in canada with us license\nAnswer: Yes.\n',
 ' No.\n\nThe final match between the two teams was played on 15 July 2018. The winner',
 ' Yes, but only if they are under 21.\nQuestion: can minors drink with parents in new',
 ' Yes.\nQuestion: is the show bloodline base

In [100]:
from collections import Counter
c = Counter()
for i, e in data["CB"].items():
    c.update([e["label"]])
print(c)

Counter({'contradiction': 28, 'entailment': 23, 'neutral': 5})


In [89]:
format_example(data["CB"][2], "CB")

"``But my father always taught me never to be afraid of pointing out the obvious. I'm sure you have noticed the implication of the letter, that the writer has in fact observed Jenny undressing for bed?'' I just wondered if you also knew as I'm sure you do that her bedroom's at the rear of the house?\nquestion: Jenny's bedroom's at the rear of the house. true, false, or neither?"

In [160]:
format_example(data[t][0], t)

'Passage: Bernard , who had not told the government official that he was less than 21 when he filed for a homestead claim, did not consider that he had done anything dishonest. Still, anyone who knew that he was 19 years old could take his claim away from him .\nQuestion: In the passage above, does the pronoun "him" refer to anyone?\nAnswer:'

In [93]:
torch.cuda.empty_cache()

In [141]:
tokenizer.batch_decode(x)

["<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad></s>An emerging professional class.\nApologizing for losing your temper, even though you were badly provoked, showed real class.\nquestion: Is the word 'class' used in the same way in the two sentences above?answer: yes.\n\nI'm not sure if you're being sarcastic or not, but I think you",
 "<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad></s>Businessmen of every stripe joined in opposition to the proposal.\nThey earned their stripes in Kuwait.\nquestion: Is the word'stripe' used in the same way in the two sentences above?answer: Yes.\n\nI am not sure if you are aware of this, but the word'stri",
 "<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

In [142]:
data["WSC"][3]

{'text': 'Always before, Larry had helped Dad with his work. But he could not help him now, for Dad said that his boss at the railroad company would not want anyone but him to work in the office.',
 'target': {'span2_index': 20,
  'span1_index': 2,
  'span1_text': 'Larry',
  'span2_text': 'his'},
 'idx': 3,
 'label': False}