In [1]:
import argparse
from dataclasses import dataclass
import json
import logging
import os
import random
import sys
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

import transformers
import wandb
from tqdm.auto import tqdm, trange


logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s: %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

In [2]:
@dataclass
class CustomArguments(transformers.TrainingArguments):
    sample_train: int = 0
    sample_eval: int = 0
    num_choices: int = 0
    model_name_or_path: str = "asdf"  # this is no longer a TrainingArgument attribute
        
    # python dataclasses cannot have positional attributes in subclass,
    # so give all attributes defaults and then make sure they are changed
    def __post_init__(self):
        if not (self.sample_train * self.sample_eval * self.num_choices) or \
               self.model_name_or_path == "asdf":  # make sure none are still default value
            raise TypeError("__init__ missing required argument(s)")

def get_args():
    """ Set hyperparameters """
    args = CustomArguments(
        output_dir="checkpoint",
        model_name_or_path="roberta-base",
        overwrite_output_dir=True,
        do_train=False,  # Zero shot
        do_eval=True,
        per_device_eval_batch_size=8,
        learning_rate=1e-5,  # Should not matter because not training
        weight_decay=0.1,
        save_total_limit=2,
        seed=123,
        sample_train=200,
        sample_eval=-1,
        num_choices=2,
    )
    
    return args

In [3]:
def get_data(file_path, sample, num_choices):
    data_file = open(file_path, "r")
    logger.info("Reading QA instances from jsonl dataset at: %s", file_path)
    item_jsons = []
    item_ids = []
    questions = []
    choice_lists = []
    answer_ids = []
    for line in data_file:
        item_jsons.append(json.loads(line.strip()))

    if sample != -1:
        item_jsons = random.sample(item_jsons, sample)
        logger.info("Sampling %d examples", sample)

    for item_json in tqdm(item_jsons,total=len(item_jsons)):
        item_id = item_json["id"]

        question_text = item_json["question"]["stem"]

        choice_label_to_id = {}
        choice_text_list = []
        choice_context_list = []
        choice_label_list = []
        choice_annotations_list = []

        any_correct = False
        choice_id_correction = 0

        for choice_id, choice_item in enumerate(item_json["question"]["choices"]):
            choice_label = choice_item["label"]
            choice_label_to_id[choice_label] = choice_id - choice_id_correction
            choice_text = choice_item["text"]

            choice_text_list.append(choice_text)
            choice_label_list.append(choice_label)

            if item_json.get('answerKey') == choice_label:
                if any_correct:
                    raise ValueError("More than one correct answer found for {item_json}!")
                any_correct = True


        if not any_correct and 'answerKey' in item_json:
            raise ValueError("No correct answer found for {item_json}!")


        answer_id = choice_label_to_id.get(item_json.get("answerKey"))
        # Pad choices with empty strings if not right number
        if len(choice_text_list) != num_choices:
            choice_text_list = (choice_text_list + num_choices * [''])[:num_choices]
            choice_context_list = (choice_context_list + num_choices * [None])[:num_choices]
            if answer_id is not None and answer_id >= num_choices:
                logging.warning(f"Skipping question with more than {num_choices} answers: {item_json}")
                continue

        item_ids.append(item_id)
        questions.append(question_text)
        choice_lists.append(choice_text_list)
        answer_ids.append(answer_id)

    data_file.close()
    return questions, choice_lists, answer_ids

In [4]:
class BERTDataset(Dataset):
    
    def __init__(self, questions, choices, answer_ids, tokenizer):
        out = tokenizer(questions, max_length=25, padding="max_length")
        self.input_ids = out["input_ids"]
        self.token_type_ids = out["token_type_ids"]
        self.attention_mask = out["attention_mask"]
        self.questions = questions
        self.choices = choices
        self.answer_ids = answer_ids
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, i):
        return {
            "input_ids": self.input_ids[i], 
            "attention_mask": self.attention_mask[i], 
            "token_type_ids": self.token_type_ids[i],
            "choice_list": self.choices[i], 
            "answer_id": self.answer_ids[i],
        }
    

class RoBERTaDataset(Dataset):
    
    def __init__(self, questions, choices, answer_ids, tokenizer):
#         if "t5" in tokenizer.name_or_path.lower():
#             questions = [question.replace('[MASK]', '') for question in questions]
#         else:
        questions = [question.replace('[MASK]', tokenizer.mask_token) for question in questions]
        out = tokenizer(questions, max_length=25, padding="max_length")
        self.input_ids = out["input_ids"]
        self.attention_mask = out["attention_mask"]
        self.questions = questions
        self.choices = choices
        self.answer_ids = answer_ids
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, i):
        return {
            "input_ids": self.input_ids[i], 
            "attention_mask": self.attention_mask[i], 
            "choice_list": self.choices[i], 
            "answer_id": self.answer_ids[i],
        }

