# Test Distributional Shift
This notebook tests distributional shift for the dataset of Amazon Reviews.

## Imports

In [1]:
%%capture
!pip install transformers

In [2]:
%%capture
!pip install datasets

In [3]:
%%capture
!pip install nltk

In [4]:
%%capture
import nltk
nltk.download('punkt')

In [5]:
from transformers import BertTokenizer, BertModel, AlbertTokenizer, AlbertModel, DistilBertTokenizer, DistilBertModel
import torch
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from math import ceil

from nltk.tokenize import sent_tokenize, word_tokenize
import random
import numpy as np

## Loading and Preparing Texts

In [6]:
def remove_sentence(t, k=1):
    sentences = random.sample(sent_tokenize(t), k)
    text = t
    for s in sentences:
        text = text.replace(s, '')
        assert len(t) != len(text)
    return text, len(word_tokenize(s))

def remove_word(t, k=1):
    text = t
    for _ in range(k):
        words = word_tokenize(text)
        if not len([i for i,w in enumerate(words) if w.isalnum()]):
            break
        word_idx = random.choice([i for i,w in enumerate(words) if w.isalnum()])
        del words[word_idx]
        text = ''.join([(' ' if w.isalnum() else '')+w for w in words]).strip(' ')
    return text

In [10]:
dataset = load_dataset('squadshifts', 'amazon')['test']
print('Average number of sentences = {}'.format(np.mean([len(sent_tokenize(sample['context'])) for sample in dataset ])))

Reusing dataset squad_shifts (/root/.cache/huggingface/datasets/squad_shifts/amazon/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039)


Average number of sentences = 7.833687405159332


In [None]:
N_WORD = 5
texts = []
dataset = load_dataset('squadshifts', 'amazon')['test']
for d in tqdm(dataset):
    if d['context'] not in texts:
        texts.append(d['context'])
