# Test Distributional Shift
This notebook tests distributional shift for the dataset SQuADShifts Reddit.

## Imports

In [1]:
%%capture
!pip install transformers

In [2]:
%%capture
!pip install datasets

In [3]:
%%capture
!pip install nltk

In [4]:
%%capture
import nltk
nltk.download('punkt')

In [5]:
from transformers import BertTokenizer, BertModel, AlbertTokenizer, AlbertModel, DistilBertTokenizer, DistilBertModel, RobertaTokenizer, RobertaModel, ElectraTokenizer, ElectraModel
import torch
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from math import ceil

from nltk.tokenize import sent_tokenize, word_tokenize
import random
import numpy as np

## Loading and Preparing Texts

In [6]:
def remove_sentence(t, k=1):
    sentences = random.sample(sent_tokenize(t), k)
    text = t
    for s in sentences:
        text = text.replace(s, '')
        assert len(t) != len(text)
    return text, len(word_tokenize(s))

def remove_word(t, k=1):
    text = t
    for _ in range(k):
        words = word_tokenize(text)
        if not len([i for i,w in enumerate(words) if w.isalnum()]):
            break
        word_idx = random.choice([i for i,w in enumerate(words) if w.isalnum()])
        del words[word_idx]
        text = ''.join([(' ' if w.isalnum() else '')+w for w in words]).strip(' ')
    return text