In [5]:
def evaluate(args, model, tokenizer, eval_dataset):
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.per_device_eval_batch_size)

    logger.info(f"***** Running evaluation  *****")
    logger.info(f"  Num examples = {len(eval_dataset)}")
    logger.info(f"  Batch size = {args.eval_batch_size}")
    eval_dataloader = tqdm(eval_dataloader, desc="Evaluating")
    
    print(tokenizer.mask_token)
    MASK_ID = tokenizer.encode(tokenizer.mask_token, add_special_tokens=False)
    assert len(MASK_ID) == 1
    MASK_ID = MASK_ID[0]
#     if "t5" in args.model_name_or_path.lower():
# #         LABELS = tokenizer("<extra_id_0>", add_special_tokens=False, return_tensors="pt")
#         LABELS = tokenizer("<extra_id_0> potato <extra_id_1> </s>", add_special_tokens=False, return_tensors="pt")
#         LABELS = LABELS.input_ids#.cuda() 
    
    all_answers = []
    all_preds = []
    all_attentions = torch.zeros((model.config.num_hidden_layers, model.config.num_attention_heads))  # (24, 16) for BERT-large
    
    head_mask = torch.ones(model.config.num_hidden_layers, model.config.num_attention_heads)
    
    for batch in eval_dataloader:
        model.eval()
        
        # batch["choice_list"] is [num_choices, batch_size]
        for i in range(len(batch["choice_list"][0])):
            all_answers.append(batch["choice_list"][batch["answer_id"][i]][i])
        
        choice_lists = batch.pop("choice_list")
        batch_len = len(batch["answer_id"])
        del batch["answer_id"] 
        for key in batch:
            batch[key] = torch.stack(batch[key], dim=-1)#.cuda()
        
        with torch.no_grad():
            #if "gpt" not in args.model_name_or_path.lower():
            if "t5" not in args.model_name_or_path.lower():
                outputs = model(**batch, output_attentions=True, head_mask=head_mask)
            else:
                outputs = model(input_ids=batch["input_ids"], decoder_input_ids=torch.zeros((len(batch["input_ids"]), 1), dtype=torch.int))
#                 outputs = model(input_ids=batch["input_ids"], decoder_input_ids=batch["input_ids"])
#                 BATCH_LABELS = LABELS.repeat(batch_len, 1)
#                 outputs = model(input_ids=batch["input_ids"], labels=BATCH_LABELS)
            
#             attentions = torch.stack(outputs.attentions) #[:,:,:,:-1, :-1]
            
#             for b in range(attentions.size()[1]):
#                 #sep_ind = (batch["input_ids"][b] == tokenizer.encode(tokenizer.sep_token, add_special_tokens=False)[0]).nonzero(as_tuple=True)[0].item()
#                 sep_ind = (batch["input_ids"][b] == tokenizer.encode(tokenizer.sep_token, add_special_tokens=False)[0]).nonzero(as_tuple=True)[0].item()
#                 for seq_ind1 in range(attentions.size()[-1]):
#                     for seq_ind2 in range(attentions.size()[-1]):
#                         if seq_ind1 == sep_ind or seq_ind2 == sep_ind or seq_ind1 == 0 or seq_ind2 == 0:
#                             attentions[:, b, :, seq_ind1, seq_ind2] = 0
            
#             maxes = torch.amax(attentions, dim=(3, 4))
#             sums = torch.sum(maxes, dim=1)
#             torch.add(all_attentions, sums, out=all_attentions)
            
            logits = outputs.logits
            print(logits.size())
            choice_ids = []
            
            for i, logit in enumerate(logits):  # Assuming all are single tokens
                choice_ids = torch.tensor([tokenizer.encode(" " + choice_lists[j][i], add_special_tokens=False)[0] for j in range(len(choice_lists))])
#                 print(choice_ids)
                if "t5" in args.model_name_or_path.lower():
#                     probs = logit[0].index_select(0, choice_ids)#.cuda()
                    probs = logit[0].index_select(0, choice_ids)#.cuda()
#                     print(probs)
                else:
                    MASK_INDEX = batch["input_ids"][i].tolist().index(MASK_ID) 
                    probs = logit[MASK_INDEX].index_select(0, choice_ids)#.cuda())
                
                max_ind = torch.argmax(probs)
                all_preds.append(choice_lists[max_ind][i])

    torch.div(all_attentions, len(eval_dataloader) * args.per_device_eval_batch_size, out=all_attentions)
    return all_answers, all_preds, all_attentions

In [7]:
args = get_args()
'''
"bert-base-uncased"
"distilbert-base-uncased"
"bert-large-uncased"
"bert-large-uncased-whole-word-masking" 
"roberta-large"
"facebook/bart-large"
"t5-large"
"albert-large-v1"
'''

args.model_name_or_path = "t5-large"
transformers.set_seed(args.seed)
if "t5" in args.model_name_or_path.lower():
    model = transformers.T5ForConditionalGeneration.from_pretrained(args.model_name_or_path)#.cuda()
#     args.per_device_eval_batch_size = 1
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)
    tokenizer.mask_token = "<extra_id_0>"
elif "gpt" in args.model_name_or_path.lower():
    model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name_or_path)#.cuda()
    args.per_device_eval_batch_size = 1
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)
    tokenizer.mask_token = "[MASK]"
