# SuperGLUE Task formats

# Preliminaries

In [1]:
import sys
import gc
import time
import random
import re
from tqdm import tqdm 
import os
import json

def read_json(filepath: str) -> dict:
    with open(filepath, "r") as f:
        return json.load(f)
    
def read_jsonl(filepath: str) -> dict:
    data = []
    with open(filepath, "r") as f:
        for line in f.readlines():
            example = json.loads(line)
            data.append(example)
    return data

def write_jsonl(data: list, filepath: str) -> None:
    with open(filepath, "w") as f:
        for example in data:
            f.write(json.dumps(example) + "\n")
            
delim = "|"

In [43]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side='left',)

In [12]:
def memory_usage(t)-> float:
    i = 0
    if not isinstance(t, torch.Tensor):
        for x in t:
            i+=memory_size(x)
    else:
        return t.element_size() * t.numel() / ((2**10)**3)
    return i

def flatten_dict(d, parent_key='', sep='_'):
    """
    Recursively flatten a dictionary that may contain nested lists and dictionaries.

    Args:
        d (dict): The dictionary to be flattened.
        parent_key (str): The prefix to be added to the keys of the flattened dictionary.
        sep (str): The separator to be used between the prefix and the key of the flattened dictionary.

    Returns:
        dict: The flattened dictionary.
    """
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def unflatten_dict(d, sep='_'):
    """
    Unflatten a flattened dictionary with keys containing a separator character.

    Args:
        d (dict): The flattened dictionary to be unflattened.
        sep (str): The separator character used in the keys of the flattened dictionary.

    Returns:
        dict: The unflattened dictionary.
    """
    result = {}
    for key, value in d.items():
        parts = key.split(sep)
        current = result
        for part in parts[:-1]:
            if part not in current:
                current[part] = {}
            current = current[part]
        current[parts[-1]] = value
    return result

In [46]:
mean([len(tokenizer(format_BoolQ(v)).input_ids) for ex, v in val["BoolQ"].items()])

141.75321100917432

In [38]:
format_RTE(val[dataset][0])

In [31]:
val[dataset]

