# Test Distributional Shift
This notebook tests distributional shift for the dataset SQuADShifts NYT.

## Imports

In [1]:
%%capture
!pip install transformers

In [2]:
%%capture
!pip install datasets

In [3]:
%%capture
!pip install nltk

In [4]:
%%capture
import nltk
nltk.download('punkt')

In [5]:
from transformers import BertTokenizer, BertModel, AlbertTokenizer, AlbertModel, DistilBertTokenizer, DistilBertModel, RobertaTokenizer, RobertaModel, ElectraTokenizer, ElectraModel
import torch
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from math import ceil

from nltk.tokenize import sent_tokenize, word_tokenize
import random
import numpy as np

## Loading and Preparing Texts

In [6]:
def remove_sentence(t, k=1):
    sentences = random.sample(sent_tokenize(t), k)
    text = t
    for s in sentences:
        text = text.replace(s, '')
        assert len(t) != len(text)
    return text, len(word_tokenize(s))

def remove_word(t, k=1):
    text = t
    for _ in range(k):
        words = word_tokenize(text)
        if not len([i for i,w in enumerate(words) if w.isalnum()]):
            break
        word_idx = random.choice([i for i,w in enumerate(words) if w.isalnum()])
        del words[word_idx]
        text = ''.join([(' ' if w.isalnum() else '')+w for w in words]).strip(' ')
    return text

