In [1]:
!pip install nlp
!pip install git+https://github.com/huggingface/transformers.git

Collecting nlp
[?25l  Downloading https://files.pythonhosted.org/packages/09/e3/bcdc59f3434b224040c1047769c47b82705feca2b89ebbc28311e3764782/nlp-0.4.0-py3-none-any.whl (1.7MB)
[K     |████████████████████████████████| 1.7MB 2.8MB/s 
Collecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/f7/73/826b19f3594756cb1c6c23d2fbd8ca6a77a9cd3b650c9dec5acc85004c38/xxhash-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (242kB)
[K     |████████████████████████████████| 245kB 20.2MB/s 
[?25hCollecting pyarrow>=0.16.0
[?25l  Downloading https://files.pythonhosted.org/packages/f3/99/0a605f016121ca314d1469dc9069e4978395bc46fda40f73099d90ad3ba4/pyarrow-1.0.1-cp36-cp36m-manylinux2014_x86_64.whl (17.3MB)
[K     |████████████████████████████████| 17.3MB 198kB/s 
Installing collected packages: xxhash, pyarrow, nlp
  Found existing installation: pyarrow 0.14.1
    Uninstalling pyarrow-0.14.1:
      Successfully uninstalled pyarrow-0.14.1
Successfully installed nlp-0.4.0 pyarrow-1.0.1 xx

In [2]:
from pathlib import Path
from copy import deepcopy

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from transformers import XLNetLMHeadModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from nlp import load_dataset
from tqdm.notebook import tqdm
from sklearn.metrics import matthews_corrcoef

  import pandas.util.testing as tm


In [3]:
np.random.seed(314)
train_full = load_dataset('winogrande', 'winogrande_xl', split='train')
train_full = train_full.shuffle()
inds = np.arange(len(train_full))
split_loc = int(len(inds)*.8)
train = train_full.select(inds[:split_loc])
val = train_full.select(inds[split_loc:])

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5659.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=9381.0, style=ProgressStyle(description…


Downloading and preparing dataset winogrande/winogrande_xl (download: 2.67 MiB, generated: 5.35 MiB, post-processed: Unknown sizetotal: 8.01 MiB) to /root/.cache/huggingface/datasets/winogrande/winogrande_xl/1.0.0/4582c2f2a293ae12f94efa46a02a9c2156d9ce9c207a419dc18692449aae236e...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2797793.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset winogrande downloaded and prepared to /root/.cache/huggingface/datasets/winogrande/winogrande_xl/1.0.0/4582c2f2a293ae12f94efa46a02a9c2156d9ce9c207a419dc18692449aae236e. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=33.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




In [4]:
device = torch.device('cuda')

In [5]:
tokenizer = AutoTokenizer.from_pretrained("xlnet-large-cased")
model = XLNetLMHeadModel.from_pretrained("xlnet-large-cased", mem_len=2**14).to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=761.0, style=ProgressStyle(description_…






HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798011.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1441285815.0, style=ProgressStyle(descr…




In [6]:
for p in model.parameters():
    p.requires_grad = False
    p.grad = None

In [7]:
def create_pred_encodings(encodings, pred_loc):
    """ Create encodings needed for XLNet """
    encodings = deepcopy(encodings)
    encodings['input_ids'][:,pred_loc] = tokenizer.mask_token_id
    encodings['attention_mask'][:,pred_loc] = 0.

    seqlen = encodings.input_ids.size(1)
    perm_mask = torch.zeros((1, seqlen, seqlen))
    perm_mask[:,:,pred_loc[0]] = 1.0 # todo: make sure this is right
    encodings['perm_mask'] = perm_mask

    target_mapping = torch.zeros((1, 1, seqlen))
    target_mapping[:,:,pred_loc[0]] = 1.
    encodings['target_mapping'] = target_mapping

    return encodings

In [8]:
def long_seq_pass(model, encodings):
    """ Get outputs when output is greater than 1024 """
    max_len = model.base_model.mem_len
    num_passes = encodings.input_ids.size(1) // max_len + 1
    mems = None
    for i in range(num_passes):
        ran = slice(i*max_len, (i+1)*max_len)
        trunc_encodings = {
            'input_ids': encodings.input_ids[:,ran].to(device),
            'token_type_ids': encodings.token_type_ids[:,ran].to(device),
            'attention_mask': encodings.attention_mask[:,ran].to(device),
            'target_mapping': encodings.target_mapping[...,ran].to(device),
            'perm_mask': encodings.perm_mask[...,ran,ran].to(device)
        }
        if trunc_encodings['perm_mask'].bool().any():
            trunc_encodings['labels'] = encodings.labels.to(device)
        outputs = model(**trunc_encodings, mems=mems)
        mems = outputs[-1]
    return outputs

In [12]:
def get_ppl(model, cloze_sequence, target, context_examples=[]):
    """ Calculate the "perplexity" of the target subsequence given the sequence and context examples """
    # create context
    replaced_examples = []
    for ex in context_examples:
        label = int(ex['answer']) - 1
        options = [ex['option1'], ex['option2']]
        replaced_examples.append(ex['sentence'].replace('_', options[label]))
    context = ' '.join(replaced_examples)

    # create target encoding
    target_ids = tokenizer.encode(target, add_special_tokens=False)
    repl_id = tokenizer.encode('_', add_special_tokens=False)[0]
    encodings = tokenizer(context, cloze_sequence, return_tensors='pt')
    target_start = torch.where(encodings.input_ids[0] == repl_id)[0]
    target_locs = list(range(target_start, target_start + len(target_ids)))

    outputs = None
    lls = []
    for trg_loc, trg in zip(target_locs, target_ids):
        trg_loc = list(range(trg_loc, target_locs[-1] + 1))
        pred_encodings = create_pred_encodings(encodings, trg_loc).to(device)
        pred_encodings['labels'] = torch.tensor(trg).view(-1)
        outputs = long_seq_pass(model, pred_encodings)
        lls.append(outputs[0])

    return (torch.mean(torch.stack(lls)),) + outputs[1:]

In [13]:
from tqdm.notebook import tqdm

In [None]:
num_reps = 1 # number of times to replicate the experiment
accs = {n_train: [] for n_train in [0,1,50]} # number of training examples to try
num_eval = 500

for rep in range(num_reps):
    val = val.shuffle(seed=np.random.randint(100000)) # workaround bc of seed bug in nlp
    for num_train in accs:
        print('Number of training points:', num_train)
        preds = []
        labels = []
        it = tqdm(val, miniters=5)
        for iexample, example in enumerate(it):
            context_examples = [train[int(i)] for i in np.random.choice(len(train), num_train, replace=False)]

            first_outputs = get_ppl(model, example['sentence'], example['option1'], context_examples)
            second_outputs = get_ppl(model, example['sentence'], example['option2'], context_examples)
            first_loss = first_outputs[0]
            second_loss = second_outputs[0]

            pred = torch.stack([first_loss, second_loss]).argmin()
            preds.append(pred.item())
            labels.append(int(example['answer']) - 1)

            acc = (np.array(preds) == np.array(labels)).mean()
            if iexample >= 0 and iexample % 5 == 0:
                it.set_description(f'acc: {acc*100:0.4f}%')

            if (iexample + 1) % num_eval == 0:
                break

        print(num_train, acc)
        accs[num_train].append(acc)

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Number of training points: 0


HBox(children=(FloatProgress(value=0.0, max=8080.0), HTML(value='')))

0 0.602
Number of training points: 1


HBox(children=(FloatProgress(value=0.0, max=8080.0), HTML(value='')))

1 0.632
Number of training points: 50


HBox(children=(FloatProgress(value=0.0, max=8080.0), HTML(value='')))