In [1]:
from functools import partial
import pickle
import math

import spacy
import numpy as np
import torch
import torch.nn as nn
import mlflow
from alibi_detect.cd import KSDrift, MMDDrift, MMDDriftOnline, LSDDDriftOnline
from alibi_detect.cd.pytorch import preprocess_drift, UAE
from alibi_detect.models.pytorch import TransformerEmbedding
from alibi_detect.saving import save_detector, load_detector

from training_helpers.dataset import ParallelLanguageDataset

mlflow.set_tracking_uri("http://0.0.0.0:8000")

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_uri="runs:/3ed0559e69954c3ba240ad6ad183089c/model_orig"
model = mlflow.pytorch.load_model(model_uri)
model.to("cpu")

Downloading artifacts:  83%|████████▎ | 5/6 [00:00<00:00, 233.54it/s]

Downloading artifacts: 100%|██████████| 6/6 [00:02<00:00,  2.09it/s] 


LanguageTransformer(
  (embed_src): Embedding(15004, 512)
  (embed_tgt): Embedding(15004, 512)
  (pos_enc): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=1024, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1024, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affin

In [4]:
class LanguageTransformerEncoder(nn.Module):
  def __init__(self, model):
    super().__init__()
    self.d_model = model.d_model
    self.embed_src = model.embed_src
    self.pos_enc = model.pos_enc
    self.encoder = model.transformer.encoder
  
  def forward(self, src):
    src_key_padding_mask = torch.where(src > 0, False, True)
    if isinstance(src, list):
      src = src[0].unsqueeze(0)
      
    src = torch.transpose(src, 0, 1)
    src = self.pos_enc(self.embed_src(src) * math.sqrt(self.d_model))
    output = self.encoder(src, src_key_padding_mask=src_key_padding_mask)
    output = torch.transpose(output, 0, 1)
    return output
model_encoder = LanguageTransformerEncoder(model)

In [11]:
embed_src = model.embed_src
enc_dim = 32
max_seq_len = 96
shape = (max_seq_len, embed_src.embedding_dim, )

uae = UAE(input_layer=model_encoder, shape=shape, enc_dim=enc_dim)

In [6]:
# tokens = torch.randint(low=0, high=10, size=(2, max_seq_len, ))
# print(embed_src(tokens).shape)
# emb_uae = uae(tokens)
# print(emb_uae.shape)

In [7]:
train_dataset = ParallelLanguageDataset(
    "./data/processed/en/train.pkl",
    "./data/processed/fr/train.pkl",
    1e9,
    96,
)

valid_dataset = ParallelLanguageDataset(
    "./data/processed/en/val.pkl",
    "./data/processed/fr/val.pkl",
    1e9,
    96,
)
X_ref = torch.IntTensor(train_dataset.data_1)
X_h0 = torch.IntTensor(valid_dataset.data_1)

print(X_ref.shape, X_h0.shape)

torch.Size([183842, 96]) torch.Size([45960, 96])


In [9]:
uae.to("cuda")
idx = torch.randperm(X_ref.size(0))
preprocess_fn = partial(preprocess_drift, model=uae, max_len=max_seq_len, batch_size=1000, device="cuda")
cd = MMDDriftOnline(X_ref[idx[:10000]], ert=200, window_size=50, preprocess_fn=preprocess_fn, backend="pytorch", input_shape=(max_seq_len, ))
save_detector(cd, "detector")
cd = load_detector("detector")

Generating permutations of kernel matrix..


100%|██████████| 1000/1000 [00:01<00:00, 891.68it/s]
Computing thresholds: 100%|██████████| 50/50 [00:03<00:00, 16.29it/s]


Generating permutations of kernel matrix..


100%|██████████| 1000/1000 [00:01<00:00, 890.77it/s]
Computing thresholds: 100%|██████████| 50/50 [00:03<00:00, 16.44it/s]


In [7]:
idx = torch.randperm(X_h0.size(0))
preds_h0 = cd.predict(X_h0[idx[:1000]])
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))

NameError: name 'cd' is not defined

In [11]:
lang_model = spacy.load("en_core_web_sm")
with open("data/processed/en/freq_list.pkl", "rb") as f:
    en_freq_list = pickle.load(f)
with open("data/processed/fr/freq_list.pkl", "rb") as f:
    fr_freq_list = pickle.load(f)

def tokenize(sentence, freq_list, lang_model):
    punctuation = ["(", ")", ":", '"', " "]

    sentence = sentence.lower()
    sentence = [
        tok.text
        for tok in lang_model.tokenizer(sentence)
        if tok.text not in punctuation
    ]
    return [
        freq_list[word] if word in freq_list else freq_list["[OOV]"]
        for word in sentence
    ]

def pad_arr(array, seq_len, freq_list):
    return array + [freq_list["[PAD]"] for i in range(seq_len - len(array))], len(array)

In [18]:
test_sentences = ["This is a test sentence"]
def process(sentence):
  s, length = pad_arr(tokenize(sentence, en_freq_list, lang_model), max_seq_len, en_freq_list)
  return np.array(s)

for sentence in test_sentences:
  preds_h0 = cd.predict(process(sentence)[0], return_test_stat=True)
  labels = ['No!', 'Yes!']
  print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
  print('test_stat: {}'.format(preds_h0['data']['test_stat']))

Drift? Yes!
test_stat: 0.004450976848602295


In [14]:
idx = torch.randperm(X_h0.size(0))
for sentence in X_h0[idx[:200]]:
  preds_h0 = cd.predict(np.array(sentence), return_test_stat=True)
  labels = ['No!', 'Yes!']
  print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
  print('test_stat: {}'.format(preds_h0['data']['test_stat']))

Drift? Yes!
test_stat: 0.03132230043411255
Drift? Yes!
test_stat: 0.030436038970947266
Drift? Yes!
test_stat: 0.03146761655807495
Drift? Yes!
test_stat: 0.029927849769592285
Drift? Yes!
test_stat: 0.030577778816223145
Drift? Yes!
test_stat: 0.02799367904663086
Drift? Yes!
test_stat: 0.026028811931610107
Drift? Yes!
test_stat: 0.026022613048553467
Drift? Yes!
test_stat: 0.023051977157592773
Drift? Yes!
test_stat: 0.023899972438812256
Drift? Yes!
test_stat: 0.020528197288513184
Drift? Yes!
test_stat: 0.019391357898712158
Drift? Yes!
test_stat: 0.017357468605041504
Drift? Yes!
test_stat: 0.015854477882385254
Drift? Yes!
test_stat: 0.015028059482574463
Drift? Yes!
test_stat: 0.014593362808227539
Drift? Yes!
test_stat: 0.01350182294845581
Drift? Yes!
test_stat: 0.012850582599639893
Drift? Yes!
test_stat: 0.011692285537719727
Drift? Yes!
test_stat: 0.009923338890075684
Drift? Yes!
test_stat: 0.00962132215499878
Drift? Yes!
test_stat: 0.008896350860595703
Drift? Yes!
test_stat: 0.009083926677