else:
    model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name_or_path)#.cuda()
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)

KeyboardInterrupt: 

In [8]:
model.encoder.block

ModuleList(
  (0): T5Block(
    (layer): ModuleList(
      (0): T5LayerSelfAttention(
        (SelfAttention): T5Attention(
          (q): Linear(in_features=1024, out_features=1024, bias=False)
          (k): Linear(in_features=1024, out_features=1024, bias=False)
          (v): Linear(in_features=1024, out_features=1024, bias=False)
          (o): Linear(in_features=1024, out_features=1024, bias=False)
          (relative_attention_bias): Embedding(32, 16)
        )
        (layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (1): T5LayerFF(
        (DenseReluDense): T5DenseReluDense(
          (wi): Linear(in_features=1024, out_features=4096, bias=False)
          (wo): Linear(in_features=4096, out_features=1024, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layer_norm): T5LayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (1): T5Block(
    (layer): ModuleList(
      (0): T5LayerS

In [60]:
for i in range(5, model.config.num_layers):
    model.encoder.block[i] = model.encoder.block[4]

In [14]:
model.encoder.block[4].layer[0].SelfAttention.q.weight

Parameter containing:
tensor([[-0.0168,  0.0388,  0.0312,  ...,  0.0302, -0.0036, -0.0067],
        [ 0.0136,  0.0199,  0.0136,  ...,  0.0464,  0.0172,  0.0315],
        [ 0.0084,  0.0007,  0.0025,  ..., -0.0500, -0.0018,  0.0039],
        ...,
        [ 0.0270, -0.0004, -0.0171,  ...,  0.0078, -0.0005, -0.0009],
        [-0.0159,  0.0104, -0.0110,  ..., -0.0173, -0.0422,  0.0272],
        [ 0.0054, -0.0069, -0.0070,  ...,  0.0013, -0.0332, -0.0007]],
       requires_grad=True)

In [21]:
model.encoder.block[15].layer[0].SelfAttention.q.weight

Parameter containing:
tensor([[-0.0168,  0.0388,  0.0312,  ...,  0.0302, -0.0036, -0.0067],
        [ 0.0136,  0.0199,  0.0136,  ...,  0.0464,  0.0172,  0.0315],
        [ 0.0084,  0.0007,  0.0025,  ..., -0.0500, -0.0018,  0.0039],
        ...,
        [ 0.0270, -0.0004, -0.0171,  ...,  0.0078, -0.0005, -0.0009],
        [-0.0159,  0.0104, -0.0110,  ..., -0.0173, -0.0422,  0.0272],
        [ 0.0054, -0.0069, -0.0070,  ...,  0.0013, -0.0332, -0.0007]],
       requires_grad=True)

In [22]:
'''
"data/number_comparison_age_compare_masked_dev.jsonl"
"data/coffee_cats_quantifiers_dev.jsonl"
"data/size_comparison_dev.jsonl"
"data/antonym_synonym_negation_dev.jsonl"
"data/hypernym_conjunction_dev.jsonl"
'''
train_path = "data/number_comparison_age_compare_masked_dev.jsonl"
# train_path = "data/number_comparison_age_compare_masked_train.jsonl"
# train_path = "data/size_comparison_dev.jsonl"
# train_path = "data/coffee_cats_quantifiers_dev.jsonl"
args.num_choices = 2
eval_path = train_path #"data/coffee_cats_quantifiers_dev.jsonl"
train_questions, train_choices, train_answer_ids = get_data(train_path, args.sample_train, args.num_choices)
eval_questions, eval_choices, eval_answer_ids = get_data(eval_path, args.sample_eval, args.num_choices)
AgeDataset = RoBERTaDataset if any(prefix in args.model_name_or_path.lower() for prefix in ("roberta", "bart", "distil", "electra", "t5")) else BERTDataset
train_dataset = AgeDataset(train_questions, train_choices, train_answer_ids, tokenizer)
eval_dataset = AgeDataset(eval_questions, eval_choices, eval_answer_ids, tokenizer)
# eval_dataset = AgeDataset(eval_questions[:500], eval_choices[:500], eval_answer_ids[:500], tokenizer)

10/10/2021 12:55:49: Reading QA instances from jsonl dataset at: data/number_comparison_age_compare_masked_dev.jsonl
10/10/2021 12:55:49: Sampling 200 examples


  0%|          | 0/200 [00:00<?, ?it/s]

10/10/2021 12:55:49: Reading QA instances from jsonl dataset at: data/number_comparison_age_compare_masked_dev.jsonl


  0%|          | 0/500 [00:00<?, ?it/s]

In [77]:
all_answers, all_preds, all_attentions = evaluate(args, model, tokenizer, eval_dataset)
# 0.956 t5 normal
print((np.array(all_answers) == np.array(all_preds)).mean())  # t5albert all - 0.492, t5normal=0.494

10/10/2021 17:31:05: ***** Running evaluation  *****
10/10/2021 17:31:05:   Num examples = 500
10/10/2021 17:31:05:   Batch size = 8


Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

<extra_id_0>
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])
torch.Size([8, 1, 32128])