# Test Distributional Shift
This notebook tests distributional shift for the dataset SQuADShifts Amazon.

## Imports

In [1]:
%%capture
!pip install transformers

In [2]:
%%capture
!pip install datasets

In [3]:
%%capture
!pip install nltk

In [4]:
%%capture
import nltk
nltk.download('punkt')

In [5]:
from transformers import BertTokenizer, BertModel, AlbertTokenizer, AlbertModel, DistilBertTokenizer, DistilBertModel, RobertaTokenizer, RobertaModel, ElectraTokenizer, ElectraModel
import torch
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from math import ceil

from nltk.tokenize import sent_tokenize, word_tokenize
import random
import numpy as np

## Loading and Preparing Texts

In [6]:
def remove_sentence(t, k=1):
    sentences = random.sample(sent_tokenize(t), k)
    text = t
    for s in sentences:
        text = text.replace(s, '')
        assert len(t) != len(text)
    return text, len(word_tokenize(s))

def remove_word(t, k=1):
    text = t
    for _ in range(k):
        words = word_tokenize(text)
        if not len([i for i,w in enumerate(words) if w.isalnum()]):
            break
        word_idx = random.choice([i for i,w in enumerate(words) if w.isalnum()])
        del words[word_idx]
        text = ''.join([(' ' if w.isalnum() else '')+w for w in words]).strip(' ')
    return text

