# GMM Anomaly Detection in contextual tokens

In [1]:
import sys
sys.path.append('../')

import pickle
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import random
import sklearn.mixture

import src.sent_encoder

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
model_name = 'roberta-base'
bert_tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModelForMaskedLM.from_pretrained(model_name)

Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Pick random subset of sentences

In [3]:
with open('../data/bnc.pkl', 'rb') as f:
  bnc_sentences = pickle.load(f)

random.seed(12345)
bnc_sentences = random.sample(bnc_sentences, 1000)

## Feed them through BERT

In [4]:
enc = src.sent_encoder.SentEncoder()

In [5]:
bnc_vecs = enc.contextual_token_vecs(bnc_sentences, layer=-2)

## Train GMM, test on ungrammatical sentences

In [6]:
gmm = sklearn.mixture.GaussianMixture()
gmm.fit(bnc_vecs)

GaussianMixture(covariance_type='full', init_params='kmeans', max_iter=100,
                means_init=None, n_components=1, n_init=1, precisions_init=None,
                random_state=None, reg_covar=1e-06, tol=0.001, verbose=0,
                verbose_interval=10, warm_start=False, weights_init=None)

In [7]:
def infer_new_sentence(sent):
  ids = [x for x in enc.auto_tokenizer(sent)['input_ids'] if x not in enc.auto_tokenizer.all_special_ids]
  sent_vecs = enc.contextual_token_vecs([sent])
  assert len(ids) == sent_vecs.shape[0]
  
  for i in range(sent_vecs.shape[0]):
    print(enc.auto_tokenizer.decode(ids[i]), gmm.score([sent_vecs[i]]))

In [8]:
infer_new_sentence("The cats won't eating the food that Mary gives them.")

The -24.88312349970215
 cats -111.1238954929845
 won -376.85479956653944
't -314.6890912956078
 eating -401.75939332260396
 the -16.839375473423047
 food -34.187998326945376
 that 13.502896209406344
 Mary -82.7106539223663
 gives -203.89372120158964
 them -220.90611577235813
. 207.03518002565397


In [9]:
infer_new_sentence("The student laughs.")

The -24.570347497591
 student -108.71452760265811
 laughs -88.1599491002903
. 207.72688568806518


In [10]:
infer_new_sentence("The student laugh.")

The -67.13459467313692
 student -185.33300475802707
 laugh -180.1849407009704
. 207.6739630696121
