# Test Distributional Shift
This notebook tests distributional shift for the dataset of NYT articles.

## Imports

In [1]:
%%capture
!pip install transformers

In [2]:
%%capture
!pip install datasets

In [3]:
%%capture
!pip install nltk

In [4]:
%%capture
import nltk
nltk.download('punkt')

In [8]:
from transformers import BertTokenizer, BertModel, AlbertTokenizer, AlbertModel, DistilBertTokenizer, DistilBertModel
import torch
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from math import ceil

from nltk.tokenize import sent_tokenize, word_tokenize
import random
import numpy as np

## Loading and Preparing Texts

In [6]:
def remove_sentence(t, k=1):
    sentences = random.sample(sent_tokenize(t), k)
    text = t
    for s in sentences:
        text = text.replace(s, '')
        assert len(t) != len(text)
    return text, len(word_tokenize(s))

def remove_word(t, k=1):
    text = t
    for _ in range(k):
        words = word_tokenize(text)
        if not len([i for i,w in enumerate(words) if w.isalnum()]):
            break
        word_idx = random.choice([i for i,w in enumerate(words) if w.isalnum()])
        del words[word_idx]
        text = ''.join([(' ' if w.isalnum() else '')+w for w in words]).strip(' ')
    return text

In [9]:
dataset = load_dataset('squadshifts', 'nyt')['test']
print('Average number of sentences = {}'.format(np.mean([len(sent_tokenize(sample['context'])) for sample in dataset ])))

Reusing dataset squad_shifts (/root/.cache/huggingface/datasets/squad_shifts/nyt/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039)


Average number of sentences = 6.047093889716841


In [None]:
N_WORD = 5
texts = []
dataset = load_dataset('squadshifts', 'nyt')['test']
for d in tqdm(dataset):
    if d['context'] not in texts:
        texts.append(d['context'])