In [7]:
dataset = load_dataset('squadshifts', 'amazon')['test']
print('Average number of sentences = {}'.format(np.mean([len(sent_tokenize(sample['context'])) for sample in dataset ])))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2317.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1400.0, style=ProgressStyle(description…


Downloading and preparing dataset squad_shifts/amazon (download: 15.74 MiB, generated: 9.00 MiB, post-processed: Unknown size, total: 24.74 MiB) to /root/.cache/huggingface/datasets/squad_shifts/amazon/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=773853.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1117198.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1056991.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1030185.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset squad_shifts downloaded and prepared to /root/.cache/huggingface/datasets/squad_shifts/amazon/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039. Subsequent calls will reuse this data.
Average number of sentences = 7.833687405159332


In [8]:
N_WORD = 5
texts = []
dataset = load_dataset('squad', split='train')
for d in tqdm(dataset):
    if d['context'] not in texts:
        texts.append(d['context'])
random.shuffle(texts)        
n_samples = len(texts)
original = texts#random.sample(texts, n_samples)#texts[:n_samples]
word_removed = []
sentence_removed = []
for t in tqdm(original):
    sr, nw = remove_sentence(t)
    wr = remove_word(t, k=N_WORD)
    word_removed.append(wr)
    sentence_removed.append(sr)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1946.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=955.0, style=ProgressStyle(description_…


Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.75 MiB, post-processed: Unknown size, total: 119.27 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  2%|▏         | 1944/87599 [00:00<00:04, 19433.81it/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41. Subsequent calls will reuse this data.


100%|██████████| 87599/87599 [00:19<00:00, 4385.81it/s]
100%|██████████| 18891/18891 [01:34<00:00, 200.61it/s]


In [9]:
text_data = original+word_removed+sentence_removed
classes = [
    (n_samples, 'Original', 'tab:blue'),
    (n_samples, 'Word', 'tab:green'),
    (n_samples, 'Sentence', 'tab:red')
]

## Compare Distributional Shift

In [10]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import numpy as np

In [11]:
#def get_accuracy_rf(x_original, x_new):
#  X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
#  y = np.array([0]*n_samples+[1]*n_samples)
#  X_train, X_test, y_train, y_test = train_test_split(X, y)
#  clf = RandomForestClassifier()
#  clf.fit(X_train, y_train)
#  print('Train : {}'.format(clf.score(X_train, y_train)))
#  print('Test : {}'.format(clf.score(X_test, y_test)))
def get_accuracy_rf(x_original, x_new):
  X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
  y = np.array([0]*n_samples+[1]*n_samples)
  X_train, X_test, y_train, y_test = train_test_split(X, y)
  oob_score = 0
  best_clf = None
  for depth in [2, 5, 7, 10, 15, 20]:
    clf = RandomForestClassifier(max_depth=depth, oob_score=True)
    clf.fit(X_train, y_train)
    print('Depth {}, OOB score {}'.format(depth, clf.oob_score_))
    if clf.oob_score_>oob_score:
      oob_score = clf.oob_score_
      best_clf = clf
  print('--> Depth={}'.format(best_clf.max_depth))
  print('-'*20)
  train_acc = best_clf.score(X_train, y_train)
  test_acc = best_clf.score(X_test, y_test)
  print(X_train.shape)
  print('Train : {}({})'.format(train_acc, train_acc*(1-train_acc)/X_train.shape[0]))
  print('Test : {}({})'.format(test_acc, test_acc*(1-test_acc)/X_test.shape[0]))
  print('-'*20)

In [12]:
def process_texts(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[0][:, 0, :].cpu().detach()

def batch(l, size):
    for i in range(0, len(l), size):
        yield l[i:i+size]

In [13]:
def models_generator():
  model_classes = [
                   BertModel,
                   RobertaModel
  ]
  tokenizer_classes = [
                       BertTokenizer,
                       RobertaTokenizer
  ]
  model_names = [
                 'bert-base-uncased',
                 'roberta-base'
  ]
  for model_name, tokenizer_class, model_class in zip(model_names, tokenizer_classes, model_classes):
    tokenizer = tokenizer_class.from_pretrained(model_name)
    model = model_class.from_pretrained(model_name).to(device)
    yield model, tokenizer, model_name.split('-')[0].upper()

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
for model, tokenizer, name in models_generator():
  print('Model: {}'.format(name))
  results = []
  for b in batch(text_data, 5):
      results.append(process_texts(b, model, tokenizer))
  embeddings = torch.cat(results, 0)
  e_original = embeddings[:n_samples, :]
  e_word = embeddings[n_samples:2*n_samples, :]
  e_sentence = embeddings[2*n_samples:, :]
  print('Sentence vs Original')
  get_accuracy_rf(e_original, e_sentence)
  print('Word vs Original')
  get_accuracy_rf(e_original, e_word)
  print('\n\n')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


Model: BERT
Sentence vs Original
Depth 2, OOB score 0.5678994918125353
Depth 5, OOB score 0.5813099943534726
Depth 7, OOB score 0.5711462450592886
Depth 10, OOB score 0.5385022586109542
Depth 15, OOB score 0.4688029361942405
Depth 20, OOB score 0.43298277809147373
--> Depth=5
--------------------
(28336, 768)
Train : 0.6115894974590627(8.383250425495266e-06)
Test : 0.5861740419225069(2.5680079874945797e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.7754446640316206
Depth 5, OOB score 0.79492518351214
Depth 7, OOB score 0.7987012987012987
Depth 10, OOB score 0.7969367588932806
Depth 15, OOB score 0.7865612648221344
Depth 20, OOB score 0.7809147374364765
--> Depth=7
--------------------
(28336, 768)
Train : 0.842285431959345(4.6880534679702355e-06)
Test : 0.8078551767944103(1.6432901770154813e-05)
--------------------





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…


Model: ROBERTA
Sentence vs Original
Depth 2, OOB score 0.7304136081309994
Depth 5, OOB score 0.7482707509881423
Depth 7, OOB score 0.7536702428006776
Depth 10, OOB score 0.7427300959909655
Depth 15, OOB score 0.7118859401468097
Depth 20, OOB score 0.6932171089779785
--> Depth=7
--------------------
(28336, 768)
Train : 0.8122882552230378(5.3810010463633115e-06)
Test : 0.7597925047639212(1.9321178749575268e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.8136293054771315
Depth 5, OOB score 0.8403444381705252
Depth 7, OOB score 0.8526962168266516
Depth 10, OOB score 0.8620482778091474
Depth 15, OOB score 0.8571781479390175
Depth 20, OOB score 0.8567546583850931
--> Depth=10
--------------------
(28336, 768)
Train : 0.9334062676453981(2.19364085141465e-06)
Test : 0.8654456912979038(1.2327910937094767e-05)
--------------------





In [16]:
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




In [17]:
print('Model: DistilBert')
results = []
for b in batch(text_data, 2):
    results.append(process_texts(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: DistilBert
Sentence vs Original
Depth 2, OOB score 0.5548066064370413
Depth 5, OOB score 0.55957086391869
Depth 7, OOB score 0.53906691134952
Depth 10, OOB score 0.47575522303783174
Depth 15, OOB score 0.37743506493506496
Depth 20, OOB score 0.3316629023150762
--> Depth=5
--------------------
(28336, 768)
Train : 0.627540937323546(8.248634574627022e-06)
Test : 0.5605547321617615(2.607803561431477e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.7610460191981931
Depth 5, OOB score 0.7733272162619989
Depth 7, OOB score 0.7768210050818747
Depth 10, OOB score 0.7753035008469791
Depth 15, OOB score 0.7461180124223602
Depth 20, OOB score 0.7385304912478825
--> Depth=7
--------------------
(28336, 768)
Train : 0.825592885375494(5.081496082469664e-06)
Test : 0.7858352741901334(1.781687444724235e-05)
--------------------





In [18]:
model = ElectraModel.from_pretrained('google/electra-small-discriminator').to(device)
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=54245363.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [19]:
print('Model: Electra')
results = []
for b in batch(text_data, 2):
    results.append(process_texts(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: Electra
Sentence vs Original
Depth 2, OOB score 0.5898503670242801
Depth 5, OOB score 0.5998729531338227
Depth 7, OOB score 0.5928147939017504
Depth 10, OOB score 0.5796160361377752
Depth 15, OOB score 0.5297501411631846
Depth 20, OOB score 0.4999647092038396
--> Depth=5
--------------------
(28336, 256)
Train : 0.6228825522303784(8.289803725202863e-06)
Test : 0.597713317806479(2.5455442253149486e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.8828345567476003
Depth 5, OOB score 0.9000917560700169
Depth 7, OOB score 0.9087380011293055
Depth 10, OOB score 0.9151962168266516
Depth 15, OOB score 0.9172430830039525
Depth 20, OOB score 0.9179488989271598
--> Depth=20
--------------------
(28336, 256)
Train : 0.9943887634105025(1.969138415243561e-07)
Test : 0.9230362058013974(7.520682678504945e-06)
--------------------