{0: {'premise': 'Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation.',
  'hypothesis': 'Christopher Reeve had an accident.',
  'label': 'not_entailment',
  'idx': 0},
 1: {'premise': 'Yet, we now are discovering that antibiotics are losing their effectiveness against illness. Disease-causing bacteria are mutating faster than we can come up with new antibiotics to fight the new variations.',
  'hypothesis': 'Bacteria is winning the war against antibiotics.',
  'label': 'entailment',
  'idx': 1},
 2: {'premise': 'Cairo is now home to some 15 million people - a burgeoning population that produces approximately 10,000 tonnes of rubbish per day, putting an enormous strain on public services. In the past 10 years, the government has tried hard to encourage private investment in the refuse sector, but some estimate 4,000 tonnes of waste is left behind every day, festering in the heat as it waits for someone to

In [20]:
tasks = ["BoolQ", "CB", "COPA", "MultiRC", "ReCoRD", "RTE", "WiC", "WSC"]
val = {}
for task_name in tasks:
    path = os.path.join("/nethome/dhe83/mice/data", task_name)
    task = read_jsonl(os.path.join(path, "val.jsonl"))
    print(task_name)
    task = {x['idx']: x for x in task}
    val[task_name] = task

BoolQ
CB
COPA
MultiRC
ReCoRD
RTE
WiC
WSC


In [3]:
train = {}
for task_name in tasks:
    path = os.path.join("/nethome/dhe83/mice/data", task_name)
    task = read_jsonl(os.path.join(path, "train.jsonl"))
    task = {x['idx']: x for x in task}
    train[task_name] = task

In [7]:
train['Winograd'][0]

{'text': 'The pony behaved well, sir, and showed no vice; but at last he just threw up his heels and tipped the young gentleman into the thorn hedge. He wanted me to help him out, but I hope you will excuse me, sir, I did not feel inclined to do so.',
 'target': {'span2_index': 27,
  'span1_index': 21,
  'span1_text': 'young gentleman',
  'span2_text': 'He'},
 'idx': 3,
 'label': 'young gentleman'}

In [8]:
# GPT format for WSC uses only positive examples
val["Winograd"]  = [v for k, v in val["WSC"].items() if v["label"]]
for v in val["Winograd"]:
    v['label'] = v['target']['span1_text']
train["Winograd"]  = [v for k, v in train["WSC"].items() if v["label"]]
for v in train["Winograd"]:
    v['label'] = v['target']['span1_text']

In [9]:
write_jsonl(train["Winograd"], "/nethome/dhe83/mice/data/Winograd/train.jsonl")

In [7]:
# convert lists to dicts for MultiRC
for ex_id, ex in val["MultiRC"].items():
    ex = ex["passage"]
    for q in ex["questions"]:
        q["answers"] = {a["idx"]: a for a in q["answers"]}
    ex["questions"]= {q["idx"]: q for q in ex["questions"]}

In [38]:
# convert lists to dicts for ReCoRD
for ex_id, ex in val["ReCoRD"].items():
    ex["qas"]= {q["idx"]: q for q in ex["qas"]}

TypeError: 'int' object is not subscriptable

In [40]:
for ex_id, ex in train["ReCoRD"].items():
    ex["qas"]= {q["idx"]: q for q in ex["qas"]}

### Permanent dataset reformatting

In [89]:
path = os.path.join("/nethome/dhe83/mice/data/ReCoRD/val_original.jsonl")
w = read_jsonl(os.path.join(path))
w[100]

{'source': 'Daily mail',
 'passage': {'text': "Niall Boylan and his biological sister Fran Kavanagh, pictured during their first meeting last week, when they were reunited 50 years after being separated A brother and sister who were separated as babies and put up for adoption five decades ago have been reunited after realising they grew up just 500 metres apart. Niall Boylan, 51, a  radio presenter in Dublin, met his biological sister who now lives in London, Fran Kavanagh, 49, for the first time this weekend after finding each other online last year. Niall, who has a show on Irish radio station 4FM, hit the headlines last year when he spoke about discovering that his birth mother had a little girl when he was just 13 months old in 1965.\n@highlight\nNiall Boylan, 51, met his biological sister for the first time last weekend\n@highlight\nHe didn't know he had a sister until an adoption agency put it in a letter\n@highlight\nFran Kavanagh, 49, was also sent a letter and tracked brother 

In [90]:
new = []

for ex in w:
    text = ex['passage']['text']
    entities = list(set({text[e["start"]:e["end"]+1] for e in ex['passage']['entities']}))
    text = text.replace("@highlight\n", "- ")

    for query in ex["qas"]:
        idx = delim.join([str(x) for x in [ex["idx"], query['idx']]])
        answers = list(set(x['text'] for x in query['answers']))
        n = {"source": ex['source'],
            "text": text,
            "choices": entities,
            "labels": answers,
            "query": query['query'],
            "ex_idx": ex['idx'],
            "idx": query['idx']}
        new.append(n)

In [91]:
new[0]

{'source': 'Daily mail',
 'text': "Tracy Morgan hasn't appeared on stage since the devastating New Jersey crash that nearly ended his life last summer, but all that will change this fall when he returns to host Saturday Night Live. NBC announced on Twitter Monday that Morgan, an SNL alum with seven seasons as a cast member under his belt, will headline the third episode of Season 41 airing October 17. For Morgan, 46, it will be a second time hosting the long-running variety show, the first since the June 2014 pileup on the New Jersey Turnpike that killed his friend and mentor James 'Jimmy Mack' McNair.\n- Morgan, 46, will host third episode of season 41 of SNL airing October 17\n- He tweeted to his fans: 'Stoked to be going home...#SNL'\n- For the SNL alum who had spent seven years as cast member, it will be a second time hosting the show\n- Morgan has been sidelined by severe head trauma suffered in deadly June 2014 crash on New Jersey Turnpike that killed his friend\n- First episode 

In [92]:
write_jsonl(new, "/nethome/dhe83/mice/data/ReCoRD/val.jsonl")

## MultiRC

In [46]:
path = os.path.join("/nethome/dhe83/mice/data/MultiRC/val_original.jsonl")
w = read_jsonl(os.path.join(path))
w[0]

{'idx': 0,
 'version': 1.1,
 'passage': {'text': 'What causes a change in motion? The application of a force. Any time an object changes motion, a force has been applied. In what ways can this happen? Force can cause an object at rest to start moving. Forces can cause objects to speed up or slow down. Forces can cause a moving object to stop. Forces can also cause a change in direction. In short, forces cause changes in motion. The moving object may change its speed, its direction, or both. We know that changes in motion require a force. We know that the size of the force determines the change in motion. How much an objects motion changes when a force is applied depends on two things. It depends on the strength of the force. It also depends on the objects mass. Think about some simple tasks you may regularly do. You may pick up a baseball. This requires only a very small force. ',
  'questions': [{'question': 'Would the mass of a baseball affect how much force you have to use to pick i

In [47]:
new = []
for ex in w:
    for question in ex['passage']["questions"]:
        n = {"text": ex['passage']['text'],
            "question": question['question'],
            "answers": [a['text'] for a in question['answers']],
#             "answers": [{"text": a['text'], "label": a['label']} for a in question['answers']],
            "labels": [a['label'] for a in question['answers']],
            "pass_idx": ex['idx'],
            "idx": question['idx']}
        new.append(n)

In [48]:
new[0]

{'text': 'What causes a change in motion? The application of a force. Any time an object changes motion, a force has been applied. In what ways can this happen? Force can cause an object at rest to start moving. Forces can cause objects to speed up or slow down. Forces can cause a moving object to stop. Forces can also cause a change in direction. In short, forces cause changes in motion. The moving object may change its speed, its direction, or both. We know that changes in motion require a force. We know that the size of the force determines the change in motion. How much an objects motion changes when a force is applied depends on two things. It depends on the strength of the force. It also depends on the objects mass. Think about some simple tasks you may regularly do. You may pick up a baseball. This requires only a very small force. ',
 'question': 'Would the mass of a baseball affect how much force you have to use to pick it up?',
 'answers': ['No',
  'Yes',
  'Less the mass, le

In [49]:
write_jsonl(new, "/nethome/dhe83/mice/data/MultiRC/train.jsonl")

In [51]:
val["COPA"][0]

{'premise': 'The man turned on the faucet.',
 'choice1': 'The toilet filled with water.',
 'choice2': 'Water flowed from the spout.',
 'question': 'effect',
 'label': 1,
 'idx': 0}

In [80]:
def format_COPA(ex: dict)->str():
    substitutions = {"cause": "because", "effect": "so"}
    return f"{ex['premise'][:-1]} {substitutions[ex['question']]}"
    
def format_COPA_in_context(ex: dict)->str:
    label = ex['choice1'] if ex['label'] == 0 else ex['choice2']
    return " ".join([format_COPA(ex), label.lower()])

def COPA_choices(ex: dict)->list(str()):
    return [" ".join([format_COPA(ex), x.lower()]) for x in [ex['choice1'], ex['choice2']]]

def format_COPA_few_shot(demonstrations, test):
    context = "\n".join([format_COPA_in_context(ex) for ex in demonstrations])
    prompt = format_COPA(test)
    prompt = "\n".join([context, prompt])
    return prompt

def COPA_few_shot_choices(demonstrations, test):
    context = "\n".join([format_COPA_in_context(ex) for ex in demonstrations])
    prompts = COPA_choices(test)
    prompts = {str(i): ("\n".join([context, prompt])) for i, prompt in enumerate(prompts)}
    return prompts

In [81]:
COPA_few_shot_choices([val["COPA"][1],val["COPA"][2]],val["COPA"][0])

{'0': 'The girl found a bug in her cereal so she lost her appetite.\nThe woman retired so she received her pension.\nThe man turned on the faucet so the toilet filled with water.',
 '1': 'The girl found a bug in her cereal so she lost her appetite.\nThe woman retired so she received her pension.\nThe man turned on the faucet so water flowed from the spout.'}

In [95]:
val["MultiRC"][1]

{'text': "A four-year-old girl with Down syndrome, who began her life in a Colombian orphanage has had a remarkable change in fortune after being picked by Target to be a model for a new product in the retailer's children's range. Little Kayella Aschoff, who was adopted by Ted and Jodi from Minnesota in 2011, was hand-picked by the company to be the face of its training pants, after she attended an open casting call in 2014. 'Kayella loves being in front of the camera,' proud mom Jodi told\xa0The Mighty. 'She’s quite the ham. When you take a photo though, she has to check it and then says ‘cute picture’ or she’ll tell you to take another one!'\n- Kayella Aschoff was adopted by Ted and Jodi, from Minnesota, in 2011\n- The couple have adopted another orphaned child with Down syndrome, Leo, who was born in China\n- Kayella appears as the face of Target's new training pants",
 'question': 'What do you apply to an object to make it move or stop?',
 'answers': ['Strength',
  'Nothing, it wil

In [None]:
def format_MultiRC(ex: dict)->list(str()):
    return (f"{ex['text']}\n\n"
            f"{question['question']}\n")
    
def format_MultiRC_in_context(ex: dict)->str:
    substitutions = ["False", "True"]
    return " ".join([format_COPA(ex), f"[{substitutions[ex['labels'][0]]}] {ex['answers'][0]}"])

def COPA_choices(ex: dict)->list(str()):
    
    return [" ".join([format_COPA(ex), x.lower()]) for x in [ex['choice1'], ex['choice2']]]

def format_COPA_few_shot(demonstrations, test):
    context = "\n".join([format_COPA_in_context(ex) for ex in demonstrations])
    prompt = format_COPA(test)
    prompt = "\n".join([context, prompt])
    return prompt

def COPA_few_shot_choices(demonstrations, test):
    context = "\n".join([format_COPA_in_context(ex) for ex in demonstrations])
    prompts = COPA_choices(test)
    prompts = {str(i): ("\n".join([context, prompt])) for i, prompt in enumerate(prompts)}
    return prompts

# Replicating Baselines

## Prompt formats

In [10]:
from collections import Counter
c = Counter()
for i, e in val["BoolQ"].items():
    c.update([e["label"]])
print(c)

Counter({True: 2033, False: 1237})


In [10]:
def format_BoolQ(ex: dict)->str:
    return f"{ex['passage']}\nquestion: {ex['question']}\nanswer:"

def format_BoolQ_in_context(ex: dict)->str:
    substitutions = ["no", "yes"]
    return f"{ex['passage']}\nquestion: {ex['question']}\nanswer: {substitutions[ex['label']]}"

In [11]:
def format_general_few_shot(demonstrations, test, dataset):
    context = [format_example_in_context(ex, dataset) for ex in demonstrations]
    
    if dataset in instructions:
        context = [instructions[dataset], *context]
    
    prompt = format_example(test, dataset)
    prompt = "\n\n".join([*context, prompt])
    return prompt

In [39]:
def format_CB(ex: dict)->str:
    return f"{ex['premise']}\nquestion: {ex['hypothesis']}. true, false, or neither?\nanswer:"
    
def format_CB_in_context(ex: dict)->str:
    substitutions = {"contradiction": "false", "entailment": "true", "neutral": "neither"}
    return  f"{ex['premise']}\nquestion: {ex['hypothesis']}. true, false, or neither?\nanswer: {substitutions[ex['label']]}"
    
def format_MultiRC(ex:dict):
    prompts = {}
    
    for q_id, question in ex['passage']["questions"].items():
        prompt = (f"READING COMPREHENSION ANSWER KEY\n{ex['passage']['text']}\n\n"
            f"{question['question']}\n")
        for a_id, answer in question["answers"].items():
            prefix = "_".join([str(x) for x in [ex["idx"], q_id, a_id]])
            for label in ["False", "True"]:
                key = "_".join([prefix, label])
                prompts[key] = "".join([prompt, f"[{label}] {answer['text']}"])

    return prompts

def format_MultiRC(ex:dict):
    prompts = {}
    
    for q_id, question in ex['passage']["questions"].items():
        prompt = (f"READING COMPREHENSION ANSWER KEY\n{ex['passage']['text']}\n\n"
            f"{question['question']}\n")
        for a_id, answer in question["answers"].items():
            prefix = "_".join([str(x) for x in [ex["idx"], q_id, a_id]])
            for label in ["False", "True"]:
                key = "_".join([prefix, label])
                prompts[key] = "".join([prompt, f"[{label}] {answer['text']}"])

    return prompts


def format_MultiRC(ex: dict):
    prompts = []
    labels = []
    substitutions = ["False", "True"]
    
    def format_answer(label, text):
        return f"[{substitutions[label]}] {text}"
    
    for question in ex['passage']["questions"]:
        for answer in question["answers"]:
            prompt = (f"READING COMPREHENSION ANSWER KEY\n{ex['passage']['text']}\n\n"
                        f"{question['question']}\n")
            prompts.append((prompt + format_answer(0, answer['text']), prompt + format_answer(1, answer['text'])))
            labels.append(srompt + format_answer(answer['label'], answer['text']))
            
    return prompts

def format_ReCoRD(ex, in_context=False):
    prompts = {}
    
    text = ex['passage']['text']
    entities = set({text[e["start"]:e["end"]+1] for e in ex['passage']['entities']})
    
    for q_id, query in ex["qas"].items():
        entity_prompts = {}
        for entity in entities:
            key = "_".join([str(x) for x in [ex["idx"], q_id, entity]])
            prompts[key] = "".join([text.replace("@highlight\n", "- "), "\n- ", query["query"].replace("@placeholder", entity)])     

    return prompts

def format_RTE(ex:dict)->str:
    return (f"{ex['premise']}"
            f"\nquestion: {ex['hypothesis']} True or False?"
            f"\nanswer:")

def format_WiC(ex:dict)->str:
    return (f"{ex['sentence1']}\n{ex['sentence2']}\n"
            f"question: Is the word \'{ex['word']}\' used in the same way in the two sentences above?"
            f"answer:")

def format_WiC_in_context(ex:dict)->str:
    substitutions = ["no", "yes"]
    return (f"{ex['sentence1']}\n{ex['sentence2']}\n"
            f"question: Is the word \'{ex['word']}\' used in the same way in the two sentences above?"
            f"answer: {substitutions[ex['label']]}")

def format_WSC(ex: dict)->str:    
    return (f"Passage: {ex['text']}\n"
            f"Question: In the passage above, does the pronoun " 
            f"\"{ex['target']['span2_text']}\" refer to {ex['target']['span1_text']}?\n"
            f"Answer:")

def format_WSC_in_context(ex: dict)->str: 
    substitutions = ["no", "yes"]
    return (f"Passage: {ex['text']}\n"
            f"Question: In the passage above, does the pronoun " 
            f"\"{ex['target']['span2_text']}\" refer to {ex['target']['span1_text']}?\n"
            f"Answer: {substitutions[ex['label']]}")

def format_Winograd(ex:dict)->str:
    text = ex['text'].split(" ")
    pronoun = ex['target']['span2_text']
    index = ex['target']['span2_index']
    assert pronoun == text[index]
    text[index] = "".join(["*", text[index], "*"])
    
    passage = " ".join(text)
    return (f"{instruction}"
             f"\nPassage: {passage}"
             f"\nQuestion: In the passage above, what does the pronoun \"*{pronoun}*\" refer to?\nAnswer:")

def format_Winograd(ex:dict)->str:
    text = ex['text'].split(" ")
    pronoun = ex['target']['span2_text']
    index = ex['target']['span2_index']
    assert pronoun == text[index]
    text[index] = "".join(["*", text[index], "*"])
    
    passage = " ".join(text)
    return (f"{instruction}"
             f"\nPassage: {passage}"
             f"\nQuestion: In the passage above, what does the pronoun \"*{pronoun}*\" refer to?\nAnswer:{ex['target']['span1_text']}")

def first_word(s):
    return "".join([c for c in re.split(" |\n|</s>",s.strip())[0] if str.isalpha(c)]).lower()

In [13]:
def format_example(ex : dict, dataset: str):
    assert dataset in tasks
    
    templates = {"BoolQ": format_BoolQ}
    
    return templates[dataset](ex)

def format_example_in_context(ex : dict, dataset: str)->str:
    assert dataset in tasks
    
    templates = {"BoolQ": format_BoolQ}
    
    return templates[dataset](ex)

def format_few_shot(demonstrations, test, dataset):
    assert dataset in tasks
    
    templates = {"BoolQ": format_general_few_shot}
    
    return templates[dataset](demonstrations, test, dataset)

## Zero-shot functions

In [14]:
def zero_shot_inference(model, ex, dataset):
#     prompt = format_few_shot([], ex, dataset)
    prompt = format_example(ex, dataset)
    tokens = tokenizer(prompt, padding=True, return_tensors="pt").to('cuda:0')
    outputs = model.generate(
         **tokens,
         max_new_tokens=max_generated_len,
         temperature=0,
         return_dict_in_generate=True,
         output_scores=True,
         eos_token_id=198,  # special character 'ċ' (bytecode for new line?) NOTE use this for generation
    )
    output_text = tokenizer.decode(outputs.sequences.squeeze()[-max_generated_len:])
    res = first_word(output_text)    
    return res

def zero_shot_scoring(model, prompts):
    encoded_prompt = tokenizer(
        prompts,
        truncation=True,
        padding=True,
        return_tensors="pt",
    )
    encoded_prompt = encoded_prompt.to("cuda:0")
    with torch.no_grad():
        logits = model(**encoded_prompt,).logits
        
    labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)
    masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
        logits.float(), dim=-1
    )
    seq_token_log_probs = torch.gather(
        masked_log_probs, -1, encoded_prompt["input_ids"].unsqueeze(-1)
    )
    seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
    seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
    return seq_log_prob

## Batch Operations

In [15]:
batch_size = 16
gen_len = 50

def batch_inference(prompts):
    output_tokens = torch.empty(0, dtype=torch.int64).to('cuda:0')
    
    num_batches = round(len(prompts) / batch_size + 0.5)
    
    for batch in tqdm(range(num_batches)): 
        start = batch * batch_size
        end = min((batch + 1) * batch_size, len(prompts))
        
        # tokenize by batch to mitigate effect of long outliers
        tokens = tokenizer(prompts[start:end], padding=True, return_tensors="pt").to('cuda:0')
        outputs = model.generate(
            **tokens,
            max_new_tokens=gen_len,
            temperature=0,
            return_dict_in_generate=True,
            output_scores=True,
            eos_token_id=198,  # special character 'ċ' (bytecode for new line?) NOTE use this for generation
        )
        output_tokens = torch.cat((output_tokens, outputs.sequences[:, -gen_len:]))
        
    return output_tokens

def batch_scoring(model, prompts):
    log_probs = torch.empty(0, dtype=torch.float32).to('cpu:0')
    
    ids, prompts = zip(*prompt_map.items())
    
    num_batches = round(len(prompts) / batch_size + 0.5)
    with torch.no_grad():
        for batch in tqdm(range(num_batches)): 
            start = batch * batch_size
            end = min((batch + 1) * batch_size, len(prompts))

            # tokenize by batch to mitigate effect of long outliers
            tokens = tokenizer(prompts[start:end], padding=True, return_tensors="pt").to('cuda:0')
            
            logits = model(tokens.input_ids, attention_mask=tokens.attention_mask).logits
            labels_attention_mask = tokens.attention_mask.unsqueeze(-1)
            masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
                logits.float(), dim=-1
            )
            seq_token_log_probs = torch.gather(
                masked_log_probs, -1, tokens.input_ids.unsqueeze(-1)
            )
            seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
            seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
            
            log_probs = torch.cat((log_probs, seq_log_prob))
    
    return log_probs

## Verbalize 

In [16]:
def verbalize_TrueFalse(text):
    if first_word(text) in ["yes", "true"]:
        return True
    elif first_word(text) in ["no", "false"]:
        return False
    else:
        return None
    
def verbalize_CB(text):
    if first_word(text) in ["yes", "true"]:
        return "entailment"
    elif first_word(text) in ["no", "false"]:
        return "contradiction"
    elif first_word(text) in ["neither", "neutral"]:
        return "neutral"
    else:
        return None
    
def verbalize_RTE(text):
    if first_word(text) in ["yes", "true"]:
        return "entailment"
    elif first_word(text) in ["no", "false"]:
        return "not_entailment"
    else:
        return None
    
def common_words(s1, s2):
    s1, s2 = s1.split(" "), s2.split(" ")
    return len(set(s1).intersection(set(s2)))

def verbalize_WSC(text):
    first_sentence = text.split("\n")[0].lower().rstrip('.')
    return re.sub("(.+) refers to ", "", first_sentence).strip()
    
def verbalize(text : dict, dataset: str):
    assert dataset in tasks
    
    templates = {"BoolQ": verbalize_TrueFalse,
                 "CB": verbalize_CB,
                "RTE": verbalize_RTE,
                "WiC": verbalize_TrueFalse,
                "WSC": verbalize_WSC}
    
    return templates[dataset](text)

## Runs

In [None]:
t = ""
n = min(1000, len(data[t].keys()))

### One at a time

In [None]:
zero_shot_res = []
for idx, ex in tqdm(list(data[t].items())[:n]):
    zero_shot_res.append(zero_shot_inference(model, ex, t))

In [None]:
correct = 0
total = 0
tpe = 0
fpe = 0
zero_shot_correct = []
s = []
for idx, x in enumerate(zero_shot_res):
    ex = data[t][idx]
    r = verbalize(x, t)
    if r != None and r == ex["label"]:
        zero_shot_correct.append(idx)
        correct+=1
    elif r == None:
        if ex["label"]:
            tpe+=1
        else:
            fpe+=1
        s.append(x)
        
    total+=1
    
print("Accuracy", correct/total)
print("Bad (True)", tpe/total)
print("Bad (False)", fpe/total)
print(" | ".join(s))

#### Batch

In [24]:
t = "BoolQ"
d = data[t]
n = min(4000, len(d))

In [25]:
prompts = []
for idx, ex in tqdm(d.items()):
    prompts.append(format_example(ex, t))
#     prompts.append(format_few_shot([], ex, t))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3270/3270 [00:00<00:00, 210937.61it/s]


In [26]:
prompts[0]

"Ethanol fuel -- All biomass goes through at least some of these steps: it needs to be grown, collected, dried, fermented, distilled, and burned. All of these steps require resources and an infrastructure. The total amount of energy input into the process compared to the energy released by burning the resulting ethanol fuel is known as the energy balance (or ``energy returned on energy invested''). Figures compiled in a 2007 report by National Geographic Magazine point to modest results for corn ethanol produced in the US: one unit of fossil-fuel energy is required to create 1.3 energy units from the resulting ethanol. The energy balance for sugarcane ethanol produced in Brazil is more favorable, with one unit of fossil-fuel energy required to create 8 from the ethanol. Energy balance estimates are not easily produced, thus numerous such reports have been generated that are contradictory. For instance, a separate survey reports that production of ethanol from sugarcane, which requires 

In [27]:
x = batch_inference(prompts[:n])
xd =  t

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 205/205 [07:15<00:00,  2.12s/it]


In [28]:
correct = 0
tpe = 0
fpe = 0
batch_correct = []
wrong = []
lost = []
for i, e in enumerate(xd):
    v = verbalize(e, t)
    if v != None and v == data[t][i]["label"]:
        correct+=1
        batch_correct.append(i)
    elif v != None and v != data[t][i]["label"]:
        wrong.append((xd[i], first_word(e), data[t][i]["label"]))
    elif v == None and data[t][i]["label"]:
        tpe= tpe + 1
        lost.append((xd[i], first_word(e), data[t][i]["label"]))
    elif v == None and not data[t][i]["label"]:
        fpe = fpe + 1
        lost.append((xd[i], first_word(e), data[t][i]["label"]))
    
print("Accuracy", correct/n)
print("True", tpe/n)
print("False", fpe/n)

Accuracy 0.5978593272171254
True 0.010091743119266056
False 0.004281345565749235


### Labels

In [68]:
from collections import Counter
c = Counter()
for i, e in val["COPA"].items():
    c.update([e["question"]])
print(c)

Counter({'cause': 52, 'effect': 48})


In [None]:
format_example(data["RTE"][3], "RTE")

In [None]:
format_example(data[t][0], t)

In [29]:
torch.cuda.empty_cache()

In [None]:
for t in tasks:
    print(t, len(data[t].keys()))

In [40]:
len(lost)

50

In [62]:
for l in lost:
    print(l)

("\n\nI'm not", 'im', 'not_entailment')
(' "I am not sure', 'i', 'entailment')
(' Steve Jobs worked for Apple', 'steve', 'entailment')
(' 1.7 percent of', '', 'not_entailment')
('     ', '', 'not_entailment')
(' Answer:  Answer:', 'answer', 'entailment')
('     ', '', 'not_entailment')
(' The Padres won the game', 'the', 'not_entailment')
(' Fallujah and Baqu', 'fallujah', 'entailment')
(' Slovenia has 3,000', 'slovenia', 'not_entailment')
('\n\nThe tragedy of', 'the', 'entailment')
(' 50th Anniversary of Normandy', 'th', 'entailment')
('     ', '', 'not_entailment')
(' breast milk may help fight', 'breast', 'entailment')
(' Pibul was the', 'pibul', 'entailment')
('     ', '', 'entailment')
(' Qualcomm is a US company', 'qualcomm', 'entailment')
('     ', '', 'not_entailment')
(' Oracle sells financial software.', 'oracle', 'entailment')
(' "I am not sure', 'i', 'entailment')
('     ', '', 'not_entailment')
('     ', '', 'entailment')
('     ', '', 'entailment')
(' Irene is going to', 

In [163]:
print(format_example(wsc[5], 'WSC'))

Final Exam with Answer Key
Instructions: Please carefully read the following passages. For each passage, you must identify which noun the pronoun marked in *bold* refers to.
=====
Passage: The large ball crashed right through the table because *it* was made of styrofoam.
Question: In the passage above, what does the pronoun "*it*" refer to?
Answer:


In [98]:
def common_words(s1, s2):
    s1, s2 = s1.split(" "), s2.split(" ")
    return len(set(s1).intersection(set(s2)))

def verbalize_WSC(text):
    first_sentence = text.split("\n")[0].lower().rstrip('.')
    return re.sub("(.+) refers to ", "", first_sentence).strip()

In [120]:
x = batch_inference(prompts[:n])
xd = tokenizer.batch_decode(x)
correct = 0
total = 0
batch_correct = []
wrong = []
for i, e in enumerate(zip(xd, wsc)):
    v = verbalize(e[0], t)
    z = e[1]
    label = wsc[z]["target"]["span1_text"].lower()
    if label in v or v in label or common_words(v, label) >= len(label.split(" ")) / 2:
        correct+=1
        batch_correct.append(i)
    else:
        print(wsc[z]["text"])
        print(f"~{v}~", f"~{label}~")
        print("="*20)
        wrong.append((xd[i], v, data[t][i]["label"]))
    total+=1

    
print("Accuracy", correct/total)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.58s/it]

I tried to paint a picture of an orchard, with lemons in the lemon trees , but they came out looking more like light bulbs.
~the person who is painting the picture~ ~lemons~
Mr. Moncrieff visited Chester 's luxurious New York apartment, thinking that it belonged to his son Edward . The result was that Mr. Moncrieff has decided to cancel Edward 's allowance on the ground that he no longer requires his financial support.
~mr. moncrieff~ ~edward~
Meanwhile, in the forest, the elephants are calling and hunting high and low for Arthur and Celeste , and their mothers are very worried. Fortunately, in flying over the town, an old marabou bird has seen them and come back quickly to tell the news.
~the elephants~ ~arthur and celeste~
Always before, Larry had helped Dad with his work. But he could not help him now, for Dad said that his boss at the railroad company would not want anyone but him to work in the office.
~larry~ ~dad~
Papa looked down at the children 's faces , so puzzled and sad no




In [171]:
parse_MultiRC(val["MultiRC"][0])

{'0_0_0_False': 'READING COMPREHENSION ANSWER KEY\nWhat causes a change in motion? The application of a force. Any time an object changes motion, a force has been applied. In what ways can this happen? Force can cause an object at rest to start moving. Forces can cause objects to speed up or slow down. Forces can cause a moving object to stop. Forces can also cause a change in direction. In short, forces cause changes in motion. The moving object may change its speed, its direction, or both. We know that changes in motion require a force. We know that the size of the force determines the change in motion. How much an objects motion changes when a force is applied depends on two things. It depends on the strength of the force. It also depends on the objects mass. Think about some simple tasks you may regularly do. You may pick up a baseball. This requires only a very small force. \n\nWould the mass of a baseball affect how much force you have to use to pick it up?\n[False] No',
 '0_0_0_

In [173]:
def format_MultiRC(ex, in_context=False):
    prompts = {}
    
    for q_id, question in ex['passage']["questions"].items():
        prompt = (f"READING COMPREHENSION ANSWER KEY\n{ex['passage']['text']}\n\n"
            f"{question['question']}\n")
        for a_id, answer in question["answers"].items():
            prefix = "_".join([str(x) for x in [ex["idx"], q_id, a_id]])
            for label in ["False", "True"]:
                key = "_".join([prefix, label])
                prompts[key] = "".join([prompt, f"[{label}] {answer['text']}"])

    return prompts

In [178]:
%%timeit
prompt_map = {}
for idx, ex in list(val["MultiRC"].items()):
    prompt_map.update(format_MultiRC(ex))
ids, prompts = list(prompt_map.keys()), list(prompt_map.values())

14 ms ± 209 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [61]:
print("Scoring prompts...", end="")
probs = batch_scoring(model, prompts)
mapped = {idx: prob.item() for idx, prob in zip(ids, probs)}

pairs = {}
for idx, prob in mapped.items():
    ex_id, q_id, a_id, label_id = idx.split("_")
    k = "_".join([ex_id, q_id, a_id])
    if k not in pairs: 
        pairs[k] = torch.empty((2), dtype=torch.float32)
    pairs[k][int(label_id)] = prob

results = {}
for k, v in pairs.items():
    results[k] = v.argmax(-1).item()
print("done")

Scoring prompts...

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 606/606 [10:11<00:00,  1.01s/it]

done





In [65]:
#Accuracy across questions
correct = 0
total = 0
for k, v in results.items():
    ex_id, q_id, a_id = [int(i) for i in k.split("_")]
    label = val["MultiRC"][ex_id]["passage"]["questions"][q_id]["answers"][a_id]["label"]
    if v == label:
        correct+=1
    total+=1

print(correct/total)

0.5016501650165016


In [62]:
question_results = {}
for k, v in results.items():
    ex_id, q_id, a_id = [int(i) for i in k.split("_")]
    question_id = "_".join([str(x) for x in [ex_id, q_id]])
    if question_id not in question_results:
        question_results[question_id] = {}
    question_results[question_id][a_id] = v

In [63]:
correct = 0
total = 0
for idx, question in question_results.items():
    ex_id, q_id = [int(i) for i in idx.split("_")]
    c = 0
    for a_id, pred in question.items():        
        label = val["MultiRC"][ex_id]["passage"]["questions"][q_id]["answers"][a_id]["label"]
        if pred == label:
            c+=1
    if c == len(question):
        correct+=1
    total+=1
print(correct/total)

0.024134312696747113


In [183]:
%time
prompt_map = {}
for idx, ex in list(val["ReCoRD"].items()):
    prompt_map.update(format_ReCoRD(ex))
ids, prompts = list(prompt_map.keys()), list(prompt_map.values())

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.05 µs


In [190]:
list(mapped.keys())[0]

'0_0_Twitter'

In [187]:
print("Scoring prompts...", end="")
probs = batch_scoring(model, prompts)
mapped = {idx: prob.item() for idx, prob in zip(ids, probs)}

Scoring prompts...

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7078/7078 [12:25<00:00,  9.50it/s]


ValueError: too many values to unpack (expected 3)

In [195]:
results = {}
for idx, prob in mapped.items():
    try:
        ex_id, q_id, entity = idx.split("_")
        entity
    except Exception as e:
        idxs = idx.split("_")
        ex_id, q_id, entity = idxs[0], idxs[1], "_".join(idxs[2:])
    k = "_".join([ex_id, q_id])
    if k not in results: 
        results[k] = {}
    results[k][entity] = prob

In [197]:
write_jsonl(results, "ReCoRD_zero_125m_results.jsonl")

In [205]:
choices = {}
for k, v in results.items():
    choices[k] = max(v, key=v.get)

In [210]:
list(choices.items())[:10]

[('0_0', 'NBC'),
 ('1_1', 'UK'),
 ('2_2', 'United'),
 ('3_3', 'Germany'),
 ('4_4', 'Florida'),
 ('4_5', 'Lago'),
 ('4_6', 'CNN'),
 ('4_7', 'Trump'),
 ('4_8', 'Lago'),
 ('4_9', 'CNN')]

In [230]:
correct = 0
total = 0
for idx, pred in choices.items():
    ex_id, q_id = [int(i) for i in idx.split("_")]
    answers = val["ReCoRD"][ex_id]["qas"][q_id]["answers"]
    answers = set([a["text"] for a in answers])
    if pred in answers:
        correct+=1
    total+=1
print(correct/total)

0.1135


In [9]:
key_d = "_"
def format_BoolQ(ex: dict):
    prompts = {}
    
    prompt = f"{ex['passage']}\nquestion: {ex['question']}\nanswer:"
    
    labels = ["no", "yes"]
    key = key_d.join([str(x) for x in [ex["idx"]]])

    for i, e in enumerate(labels):
        key = key_d.join([str(x) for x in [ex["idx"],i]])
        prompts[key] = " ".join([prompt, e])
        
    return prompts

In [36]:
model_name = "facebook/opt-2.7b"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()

In [37]:
prompt_map = {}
for idx, ex in list(val["BoolQ"].items()):
    prompt_map.update(format_BoolQ(ex))

In [38]:
def run_task(model_name, task_name, num_examples=None):
    model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
    prompt_map = {}
    
    num_examples = min(num_examples, len(val[task_name]))
    for idx, ex in list(val[task_name].items())[:num_examples]:
        prompt_map = {}.update(format_BoolQ(ex))
    
    ids, prompts = zip(*prompt_map.items())
    probs = batch_scoring(model, prompts)
    
    results = {}
    for i, prob in enumerate(probs):
        key = ids[i]
        ex_id, label = key.split(key_d)
        label = int(label) #BoolQ uses 
        if ex_id not in results: 
            results[ex_id] = {}
        results[ex_id][label] = prob.item()
        
    write_jsonl(results, f"{task_name}_{model_name}_zero_results.jsonl")
    
    choices = {}
    for k, v in results.items():
        choices[k] = max(v, key=v.get)
        
    correct, total = 0, 0
    for ex_id, pred in choices.items():
        pred = True if pred == 1 else 0
        if pred == val[task_name][int(ex_id)]['label']:
            correct+=1
        total+=1

    print(correct/total)
    return correct/total

In [39]:
ids, prompts = zip(*prompt_map.items())

In [40]:
probs = batch_scoring(model, prompts)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 409/409 [08:44<00:00,  1.28s/it]


In [41]:
results = {}
for i, prob in enumerate(probs):
    key = ids[i]
    ex_id, label = key.split(key_d)
    label = int(label) #BoolQ uses 
    if ex_id not in results: 
        results[ex_id] = {}
    results[ex_id][label] = prob.item()
    
choices = {}
for k, v in results.items():
    choices[k] = max(v, key=v.get)

In [84]:
l = [1, 2, 3, 4]
l[0:-1]

[1, 2, 3]

In [83]:
def pack_COPA(ids, probs):
    results = {}
    for i, prob in enumerate(probs):
        key = ids[i]
        label = int(key.split(delim)[-1])
        dem_key = delim.join(key.split(delim)[0:-1])
        if ex_id not in results: 
            results[dem_key] = {}
        results[dem_key][label] = prob.item()
    return results

def verbalize_COPA(results):
    choices = {}
    for k, v in results.items():
        choices[k] = max(v, key=v.get)

In [42]:
correct, total = 0, 0
for ex_id, pred in choices.items():
    pred = True if pred == 1 else 0
    if pred == val["BoolQ"][int(ex_id)]['label']:
        correct+=1
    total+=1

print(correct/total)

0.500611620795107


In [None]:
def efficient_scoring(model, passages, queries):
    log_probs = torch.empty(0, dtype=torch.float32).to('cpu:0')
    
    # passages {id: text}
    # queries {passage_query_id: text}
    passages
    
    
    
    num_batches = round(len(prompts) / batch_size + 0.5)
    with torch.no_grad():
        for batch in tqdm(range(num_batches)): 
            start = batch * batch_size
            end = min((batch + 1) * batch_size, len(prompts))

            # tokenize by batch to mitigate effect of long outliers
            tokens = tokenizer(prompts[start:end], padding=True, return_tensors="pt").to('cuda:0')
            
            logits = model(tokens.input_ids, attention_mask=tokens.attention_mask).logits
            labels_attention_mask = tokens.attention_mask.unsqueeze(-1)
            masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
                logits.float(), dim=-1
            )
            seq_token_log_probs = torch.gather(
                masked_log_probs, -1, tokens.input_ids.unsqueeze(-1)
            )
            seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
            seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
            
            log_probs = torch.cat((log_probs, seq_log_prob))
        
    return log_probs

In [53]:
prompts = format_COPA(val["COPA"][0])
prompts

['The man turned on the faucet so the toilet filled with water.',
 'The man turned on the faucet so water flowed from the spout.']

In [46]:
log_probs = torch.empty(0, dtype=torch.float32).to('cpu:0')

In [54]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side='left',)

In [55]:
tokens = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda:0')

In [56]:
tokenizer.batch_decode(tokens.input_ids)

['<pad></s>The man turned on the faucet so the toilet filled with water.',
 '</s>The man turned on the faucet so water flowed from the spout.']

In [60]:
with torch.no_grad():
    res = model(tokens.input_ids, attention_mask=tokens.attention_mask)

In [61]:
res

CausalLMOutputWithPast(loss=None, logits=tensor([[[-3.0585, -2.8703, 11.2414,  ..., -2.7916, -3.3031, -3.1531],
         [-1.3149, -1.2285,  3.2588,  ..., -1.0370, -1.3902, -1.2169],
         [-4.5053, -4.6381, -2.4266,  ..., -4.6098, -4.4628, -4.4893],
         ...,
         [-1.9493, -2.0619,  3.9064,  ..., -2.2500, -1.6910, -2.3190],
         [ 3.1190,  3.1187, 13.2996,  ...,  3.0830,  2.8017,  2.7808],
         [-2.0322, -1.4204, 10.0002,  ..., -1.4385, -1.8405, -1.6747]],

        [[-1.3149, -1.2285,  3.2588,  ..., -1.0370, -1.3902, -1.2169],
         [-4.5053, -4.6381, -2.4266,  ..., -4.6098, -4.4628, -4.4893],
         [ 1.2275,  1.1053,  6.4566,  ...,  1.0994,  1.1545,  0.8698],
         ...,
         [ 1.5653,  1.7930,  6.3448,  ...,  1.3615,  1.6330,  1.3823],
         [ 1.2497,  1.6617, 12.4435,  ...,  1.6200,  1.3743,  1.1372],
         [-2.1398, -1.5684, 10.0955,  ..., -1.6566, -1.9249, -1.8672]]],
       device='cuda:0'), past_key_values=((tensor([[[[-1.2940e+00,  3.2353e

In [62]:
labels_attention_mask = tokens.attention_mask.unsqueeze(-1)
labels_attention_mask

tensor([[[0],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]]], device='cuda:0')

In [63]:
masked_log_probs = labels_attention_mask.float() * torch.log_softmax(
    logits.float(), dim=-1
)
masked_log_probs

tensor([[[ -0.0000,  -0.0000,  -0.0000,  ...,  -0.0000,  -0.0000,  -0.0000],
         [-15.2423, -15.1559, -10.6686,  ..., -14.9644, -15.3176, -15.1443],
         [-16.5837, -16.7165, -14.5050,  ..., -16.6882, -16.5413, -16.5677],
         ...,
         [-18.2769, -18.3895, -12.4212,  ..., -18.5775, -18.0186, -18.6465],
         [-18.5857, -18.5860,  -8.4051,  ..., -18.6217, -18.9029, -18.9239],
         [-19.0999, -18.4881,  -7.0676,  ..., -18.5062, -18.9082, -18.7425]],

        [[-15.2423, -15.1559, -10.6686,  ..., -14.9644, -15.3176, -15.1443],
         [-16.5837, -16.7165, -14.5050,  ..., -16.6882, -16.5413, -16.5677],
         [-17.5555, -17.6777, -12.3264,  ..., -17.6836, -17.6285, -17.9132],
         ...,
         [-20.1722, -19.9445, -15.3928,  ..., -20.3760, -20.1045, -20.3553],
         [-19.3283, -18.9163,  -8.1345,  ..., -18.9579, -19.2036, -19.4408],
         [-18.8336, -18.2621,  -6.5983,  ..., -18.3504, -18.6187, -18.5610]]],
       device='cuda:0')

In [70]:
seq_token_log_probs = torch.gather(
    masked_log_probs, -1, tokens.input_ids.unsqueeze(-1)
)
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
seq_token_log_probs

tensor([[ -0.0000, -10.6686, -14.2186,  -9.5525, -11.4017,  -7.7413,  -7.9770,
         -12.2092, -15.9495, -13.3197,  -9.0634,  -8.2436,  -9.6349, -10.6182,
         -11.3130, -10.6234, -13.0543],
        [-10.6686, -14.2186,  -9.5525, -11.4017,  -7.7413,  -7.9770, -12.2092,
         -15.9495, -13.3197,  -9.0634,  -9.3281, -12.3265,  -9.0664,  -8.3279,
         -14.7244,  -9.8498, -12.8928]], device='cuda:0')

In [71]:
seq_log_prob = seq_token_log_probs.sum(dim=-1).to("cpu")
seq_log_prob

tensor([-175.5889, -188.6173])

In [30]:
def model_mem_usage(model, tokens):
    kv = (2 * tokens.numel() * 
          model.config.hidden_size * 
          model.config.num_hidden_layers * 4) 
    logits = tokens.numel() * model.config.vocab_size * 4
    return kv + logits

In [40]:
def memory_usage(t)-> float:
    i = 0
    if not isinstance(t, torch.Tensor):
        for x in t:
            i+=memory_usage(x)
    else:
        return t.element_size() * t.numel()
    return i

In [32]:
prompts = ["I like cheese" * 10] * 200

In [33]:
tokens = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda:0')

In [34]:
with torch.no_grad():
    res = model(**tokens)

In [42]:
model_mem_usage(model, tokens.input_ids)

1703859200

In [43]:
res

CausalLMOutputWithPast(loss=None, logits=tensor([[[-3.4348, -3.4329, 13.0188,  ..., -3.4879, -3.5034, -3.7238],
         [-7.2210, -7.2205, -0.3695,  ..., -7.2632, -7.2661, -6.8043],
         [-6.8937, -6.8857, -3.2591,  ..., -6.9950, -7.1412, -6.9377],
         ...,
         [-7.5777, -7.5518,  4.1764,  ..., -7.7254, -7.5638, -7.6214],
         [-6.4905, -6.4529,  2.4248,  ..., -6.4983, -6.4843, -6.5645],
         [-5.8134, -5.7868, 10.6286,  ..., -5.8808, -5.7975, -5.6505]],

        [[-3.4348, -3.4329, 13.0188,  ..., -3.4879, -3.5034, -3.7238],
         [-7.2210, -7.2205, -0.3695,  ..., -7.2632, -7.2661, -6.8043],
         [-6.8937, -6.8857, -3.2591,  ..., -6.9950, -7.1412, -6.9377],
         ...,
         [-7.5777, -7.5518,  4.1764,  ..., -7.7254, -7.5638, -7.6214],
         [-6.4905, -6.4529,  2.4248,  ..., -6.4983, -6.4843, -6.5645],
         [-5.8134, -5.7868, 10.6286,  ..., -5.8808, -5.7975, -5.6505]],

        [[-3.4348, -3.4329, 13.0188,  ..., -3.4879, -3.5034, -3.7238],
    

In [45]:
memory_usage(res.logits) + memory_usage(res.past_key_values)

1703859200

### MICE 

In [61]:
train.keys()

dict_keys(['BoolQ', 'CB', 'COPA', 'MultiRC', 'ReCoRD', 'RTE', 'WiC', 'WSC', 'Winograd'])

In [66]:
val["ReCoRD"][0]

{'source': 'Daily mail',
 'passage': {'text': "Tracy Morgan hasn't appeared on stage since the devastating New Jersey crash that nearly ended his life last summer, but all that will change this fall when he returns to host Saturday Night Live. NBC announced on Twitter Monday that Morgan, an SNL alum with seven seasons as a cast member under his belt, will headline the third episode of Season 41 airing October 17. For Morgan, 46, it will be a second time hosting the long-running variety show, the first since the June 2014 pileup on the New Jersey Turnpike that killed his friend and mentor James 'Jimmy Mack' McNair.\n@highlight\nMorgan, 46, will host third episode of season 41 of SNL airing October 17\n@highlight\nHe tweeted to his fans: 'Stoked to be going home...#SNL'\n@highlight\nFor the SNL alum who had spent seven years as cast member, it will be a second time hosting the show\n@highlight\nMorgan has been sidelined by severe head trauma suffered in deadly June 2014 crash on New Jers

{'6', 'c', '$', 'á', 'z', 'N', 'U', '=', '\xa0', ';', 'ö', '\xad', 'O', 'u', '*', 'Y', 'Ä', '-', 'í', '–', 'w', 'T', '"', '(', 'h', '3', 'y', '8', 'v', 'F', '’', 'l', "'", '\u200e', 'H', 'Â', '±', 'S', 'b', '—', 'A', 'X', 'ú', 'L', '_', '[', ':', 't', 'ó', 'Ã', 'M', 'k', 'Z', 'ä', 'p', '7', 'j', 'G', 'C', 'e', ')', '/', 'P', 'K', 'f', 'I', ',', 'V', '5', 'r', 'R', 'ô', 'ć', '.', ']', 'É', 'g', 's', 'i', 'è', '&', ' ', 'E', 'q', 'J', 'é', 'W', 'x', '@', 'ş', 'ã', 'ñ', 'a', 'm', '#', '4', '‘', 'n', 'ü', 'œ', '2', '?', '9', 'B', '°', '1', 'D', 'Ö', 'd', '0', 'Q', 'ç', 'o', '+', '§', '!'}
129511
CPU times: user 128 ms, sys: 0 ns, total: 128 ms
Wall time: 126 ms


### Dead Code 

In [96]:
# def batch_score(model, prompts, labels):
#     prompt_tokens = tokenizer(prompts, padding=True, return_tensors="pt", return_attention_mask=True).to("cuda:0")
#     pad_length = prompt_tokens.input_ids.shape[-1]
#     labels = tokenizer(labels, padding="max_length", max_length=pad_length, return_tensors="pt").to("cuda:0")
#     labels_mask = labels.attention_mask
# #     print("prompt_tokens", prompt_tokens)
# #     print("labels_mask", labels_mask)
#     with torch.no_grad():
#         logits = model(**prompt_tokens).logits

# #     print("logits", logits)

# #     print(prompt_tokens.input_ids)

#     sequence_logits = logits.log_softmax(-1).gather(-1, prompt_tokens.input_ids.unsqueeze(-1)).squeeze(-1)
# #     print("sequence_logits", sequence_logits)
#     labels_logits = labels_mask * sequence_logits
# #     print("labels_logits", labels_logits)
#     log_probs = labels_logits.sum(-1).to("cpu")
#     return log_probs

In [16]:
# from evaluate import load
# super_glue_metric = load('super_glue', 'record') 

# predictions = []
# for k, v in choices.items():
#     q_id = int(k.split("_")[1])
#     predictions.append({"idx": str(q_id), "prediction_text": v})
    
# references = []
# for idx, ex in val["ReCoRD"].items():
#     for q_id, query in ex["qas"].items():
#         r = set()
#         for a in query["answers"]:
#             r.add(a["text"])
#         references.append({"idx": str(q_id), "answers": list(r)})

In [92]:
chars = set()
i = 0
for k, v in train["ReCoRD"].items():
    for entity in v['passage']['entities']:
        e = v['passage']['text'][entity['start']:entity['end']+1]
        chars.update(*e)
        i+=1
print(chars)
print(i)
for k, v in train["ReCoRD"].items():
    for entity in v['passage']['entities']:
        e = v['passage']['text'][entity['start']:entity['end']+1]
        if "|" in e:
            print(e)
print("done")

{'6', '\u200b', 'c', '$', 'á', 'z', 'N', 'U', '=', 'ž', 'ö', '\xad', 'O', 'u', '*', 'Y', 'Ä', '%', '-', 'í', '–', 'õ', 'w', 'T', 'ë', '"', 'Ü', 'š', '(', '\u202a', 'h', '3', 'y', 'ė', '8', 'v', 'F', 'ª', '’', '¨', 'l', "'", '\u200e', 'H', '±', 'Â', 'č', 'S', 'Á', 'b', 'ê', 'A', 'X', 'ú', 'L', 'Ž', 'ł', '́', '_', 'ǃ', ':', '[', 'Î', '“', 't', 'ó', 'Ã', 'ý', '¼', 'º', 'M', 'ū', 'k', 'Z', 'ä', 'p', '7', 'j', 'G', 'Ó', 'C', 'å', 'ğ', '¶', 'Í', '”', 'e', '™', ')', '/', 'P', 'K', 'f', 'I', ',', '©', 'V', '5', 'r', 'R', 'ô', '.', 'Ì', ']', '\u200f', 'ð', 'ï', 'É', 'g', 's', 'i', 'è', '&', '˜', ' ', 'î', '´', 'E', 'ź', 'q', 'J', 'é', '¿', '€', 'à', '£', 'W', 'x', '@', 'ß', 'ã', 'ş', 'ñ', 'ø', 'a', 'm', '#', 'ă', '®', '4', '‘', '¯', 'n', '½', '`', 'ù', 'ü', 'ż', '2', '?', '9', 'ā', '¡', 'B', '°', '1', 'D', 'Ö', 'd', '¹', '\u2009', '0', '³', 'Q', 'ç', 'â', 'o', '+', '\\', 'ţ', '§', '!'}
1129798
done


In [93]:
"|" in chars

False

In [95]:
format_BoolQ(val["BoolQ"][1])

"Property tax -- Property tax or 'house tax' is a local tax on buildings, along with appurtenant land. It is and imposed on the Possessor (not the custodian of property as per 1978, 44th amendment of constitution). It resembles the US-type wealth tax and differs from the excise-type UK rate. The tax power is vested in the states and is delegated to local bodies, specifying the valuation method, rate band, and collection procedures. The tax base is the annual rental value (ARV) or area-based rating. Owner-occupied and other properties not producing rent are assessed on cost and then converted into ARV by applying a percentage of cost, usually four percent. Vacant land is generally exempt. Central government properties are exempt. Instead a 'service charge' is permissible under executive order. Properties of foreign missions also enjoy tax exemption without requiring reciprocity. The tax is usually accompanied by service taxes, e.g., water tax, drainage tax, conservancy (sanitation) tax,