In [7]:
dataset = load_dataset('squadshifts', 'nyt')['test']
print('Average number of sentences = {}'.format(np.mean([len(sent_tokenize(sample['context'])) for sample in dataset ])))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2317.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1400.0, style=ProgressStyle(description…


Downloading and preparing dataset squad_shifts/nyt (download: 15.74 MiB, generated: 10.29 MiB, post-processed: Unknown size, total: 26.03 MiB) to /root/.cache/huggingface/datasets/squad_shifts/nyt/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=773853.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1117198.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1056991.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1030185.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset squad_shifts downloaded and prepared to /root/.cache/huggingface/datasets/squad_shifts/nyt/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039. Subsequent calls will reuse this data.
Average number of sentences = 6.047093889716841


In [8]:
N_WORD = 5
texts = []
dataset = load_dataset('squad', split='train')
for d in tqdm(dataset):
    if d['context'] not in texts:
        texts.append(d['context'])
random.shuffle(texts)        
n_samples = len(texts)
original = texts#random.sample(texts, n_samples)#texts[:n_samples]
word_removed = []
sentence_removed = []
for t in tqdm(original):
    sr, nw = remove_sentence(t)
    wr = remove_word(t, k=N_WORD)
    word_removed.append(wr)
    sentence_removed.append(sr)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1946.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=955.0, style=ProgressStyle(description_…


Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.75 MiB, post-processed: Unknown size, total: 119.27 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  2%|▏         | 1827/87599 [00:00<00:04, 18266.19it/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41. Subsequent calls will reuse this data.


100%|██████████| 87599/87599 [00:20<00:00, 4357.13it/s]
100%|██████████| 18891/18891 [01:37<00:00, 194.27it/s]


In [9]:
text_data = original+word_removed+sentence_removed
classes = [
    (n_samples, 'Original', 'tab:blue'),
    (n_samples, 'Word', 'tab:green'),
    (n_samples, 'Sentence', 'tab:red')
]

## Compare Distributional Shift

In [10]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import numpy as np

In [11]:
#def get_accuracy_rf(x_original, x_new):
#  X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
#  y = np.array([0]*n_samples+[1]*n_samples)
#  X_train, X_test, y_train, y_test = train_test_split(X, y)
#  clf = RandomForestClassifier()
#  clf.fit(X_train, y_train)
#  print('Train : {}'.format(clf.score(X_train, y_train)))
#  print('Test : {}'.format(clf.score(X_test, y_test)))
def get_accuracy_rf(x_original, x_new):
  X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
  y = np.array([0]*n_samples+[1]*n_samples)
  X_train, X_test, y_train, y_test = train_test_split(X, y)
  oob_score = 0
  best_clf = None
  for depth in [2, 5, 7, 10, 15, 20]:
    clf = RandomForestClassifier(max_depth=depth, oob_score=True)
    clf.fit(X_train, y_train)
    print('Depth {}, OOB score {}'.format(depth, clf.oob_score_))
    if clf.oob_score_>oob_score:
      oob_score = clf.oob_score_
      best_clf = clf
  print('--> Depth={}'.format(best_clf.max_depth))
  print('-'*20)
  train_acc = best_clf.score(X_train, y_train)
  test_acc = best_clf.score(X_test, y_test)
  print(X_train.shape)
  print('Train : {}({})'.format(train_acc, train_acc*(1-train_acc)/X_train.shape[0]))
  print('Test : {}({})'.format(test_acc, test_acc*(1-test_acc)/X_test.shape[0]))
  print('-'*20)

In [12]:
def process_texts(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[0][:, 0, :].cpu().detach()

def batch(l, size):
    for i in range(0, len(l), size):
        yield l[i:i+size]

In [13]:
def models_generator():
  model_classes = [
                   BertModel,
                   RobertaModel
  ]
  tokenizer_classes = [
                       BertTokenizer,
                       RobertaTokenizer
  ]
  model_names = [
                 'bert-base-uncased',
                 'roberta-base'
  ]
  for model_name, tokenizer_class, model_class in zip(model_names, tokenizer_classes, model_classes):
    tokenizer = tokenizer_class.from_pretrained(model_name)
    model = model_class.from_pretrained(model_name).to(device)
    yield model, tokenizer, model_name.split('-')[0].upper()

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
for model, tokenizer, name in models_generator():
  print('Model: {}'.format(name))
  results = []
  for b in batch(text_data, 5):
      results.append(process_texts(b, model, tokenizer))
  embeddings = torch.cat(results, 0)
  e_original = embeddings[:n_samples, :]
  e_word = embeddings[n_samples:2*n_samples, :]
  e_sentence = embeddings[2*n_samples:, :]
  print('Sentence vs Original')
  get_accuracy_rf(e_original, e_sentence)
  print('Word vs Original')
  get_accuracy_rf(e_original, e_word)
  print('\n\n')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


Model: BERT
Sentence vs Original
Depth 2, OOB score 0.5664878599661208
Depth 5, OOB score 0.5784867306606437
Depth 7, OOB score 0.5705463015245624
Depth 10, OOB score 0.5383963862224732
Depth 15, OOB score 0.4667207792207792
Depth 20, OOB score 0.4308300395256917
--> Depth=5
--------------------
(28336, 768)
Train : 0.6096485036702428(8.398405055155165e-06)
Test : 0.5858564471734067(2.5685863908401443e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.7751623376623377
Depth 5, OOB score 0.7933018068887634
Depth 7, OOB score 0.8020186335403726
Depth 10, OOB score 0.8011010728402033
Depth 15, OOB score 0.7856789949181253
Depth 20, OOB score 0.7787972896668549
--> Depth=7
--------------------
(28336, 768)
Train : 0.8417913608130999(4.699981143193441e-06)
Test : 0.8055261486343426(1.658413852431353e-05)
--------------------





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…


Model: ROBERTA
Sentence vs Original
Depth 2, OOB score 0.732813382269904
Depth 5, OOB score 0.754199604743083
Depth 7, OOB score 0.7583639186900056
Depth 10, OOB score 0.7474943534726144
Depth 15, OOB score 0.7161208356860531
Depth 20, OOB score 0.6978402032749859
--> Depth=7
--------------------
(28336, 768)
Train : 0.8147233201581028(5.327118568205122e-06)
Test : 0.7577810713529536(1.943139098582717e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.812605872388481
Depth 5, OOB score 0.8377682100508187
Depth 7, OOB score 0.8513904573687182
Depth 10, OOB score 0.8587309429700734
Depth 15, OOB score 0.855378317334839
Depth 20, OOB score 0.8525903444381705
--> Depth=10
--------------------
(28336, 768)
Train : 0.9321357989836251(2.232448166176735e-06)
Test : 0.8726445056108406(1.1765411013979694e-05)
--------------------





In [16]:
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




In [17]:
print('Model: DistilBert')
results = []
for b in batch(text_data, 2):
    results.append(process_texts(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: DistilBert
Sentence vs Original
Depth 2, OOB score 0.5496894409937888
Depth 5, OOB score 0.5548418972332015
Depth 7, OOB score 0.5346908526256352
Depth 10, OOB score 0.47663749294184077
Depth 15, OOB score 0.37450592885375494
Depth 20, OOB score 0.3264398644833427
--> Depth=5
--------------------
(28336, 768)
Train : 0.6301171654432524(8.225209036455861e-06)
Test : 0.5664831674782977(2.5998304937757e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.7539172783738001
Depth 5, OOB score 0.7705392433653303
Depth 7, OOB score 0.7786914172783738
Depth 10, OOB score 0.7729390175042349
Depth 15, OOB score 0.7478119706380576
Depth 20, OOB score 0.7386363636363636
--> Depth=7
--------------------
(28336, 768)
Train : 0.8252752682100508(5.08879163928852e-06)
Test : 0.7794833792081304(1.8197018922973152e-05)
--------------------





In [18]:
model = ElectraModel.from_pretrained('google/electra-small-discriminator').to(device)
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=54245363.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [19]:
print('Model: Electra')
results = []
for b in batch(text_data, 2):
    results.append(process_texts(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: Electra
Sentence vs Original
Depth 2, OOB score 0.5850508187464709
Depth 5, OOB score 0.5931324110671937
Depth 7, OOB score 0.5902738565782044
Depth 10, OOB score 0.5748164878599661
Depth 15, OOB score 0.5296089779785432
Depth 20, OOB score 0.49594155844155846
--> Depth=5
--------------------
(28336, 256)
Train : 0.6204474872953134(8.310714384678336e-06)
Test : 0.6051238619521491(2.529631311118627e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.882128740824393
Depth 5, OOB score 0.9029503105590062
Depth 7, OOB score 0.9108554488989271
Depth 10, OOB score 0.9168195934500283
Depth 15, OOB score 0.918548842461886
Depth 20, OOB score 0.9187958780350085
--> Depth=20
--------------------
(28336, 256)
Train : 0.9939299830604178(2.1291543739184734e-07)
Test : 0.9145670124920602(8.271669717701463e-06)
--------------------