random.shuffle(texts)        
n_samples = len(texts)
original = texts#random.sample(texts, n_samples)#texts[:n_samples]
word_removed = []
sentence_removed = []
for t in tqdm(original):
    sr, nw = remove_sentence(t)
    wr = remove_word(t, k=N_WORD)
    word_removed.append(wr)
    sentence_removed.append(sr)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2317.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1400.0, style=ProgressStyle(description…


Downloading and preparing dataset squad_shifts/nyt (download: 15.74 MiB, generated: 10.29 MiB, post-processed: Unknown size, total: 26.03 MiB) to /root/.cache/huggingface/datasets/squad_shifts/nyt/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=773853.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1117198.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1056991.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1030185.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

 18%|█▊        | 1844/10065 [00:00<00:00, 18436.19it/s]

Dataset squad_shifts downloaded and prepared to /root/.cache/huggingface/datasets/squad_shifts/nyt/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039. Subsequent calls will reuse this data.


100%|██████████| 10065/10065 [00:00<00:00, 14487.05it/s]
100%|██████████| 2017/2017 [00:12<00:00, 165.78it/s]


In [None]:
text_data = original+word_removed+sentence_removed
classes = [
    (n_samples, 'Original', 'tab:blue'),
    (n_samples, 'Word', 'tab:green'),
    (n_samples, 'Sentence', 'tab:red')
]

## Compare Distributional Shift
For different Classifiers (BERT, ALBERT, DistilBERT)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
def get_accuracy_rf(x_original, x_new):
    X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
    y = np.array([0]*n_samples+[1]*n_samples)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    oob_score = 0
    best_clf = None
    for depth in [2, 5, 7, 10, 15, 20]:
        clf = RandomForestClassifier(max_depth=depth, oob_score=True)
        clf.fit(X_train, y_train)
        print('Depth {}, OOB score {}'.format(depth, clf.oob_score_))
        if clf.oob_score_>oob_score:
            oob_score = clf.oob_score_
            best_clf = clf
    print('--> Depth={}'.format(best_clf.max_depth))
    print('-'*20)
    train_acc = best_clf.score(X_train, y_train)
    test_acc = best_clf.score(X_test, y_test)
    print(X_train.shape)
    print('Train : {}({})'.format(train_acc, train_acc*(1-train_acc)/X_train.shape[0]))
    print('Test : {}({})'.format(test_acc, test_acc*(1-test_acc)/X_test.shape[0]))
    print('-'*20)

In [None]:
def process_texts(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[1].detach()

def batch(l, size):
    for i in range(0, len(l), size):
        yield l[i:i+size]

In [None]:
def models_generator():
    model_classes = [
        BertModel,
        AlbertModel
    ]
    tokenizer_classes = [
        BertTokenizer,
        AlbertTokenizer
    ]
    model_names = [
        'bert-base-uncased',
        'albert-base-v2'
    ]
    for model_name, tokenizer_class, model_class in zip(model_names, tokenizer_classes, model_classes):
        tokenizer = tokenizer_class.from_pretrained(model_name)
        model = model_class.from_pretrained(model_name).to(device)
        yield model, tokenizer, model_name.split('-')[0].upper()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
for model, tokenizer, name in models_generator():
    print('Model: {}'.format(name))
    results = []
    for b in batch(text_data, 5):
        results.append(process_texts(b, model, tokenizer))
    embeddings = torch.cat(results, 0)
    e_original = embeddings[:n_samples, :]
    e_word = embeddings[n_samples:2*n_samples, :]
    e_sentence = embeddings[2*n_samples:, :]
    print('Sentence vs Original')
    get_accuracy_rf(e_original, e_sentence)
    print('Word vs Original')
    get_accuracy_rf(e_original, e_word)
    print('\n\n')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


Model: BERT
Sentence vs Original
Depth 2, OOB score 0.5689256198347108
Depth 5, OOB score 0.5447933884297521
Depth 7, OOB score 0.5203305785123967
Depth 10, OOB score 0.47768595041322315
Depth 15, OOB score 0.4452892561983471
Depth 20, OOB score 0.44264462809917354
--> Depth=2
--------------------
(3025, 768)
Train : 0.5897520661157025(7.998167491833474e-05)
Test : 0.5490584737363726(0.0002453848029283029)
--------------------
Word vs Original
Depth 2, OOB score 0.7444628099173554
Depth 5, OOB score 0.7966942148760331
Depth 7, OOB score 0.8125619834710743
Depth 10, OOB score 0.8132231404958677
Depth 15, OOB score 0.8105785123966942
Depth 20, OOB score 0.8023140495867769
--> Depth=10
--------------------
(3025, 768)
Train : 0.9877685950413223(3.993982707905629e-06)
Test : 0.8245787908820614(0.0001433583830619812)
--------------------





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=684.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=47376696.0, style=ProgressStyle(descrip…


Model: ALBERT
Sentence vs Original
Depth 2, OOB score 0.5785123966942148
Depth 5, OOB score 0.5752066115702479
Depth 7, OOB score 0.5494214876033058
Depth 10, OOB score 0.5087603305785124
Depth 15, OOB score 0.4677685950413223
Depth 20, OOB score 0.4571900826446281
--> Depth=2
--------------------
(3025, 768)
Train : 0.6105785123966943(7.860244383343278e-05)
Test : 0.5718533201189296(0.00024265322139532865)
--------------------
Word vs Original
Depth 2, OOB score 0.8300826446280992
Depth 5, OOB score 0.8684297520661157
Depth 7, OOB score 0.8737190082644628
Depth 10, OOB score 0.8704132231404959
Depth 15, OOB score 0.8664462809917355
Depth 20, OOB score 0.8644628099173554
--> Depth=7
--------------------
(3025, 768)
Train : 0.9705785123966942(9.439954932401417e-06)
Test : 0.8919722497522299(9.549827098530775e-05)
--------------------





In [None]:
def process_texts_db(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[0][:, 0, :].cpu().detach()

In [None]:
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




In [None]:
print('Model: DistilBert')
results = []
for b in batch(text_data, 2):
    results.append(process_texts_db(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: DistilBert
Sentence vs Original
Depth 2, OOB score 0.5302479338842975
Depth 5, OOB score 0.4618181818181818
Depth 7, OOB score 0.41024793388429753
Depth 10, OOB score 0.32859504132231404
Depth 15, OOB score 0.2872727272727273
Depth 20, OOB score 0.2809917355371901
--> Depth=2
--------------------
(3025, 768)
Train : 0.617190082644628(7.810462298503975e-05)
Test : 0.533201189296333(0.0002466775827842508)
--------------------
Word vs Original
Depth 2, OOB score 0.7851239669421488
Depth 5, OOB score 0.8079338842975207
Depth 7, OOB score 0.7983471074380165
Depth 10, OOB score 0.7748760330578512
Depth 15, OOB score 0.7583471074380165
Depth 20, OOB score 0.7596694214876033
--> Depth=5
--------------------
(3025, 768)
Train : 0.9071074380165289(2.785571368979109e-05)
Test : 0.8235877106045589(0.00014399503820188326)
--------------------