In [7]:
dataset = load_dataset('squadshifts', 'reddit')['test']
print('Average number of sentences = {}'.format(np.mean([len(sent_tokenize(sample['context'])) for sample in dataset ])))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2317.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1400.0, style=ProgressStyle(description…


Downloading and preparing dataset squad_shifts/reddit (download: 15.74 MiB, generated: 9.03 MiB, post-processed: Unknown size, total: 24.77 MiB) to /root/.cache/huggingface/datasets/squad_shifts/reddit/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=773853.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1117198.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1056991.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1030185.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset squad_shifts downloaded and prepared to /root/.cache/huggingface/datasets/squad_shifts/reddit/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039. Subsequent calls will reuse this data.
Average number of sentences = 8.860246863205141


In [8]:
N_WORD = 5
texts = []
dataset = load_dataset('squad', split='train')
for d in tqdm(dataset):
    if d['context'] not in texts:
        texts.append(d['context'])
random.shuffle(texts)        
n_samples = len(texts)
original = texts#random.sample(texts, n_samples)#texts[:n_samples]
word_removed = []
sentence_removed = []
for t in tqdm(original):
    sr, nw = remove_sentence(t)
    wr = remove_word(t, k=N_WORD)
    word_removed.append(wr)
    sentence_removed.append(sr)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1946.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=955.0, style=ProgressStyle(description_…


Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.75 MiB, post-processed: Unknown size, total: 119.27 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  3%|▎         | 2221/87599 [00:00<00:03, 22209.49it/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41. Subsequent calls will reuse this data.


100%|██████████| 87599/87599 [00:18<00:00, 4817.91it/s]
100%|██████████| 18891/18891 [01:25<00:00, 220.78it/s]


In [9]:
text_data = original+word_removed+sentence_removed
classes = [
    (n_samples, 'Original', 'tab:blue'),
    (n_samples, 'Word', 'tab:green'),
    (n_samples, 'Sentence', 'tab:red')
]

## Compare Distributional Shift

In [10]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import numpy as np

In [11]:
#def get_accuracy_rf(x_original, x_new):
#  X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
#  y = np.array([0]*n_samples+[1]*n_samples)
#  X_train, X_test, y_train, y_test = train_test_split(X, y)
#  clf = RandomForestClassifier()
#  clf.fit(X_train, y_train)
#  print('Train : {}'.format(clf.score(X_train, y_train)))
#  print('Test : {}'.format(clf.score(X_test, y_test)))
def get_accuracy_rf(x_original, x_new):
  X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
  y = np.array([0]*n_samples+[1]*n_samples)
  X_train, X_test, y_train, y_test = train_test_split(X, y)
  oob_score = 0
  best_clf = None
  for depth in [2, 5, 7, 10, 15, 20]:
    clf = RandomForestClassifier(max_depth=depth, oob_score=True)
    clf.fit(X_train, y_train)
    print('Depth {}, OOB score {}'.format(depth, clf.oob_score_))
    if clf.oob_score_>oob_score:
      oob_score = clf.oob_score_
      best_clf = clf
  print('--> Depth={}'.format(best_clf.max_depth))
  print('-'*20)
  train_acc = best_clf.score(X_train, y_train)
  test_acc = best_clf.score(X_test, y_test)
  print(X_train.shape)
  print('Train : {}({})'.format(train_acc, train_acc*(1-train_acc)/X_train.shape[0]))
  print('Test : {}({})'.format(test_acc, test_acc*(1-test_acc)/X_test.shape[0]))
  print('-'*20)

In [12]:
def process_texts(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[1].detach()

def batch(l, size):
    for i in range(0, len(l), size):
        yield l[i:i+size]

In [13]:
def models_generator():
  model_classes = [
                   BertModel,
                   RobertaModel
  ]
  tokenizer_classes = [
                       BertTokenizer,
                       RobertaTokenizer
  ]
  model_names = [
                 'bert-base-uncased',
                 'roberta-base'
  ]
  for model_name, tokenizer_class, model_class in zip(model_names, tokenizer_classes, model_classes):
    tokenizer = tokenizer_class.from_pretrained(model_name)
    model = model_class.from_pretrained(model_name).to(device)
    yield model, tokenizer, model_name.split('-')[0].upper()

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
for model, tokenizer, name in models_generator():
  print('Model: {}'.format(name))
  results = []
  for b in batch(text_data, 5):
      results.append(process_texts(b, model, tokenizer))
  embeddings = torch.cat(results, 0)
  e_original = embeddings[:n_samples, :]
  e_word = embeddings[n_samples:2*n_samples, :]
  e_sentence = embeddings[2*n_samples:, :]
  print('Sentence vs Original')
  get_accuracy_rf(e_original, e_sentence)
  print('Word vs Original')
  get_accuracy_rf(e_original, e_word)
  print('\n\n')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


Model: BERT
Sentence vs Original
Depth 2, OOB score 0.5786631846414455
Depth 5, OOB score 0.5849449463579899
Depth 7, OOB score 0.5813805759457933
Depth 10, OOB score 0.568746470920384
Depth 15, OOB score 0.5398785996612083
Depth 20, OOB score 0.5132693393562959
--> Depth=5
--------------------
(28336, 768)
Train : 0.6042843026538679(8.438904016798294e-06)
Test : 0.587232691086174(2.5660645522534845e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.7240612648221344
Depth 5, OOB score 0.7723743647656691
Depth 7, OOB score 0.782326369282891
Depth 10, OOB score 0.7876199887069453
Depth 15, OOB score 0.7872317899491813
Depth 20, OOB score 0.7847967250141163
--> Depth=10
--------------------
(28336, 768)
Train : 0.8987859966120836(3.2103941595887653e-06)
Test : 0.7876349777683676(1.770761375864818e-05)
--------------------





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…


Model: ROBERTA
Sentence vs Original
Depth 2, OOB score 0.7149915302089215
Depth 5, OOB score 0.7395892151326934
Depth 7, OOB score 0.7381775832862789
Depth 10, OOB score 0.7293901750423489
Depth 15, OOB score 0.6925465838509317
Depth 20, OOB score 0.6793125352907962
--> Depth=5
--------------------
(28336, 768)
Train : 0.7696922642574816(6.255861187170833e-06)
Test : 0.7607452890112216(1.9268673963418856e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.8105237154150198
Depth 5, OOB score 0.8357919254658385
Depth 7, OOB score 0.8500846979107849
Depth 10, OOB score 0.8568958215697347
Depth 15, OOB score 0.8505434782608695
Depth 20, OOB score 0.8516727837380011
--> Depth=10
--------------------
(28336, 768)
Train : 0.9332298136645962(2.1990375688925474e-06)
Test : 0.8590937963159009(1.2815122321345996e-05)
--------------------





In [16]:
def process_texts_db(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[0][:, 0, :].cpu().detach()

In [17]:
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




In [18]:
print('Model: DistilBert')
results = []
for b in batch(text_data, 2):
    results.append(process_texts_db(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: DistilBert
Sentence vs Original
Depth 2, OOB score 0.5551242236024845
Depth 5, OOB score 0.5516304347826086
Depth 7, OOB score 0.534867306606437
Depth 10, OOB score 0.4763551665725579
Depth 15, OOB score 0.37238848108413325
Depth 20, OOB score 0.3277809147374365
--> Depth=2
--------------------
(28336, 768)
Train : 0.5695581592320723(8.65195025706683e-06)
Test : 0.5589667584162609(2.6098128456688368e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.7582227555053642
Depth 5, OOB score 0.7730801806888763
Depth 7, OOB score 0.7807382834556748
Depth 10, OOB score 0.7763975155279503
Depth 15, OOB score 0.7509175607001694
Depth 20, OOB score 0.7381422924901185
--> Depth=7
--------------------
(28336, 768)
Train : 0.8272162619988707(5.044096480924846e-06)
Test : 0.7839297056955326(1.7931814760072864e-05)
--------------------





In [19]:
model = ElectraModel.from_pretrained('google/electra-small-discriminator').to(device)
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=54245363.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [20]:
print('Model: Electra')
results = []
for b in batch(text_data, 2):
    results.append(process_texts_db(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: Electra
Sentence vs Original
Depth 2, OOB score 0.5856507622811971
Depth 5, OOB score 0.5957792207792207
Depth 7, OOB score 0.5955674760022586
Depth 10, OOB score 0.5758046301524562
Depth 15, OOB score 0.5295736871823828
Depth 20, OOB score 0.4966120835686053
--> Depth=5
--------------------
(28336, 256)
Train : 0.6203769057029926(8.311314249483796e-06)
Test : 0.6037476180393817(2.5326744839207555e-05)
--------------------
Word vs Original
Depth 2, OOB score 0.8801524562394127
Depth 5, OOB score 0.9017857142857143
Depth 7, OOB score 0.911914172783738
Depth 10, OOB score 0.9169254658385093
Depth 15, OOB score 0.9185135516657256
Depth 20, OOB score 0.9197487295313382
--> Depth=20
--------------------
(28336, 256)
Train : 0.9943534726143421(1.9814526094511811e-07)
Test : 0.9224010163031972(7.577533498414796e-06)
--------------------