random.shuffle(texts)        
n_samples = len(texts)
original = texts#random.sample(texts, n_samples)#texts[:n_samples]
word_removed = []
sentence_removed = []
for t in tqdm(original):
    sr, nw = remove_sentence(t)
    wr = remove_word(t, k=N_WORD)
    word_removed.append(wr)
    sentence_removed.append(sr)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2317.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1400.0, style=ProgressStyle(description…


Downloading and preparing dataset squad_shifts/amazon (download: 15.74 MiB, generated: 9.00 MiB, post-processed: Unknown size, total: 24.74 MiB) to /root/.cache/huggingface/datasets/squad_shifts/amazon/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=773853.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1117198.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1056991.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1030185.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

 21%|██        | 2056/9885 [00:00<00:00, 20554.38it/s]

Dataset squad_shifts downloaded and prepared to /root/.cache/huggingface/datasets/squad_shifts/amazon/1.0.0/f6c7b6f10e62b342754f88631c92624a2033652e3a3e129b8d979726dec04039. Subsequent calls will reuse this data.


100%|██████████| 9885/9885 [00:00<00:00, 17492.47it/s]
100%|██████████| 2050/2050 [00:11<00:00, 183.52it/s]


In [None]:
text_data = original+word_removed+sentence_removed
classes = [
    (n_samples, 'Original', 'tab:blue'),
    (n_samples, 'Word', 'tab:green'),
    (n_samples, 'Sentence', 'tab:red')
]

## Compare Distributional Shift
For different Classifiers (BERT, ALBERT, DistilBERT)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
def get_accuracy_rf(x_original, x_new):
    X = torch.cat([x_original, x_new], dim=0).cpu().numpy()
    y = np.array([0]*n_samples+[1]*n_samples)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    oob_score = 0
    best_clf = None
    for depth in [2, 5, 7, 10, 15, 20]:
        clf = RandomForestClassifier(max_depth=depth, oob_score=True)
        clf.fit(X_train, y_train)
        print('Depth {}, OOB score {}'.format(depth, clf.oob_score_))
        if clf.oob_score_>oob_score:
            oob_score = clf.oob_score_
            best_clf = clf
    print('--> Depth={}'.format(best_clf.max_depth))
    print('-'*20)
    train_acc = best_clf.score(X_train, y_train)
    test_acc = best_clf.score(X_test, y_test)
    print(X_train.shape)
    print('Train : {}({})'.format(train_acc, train_acc*(1-train_acc)/X_train.shape[0]))
    print('Test : {}({})'.format(test_acc, test_acc*(1-test_acc)/X_test.shape[0]))
    print('-'*20)

In [None]:
def process_texts(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[1].detach()

def batch(l, size):
    for i in range(0, len(l), size):
        yield l[i:i+size]

In [None]:
def models_generator():
    model_classes = [
        BertModel,
        AlbertModel
    ]
    tokenizer_classes = [
        BertTokenizer,
        AlbertTokenizer
    ]
    model_names = [
        'bert-base-uncased',
        'albert-base-v2'
    ]
    for model_name, tokenizer_class, model_class in zip(model_names, tokenizer_classes, model_classes):
        tokenizer = tokenizer_class.from_pretrained(model_name)
        model = model_class.from_pretrained(model_name).to(device)
        yield model, tokenizer, model_name.split('-')[0].upper()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
for model, tokenizer, name in models_generator():
    print('Model: {}'.format(name))
    results = []
    for b in batch(text_data, 5):
      results.append(process_texts(b, model, tokenizer))
    embeddings = torch.cat(results, 0)
    e_original = embeddings[:n_samples, :]
    e_word = embeddings[n_samples:2*n_samples, :]
    e_sentence = embeddings[2*n_samples:, :]
    print('Sentence vs Original')
    get_accuracy_rf(e_original, e_sentence)
    print('Word vs Original')
    get_accuracy_rf(e_original, e_word)
    print('\n\n')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


Model: BERT
Sentence vs Original
Depth 2, OOB score 0.5499186991869919
Depth 5, OOB score 0.528130081300813
Depth 7, OOB score 0.5128455284552845
Depth 10, OOB score 0.45170731707317074
Depth 15, OOB score 0.4367479674796748
Depth 20, OOB score 0.41203252032520327
--> Depth=2
--------------------
(3075, 768)
Train : 0.5808130081300813(7.917699437950161e-05)
Test : 0.5463414634146342(0.00024180728660350257)
--------------------
Word vs Original
Depth 2, OOB score 0.7271544715447155
Depth 5, OOB score 0.7547967479674796
Depth 7, OOB score 0.7580487804878049
Depth 10, OOB score 0.7521951219512195
Depth 15, OOB score 0.7486178861788618
Depth 20, OOB score 0.7388617886178862
--> Depth=7
--------------------
(3075, 768)
Train : 0.9082926829268293(2.7088482949076952e-05)
Test : 0.7541463414634146(0.00018088745084952337)
--------------------





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=684.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=47376696.0, style=ProgressStyle(descrip…


Model: ALBERT
Sentence vs Original
Depth 2, OOB score 0.5704065040650407
Depth 5, OOB score 0.5570731707317074
Depth 7, OOB score 0.5226016260162601
Depth 10, OOB score 0.4796747967479675
Depth 15, OOB score 0.4416260162601626
Depth 20, OOB score 0.4250406504065041
--> Depth=2
--------------------
(3075, 768)
Train : 0.6133333333333333(7.712375790424571e-05)
Test : 0.5541463414634147(0.00024104212068890468)
--------------------
Word vs Original
Depth 2, OOB score 0.6933333333333334
Depth 5, OOB score 0.7203252032520325
Depth 7, OOB score 0.712520325203252
Depth 10, OOB score 0.6832520325203252
Depth 15, OOB score 0.6790243902439025
Depth 20, OOB score 0.6826016260162602
--> Depth=5
--------------------
(3075, 768)
Train : 0.8364227642276423(4.449421909249828e-05)
Test : 0.7307317073170732(0.00019196378462297414)
--------------------





In [None]:
def process_texts_db(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model(**(inputs.to(device)))[0][:, 0, :].cpu().detach()

In [None]:
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




In [None]:
print('Model: DistilBert')
results = []
for b in batch(text_data, 2):
    results.append(process_texts_db(b, model, tokenizer))
embeddings = torch.cat(results, 0)
e_original = embeddings[:n_samples, :]
e_word = embeddings[n_samples:2*n_samples, :]
e_sentence = embeddings[2*n_samples:, :]
print('Sentence vs Original')
get_accuracy_rf(e_original, e_sentence)
print('Word vs Original')
get_accuracy_rf(e_original, e_word)
print('\n\n')

Model: DistilBert
Sentence vs Original
Depth 2, OOB score 0.5232520325203251
Depth 5, OOB score 0.4484552845528455
Depth 7, OOB score 0.38341463414634147
Depth 10, OOB score 0.3173983739837398
Depth 15, OOB score 0.27772357723577235
Depth 20, OOB score 0.26471544715447154
--> Depth=2
--------------------
(3075, 768)
Train : 0.6185365853658537(7.673140745684673e-05)
Test : 0.551219512195122(0.0002413429868980427)
--------------------
Word vs Original
Depth 2, OOB score 0.7772357723577236
Depth 5, OOB score 0.7954471544715447
Depth 7, OOB score 0.7723577235772358
Depth 10, OOB score 0.7528455284552845
Depth 15, OOB score 0.7398373983739838
Depth 20, OOB score 0.7365853658536585
--> Depth=5
--------------------
(3075, 768)
Train : 0.8663414634146341(3.765656331161765e-05)
Test : 0.8146341463414634(0.00014732229654241817)
--------------------



