In [18]:
import torch

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM

torch_device=torch.device('cuda')

bert_model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased')
bert_model_mlm.eval()
bert_model_mlm.to(torch_device)

for param in bert_model_mlm.parameters():
    param.requires_grad = False

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

bert_id2tok = dict()
for tok, tok_id in bert_tokenizer.vocab.items():
    bert_id2tok[tok_id] = tok

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [None]:
#Parameters

MAX_BERT_LEN=256
MAX_COSINE_DIST=0.3
BERT_VOCAB_QTY=30000

num_threads=8
K=10

In [None]:
import numpy as np

ft_compiled_path = "../data/jigsaw/ft_model_bert_basic_tok.npy" # Embeddings generated from the vocabulary
fasttext_embeds = np.load(ft_compiled_path)

In [None]:
#data_vocab_path = "../data/jigsaw/data_vocab.bin"
#vocab=pickle.load(open(data_vocab_path,'rb'))

from allennlp.data.vocabulary import Vocabulary
vocab = Vocabulary.from_files("../data/jigsaw/data_ft_vocab")

In [None]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from pytorch_pretrained_bert.tokenization import BasicTokenizer
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
import re

#_spacy_tok = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words
_bert_tok = BasicTokenizer(do_lower_case=True)

spacy_tokenizer = SpacyWordSplitter(language='en_core_web_sm', pos_tags=False)

from allennlp.data.token_indexers import SingleIdTokenIndexer
token_indexer = SingleIdTokenIndexer(
    lowercase_tokens=True,
)

from itertools import groupby

def remove_url(s):
    return re.sub(r"http\S+", "", s)

def remove_extra_chars(s, max_qty=2):
    res = [c * min(max_qty, len(list(group_iter))) for c, group_iter in groupby(s)] 
    return ''.join(res)

def tokenizer(x: str):
    return [remove_extra_chars(w) for w in _bert_tok.tokenize(remove_url(x))]
    #return [w.text for w in _spacy_tok(x.lower())]

In [19]:
"n't" in bert_tokenizer.vocab

False

In [None]:
len(bert_tokenizer.vocab)

In [None]:
bert_id2tok[2]

In [None]:
bert_tokenizer.vocab['sh']

In [None]:
# Returns arrays of arrays if there's an OOV word or an empty array instead
# Each array element is a tuple: 
# position of OOV word (with respect to the original tokenizer), sent for BERT tokenizer
def get_bert_masked_inputs(toks, bert_tokenizer):
    res = []
    
    oov_pos = []
    bert_vocab = bert_tokenizer.vocab
    
    for i in range(len(toks)):
        if toks[i] not in bert_vocab:
            oov_pos.append(i)
            

    for pos in oov_pos:
        res.append( (pos, '[CLS] %s [MASK] %s [SEP]' % 
                     (' '.join(toks[0:pos]), ' '.join(toks[pos+1:])) ) )
        
    return res

In [9]:
import spacy
#from spacy.vocab import Vocab
#from spacy.language import Language
#nlp = Language(Vocab())
#from spacy.lang.en import English
#nlp = English()

nlp = spacy.load("en_core_web_sm")

from spacy.tokenizer import Tokenizer

tokenizer = Tokenizer(nlp.vocab)

In [None]:
get_bert_masked_inputs(tokenizer('This is a *strangge* sentence.'), bert_tokenizer)

In [None]:
toks = bert_tokenizer.tokenize('[CLS] detete what the [MASK] are you doing here ? [SEP]')
toks

In [None]:
toks = spacy_tokenizer.split_words("[CLS] what the [MASK] are you don't here ? [SEP]")
toks

In [27]:
doc = nlp("[CLS] what the [MASK] are you don't here sh#t fcuk? [SEP]")
print([token.text for token in doc])


['[', 'CLS', ']', 'what', 'the', '[', 'MASK', ']', 'are', 'you', 'do', "n't", 'here', 'sh#t', 'fcuk', '?', '[', 'SEP', ']']


In [26]:
"uck" in nlp.vocab

False

In [10]:
len(nlp.vocab)

57852

In [None]:
tokenizer("don't  couldn't can't you're I'm sheeeet")

In [25]:
bert_tokenizer.tokenize("don't  couldn't can't you're I'm fcuk")

['don',
 "'",
 't',
 'couldn',
 "'",
 't',
 'can',
 "'",
 't',
 'you',
 "'",
 're',
 'i',
 "'",
 'm',
 'fc',
 '##uk']

In [None]:
tokenizer("You ' re right. It ' s a miracle! You'd been deceived!") # 've', 're', 's', 'd', 'll'

In [None]:
from collections import namedtuple
# pos_oov is OOV index with respect to the original (not BERT) tokenizer!!!
UtterData = namedtuple('SentData', ['batch_sent_id', 'pos_oov', 'tok_ids', 'oov_token'])

def get_batch_data(torch_device, tokenizer, bert_tokenizer, sent_list, max_len=MAX_BERT_LEN):
    
    batch_data_raw = []
    batch_max_seq_qty = 0
    batch_sent_id = -1
    for sent in sent_list:
        batch_sent_id += 1
        sent_toks = tokenizer(sent)
        for sent_oov_pos, text in get_bert_masked_inputs(sent_toks, bert_tokenizer):
            # To accurately get what is the position of [MASK] according
            # to BERT tokenizer, we need to re-tokenize the sentence using
            # the BERT tokenizer
            all_bert_toks = bert_tokenizer.tokenize(text)
            bert_toks = all_bert_toks[0:max_len] # 512 is the max. Bert seq. length

            tok_ids = bert_tokenizer.convert_tokens_to_ids(bert_toks)
            pos_oov = None
            for i in range(len(bert_toks)):
                if bert_toks[i] == '[MASK]':
                    pos_oov = i
                    break
            assert(pos_oov is not None or len(all_bert_toks) > max_len)
            if pos_oov is not None:
                tok_qty = len(tok_ids)
                batch_max_seq_qty = max(batch_max_seq_qty, tok_qty)
                batch_data_raw.append( 
                    UtterData(batch_sent_id=batch_sent_id, 
                              pos_oov=sent_oov_pos, 
                              tok_ids=tok_ids, 
                              oov_token=sent_toks[sent_oov_pos]))
            
    batch_qty = len(batch_data_raw)
    tok_ids_batch = np.zeros( (batch_qty, batch_max_seq_qty), dtype=np.int64) # zero is a padding symbol
    for k in range(batch_qty):
        tok_ids = batch_data_raw[k].tok_ids
        tok_ids_batch[k, 0:len(tok_ids)] = tok_ids
        
                   
    tok_ids_batch = torch.from_numpy(tok_ids_batch).to(device=torch_device) 
    
    return batch_data_raw, tok_ids_batch

In [None]:
import torch
import numpy as np
from collections import namedtuple
import sys

BertPredProbs = namedtuple('BertPred', ['batch_sent_id', 'pos_oov', 'logits'])

def get_bert_preds_for_words_batch(torch_device, bert_model_mlm, 
                                   batch_data_raw, tok_ids_batch, # comes from get_batch_data
                                   word_ids, # a list of IDS for which we generate logits
                                   max_len=MAX_BERT_LEN):

    seg_ids = torch.zeros_like(tok_ids_batch, device=torch_device)
    
    batch_qty = len(batch_data_raw)
    
    # Main BERT model see modeling.py in https://github.com/huggingface/pytorch-pretrained-BERT
    bert = bert_model_mlm.bert 
    # cls is an instance of BertOnlyMLMHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    cls = bert_model_mlm.cls
    # predictions are of the type BertLMPredictionHead (see https://github.com/huggingface/pytorch-pretrained-BERT)
    predictions = cls.predictions
    transform = predictions.transform
   
    # We don't use the complete decoding matrix, but only selected rows
    word_ids = torch.from_numpy(np.array(word_ids, dtype=np.int64)).to(device=torch_device)
                                
    weight = predictions.decoder.weight[word_ids,:]
    bias = predictions.bias[word_ids]

    # Transformations from the main BERT model
    sequence_output, _= bert(tok_ids_batch, seg_ids, attention_mask=None, output_all_encoded_layers=False)
    # Transformations from the BertLMPredictionHead model with the restricted last layer
    hidden_states = transform(sequence_output)    
    logits = torch.nn.functional.linear(hidden_states, weight) + bias                            
                                        
    logits=logits.detach().cpu().numpy()
    
    res = []
    
    for k in range(batch_qty):
        
        pos_oov = batch_data_raw[k].pos_oov         
        res.append( BertPredProbs(batch_sent_id = batch_data_raw[k].batch_sent_id,
                             pos_oov = pos_oov,
                             logits = logits[k, pos_oov]
                            ) 
                  )
        
    return res

In [None]:
bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                      'doin', 'wearing', 'making', 'thinking', 'all', 'test'])

In [None]:
sent_list = ['What the fcuk are you doingg here?',
             'This is a *strangge* sentence']

batch_data_raw, tok_ids_batch = get_batch_data(torch_device, 
                                                tokenizer, 
                                                bert_tokenizer, 
                                                sent_list,
                                                MAX_BERT_LEN)

get_bert_preds_for_words_batch(torch_device,
                               bert_model_mlm, 
                               batch_data_raw, tok_ids_batch,
                               bert_tokenizer.convert_tokens_to_ids(['hell', 'fuck', 'heck', 'devil', 'shit', 'doing', 
                                                                  'doin', 'wearing', 'making', 'thinking']))

In [None]:
#bert_model_mlm.to('cpu')
#torch.zeros(3,device=torch.device("cuda"))

In [None]:
bert_tokenizer.tokenize('б')

In [None]:
from typing import *
from overrides import overrides
import allennlp
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
from allennlp.data.fields import TextField, SequenceLabelField, LabelField, MetadataField, ArrayField
class MemoryOptimizedTextField(TextField):
    @overrides
    def __init__(self, tokens: List[str], token_indexers: Dict[str, TokenIndexer]) -> None:
        self.tokens = tokens
        self._token_indexers = token_indexers
        self._indexed_tokens: Optional[Dict[str, TokenList]] = None
        self._indexer_name_to_indexed_token: Optional[Dict[str, List[str]]] = None
        # skip checks for tokens
    @overrides
    def index(self, vocab):
        super().index(vocab)
        self.tokens = None # empty tokens


In [None]:
# These are *MAIN* vocabulary word IDs for words in the BERT vocabulary.
bert_vocab_term_glob_ids = []
bert_vocab_term_bert_ids = []

for tok, bert_tok_id in bert_tokenizer.vocab.items():
    glob_tok_id = vocab.get_token_index(tok)
    if glob_tok_id > 1:
        bert_vocab_term_glob_ids.append(glob_tok_id)
        bert_vocab_term_bert_ids.append(bert_tok_id)
        
bert_vocab_term_glob_ids = np.array(bert_vocab_term_glob_ids)
bert_vocab_term_bert_ids = np.array(bert_vocab_term_bert_ids)
fasttext_embeds[bert_vocab_term_glob_ids].shape

In [None]:
max(bert_vocab_term_bert_ids)

In [None]:
import nmslib, time

M = 30
efC = 200

num_threads = 0
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post' : 0}
print('Index-time parameters', index_time_params)

# Space name should correspond to the space name 
# used for brute-force search
space_name='cosinesimil'


# Intitialize the library, specify the space, the type of the vector and add data points 
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) 
index.addDataPointBatch(fasttext_embeds[bert_vocab_term_glob_ids], bert_vocab_term_bert_ids)

# Create an index
start = time.time()
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
index.createIndex(index_time_params) 
end = time.time() 
print('Index-time parameters', index_time_params)
print('Indexing time = %f' % (end-start))


In [None]:
# Setting query-time parameters
efS = 200
K=10
query_time_params = {'efSearch': efS}
print('Setting query-time parameters', query_time_params)
index.setQueryTimeParams(query_time_params)

In [None]:
#import pickle
#with open("../data/jigsaw/val_ds.bin", "rb") as f:
#    val_ds = pickle.load(f)

In [None]:
from itertools import groupby
import re

def remove_extra_chars(s, max_qty=2):
    res = [c * min(max_qty, len(list(group_iter))) for c, group_iter in groupby(s)] 
    return ''.join(res)

def is_apost_token(s):
    return re.match(r"'[a-z]{1,3}$", s) is not None


def replace_by_patterns(tokenizer, s, replace_dict):
    res = tokenizer(s)
    for pos, repl in replace_dict.items():
        res[pos] = repl
    return ' '.join(res)
    

In [None]:
repl_oov_files = [ ("../data/jigsaw/test_proced.csv", "../data/jigsaw/test_proced_no_oov1.csv") , ("../data/jigsaw/train.csv", "../data/jigsaw/train_no_oov1.csv") ]
#repl_oov_files = [ ("../data/jigsaw/val.csv", "../data/jigsaw/val_no_oov1.csv")  ]

In [None]:
import time, gc
import pandas as pd
from scipy.special import softmax


DEBUG_PRINT=False

for src_file, dst_file in repl_oov_files:
    src_data = pd.read_csv(src_file)
    
    t0 = time.time()
    preds = []
    
    #all_src_sents = [' '.join(t['tokens']) for t in val_ds]
    all_src_sents = list(src_data['comment_text'])
    all_dst_sents = []
    
    #print(s)
    #preds.append(get_bert_top_preds(tokenizer, bert_tokenizer, s, 2))
    #preds.append(get_bert_masked_inputs(tokenizer(s), bert_tokenizer, sent))


    batch_qty_step = 20

    for batch_start_sent_id in range(0, len(all_src_sents), batch_qty_step):
        print('Batch start', batch_start_sent_id)

        batch_qty = min(batch_qty_step, len(all_src_sents) - batch_start_sent_id)

        batch_sents = [all_src_sents[k] for k in range(batch_start_sent_id,
                                                   batch_start_sent_id + batch_qty)]

        replace_dict = {k : dict() for k in range(0,batch_qty)} 

        # batch_data raw contains elements
        # UtterData = namedtuple('SentData', ['batch_sent_id', 'pos_oov', 'tok_ids', 'oov_token')
        # NOTE: pos_oov is OOV index with respect to the original (not BERT) tokenizer!!!
        #
        # tok_ids_batch is a Tensor with padded Bert-specific token IDs ready
        # to be fed into a BERT model
        batch_data_raw, tok_ids_batch = get_batch_data(torch_device,
                                                     tokenizer, bert_tokenizer,
                                                     batch_sents, 
                                                     MAX_BERT_LEN)

        query_arr = []
        query_tok_oov_id = []

        for e in batch_data_raw: 
            w = e.oov_token
            wCompr = remove_extra_chars(w)
            wid = vocab.get_token_index(wCompr)
            if w != wCompr:
                if wid < 2:
                    wid = vocab.get_token_index(w)

            query_arr.append(fasttext_embeds[wid])
            query_tok_oov_id.append(wid)

        query_arr = np.array(query_arr)
        query_matrix = np.array(query_arr)
        query_qty = query_matrix.shape[0]

        if DEBUG_PRINT: print('Query matrix shape:', query_matrix.shape)

        start = time.time() 
        # nbrs is array of tuples (neighbor array, distance array)
        # For cosine, the distance is 1 - cosine similarity
        # k-NN search returns Bert-specific token IDs
        nbrs = index.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)
        end = time.time() 
        if DEBUG_PRINT: 
            print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % 
              (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty))

        neighb_tok_ids=set()

        for qid in range(query_qty):
            if query_tok_oov_id[qid] > 1:
                nbrs_ids = nbrs[qid][0]
                nbrs_dist = nbrs[qid][1]

                nqty = len(nbrs_ids)
                for t in range(nqty):
                    if nbrs_dist[t] < MAX_COSINE_DIST:
                        assert(nbrs_ids[t] < BERT_VOCAB_QTY)
                        neighb_tok_ids.add(nbrs_ids[t])


        neighb_tok_ids = list(neighb_tok_ids)

        preds = get_bert_preds_for_words_batch(torch_device,
                                               bert_model_mlm,
                                               batch_data_raw, tok_ids_batch,
                                               neighb_tok_ids)

        assert(len(preds) == query_qty)
        for qid in range(query_qty):
            e = batch_data_raw[qid]
            glob_sent_id = batch_start_sent_id + e.batch_sent_id
            assert(batch_sents[e.batch_sent_id] == all_src_sents[glob_sent_id])
            if is_apost_token(e.oov_token) or e.oov_token == "n't":
                # Thing's like "I don't" or "You're" are tokenized as do "I do n't" or "You 're'"
                pass # TODO fix this
            elif query_tok_oov_id[qid] > 1:  
                # Let's map neighbor IDs from each queries to respective 
                # logits from the prediction set
                logit_map = dict() # from Bert-specific token IDs to predicted logits
                assert(len(preds[qid].logits) == len(neighb_tok_ids))
                for i in range(len(neighb_tok_ids)):
                    logit_map[neighb_tok_ids[i]] = preds[qid].logits[i]

                e = batch_data_raw[qid]
                if DEBUG_PRINT: 
                    print(all_src_sents[glob_sent_id])
                    print("### OOV ###", e.oov_token)
                    print([bert_id2tok[bert_tok_id] for bert_tok_id in nbrs[qid][0]])

                nbrs_sel_logits = []
                nbrs_sel_toks = []
                nbrs_sel_dists = []

                nbrs_ids = nbrs[qid][0]
                nbrs_dist = nbrs[qid][1]

                #print('Logit map:', logit_map)
                #print('neighb_tok_ids', neighb_tok_ids)

                nqty = len(nbrs_ids)
                for t in range(nqty):
                    bert_tok_id = nbrs_ids[t]
                    # nid is Bert-speicifc token ID
                    if not bert_tok_id in neighb_tok_ids:
                        if DEBUG_PRINT: 
                            print('Missing %s distance %g ' 
                                  % (bert_id2tok[bert_tok_id],
                                     nbrs_dist[t]))
                    else:
                        if nbrs_dist[t] < MAX_COSINE_DIST:
                            nbrs_sel_logits.append(logit_map[bert_tok_id])
                            nbrs_sel_toks.append(bert_id2tok[bert_tok_id]) 
                            nbrs_sel_dists.append(nbrs_dist[t])

                if nbrs_sel_logits:
                    nbrs_softmax = softmax(np.array(nbrs_sel_logits))
                    nbrs_simil = 1 - np.array(nbrs_sel_dists)
                    nbrs_simil_adj = nbrs_softmax * nbrs_simil 

                    best_tok_id = np.argmax(nbrs_simil_adj)

                    #print("batch sent id:",e.batch_sent_id, e.pos_oov, best_tok_id)
                    #print(replace_dict[e.batch_sent_id])
                    assert(not e.pos_oov in replace_dict[e.batch_sent_id])
                    replace_dict[e.batch_sent_id][e.pos_oov] = nbrs_sel_toks[best_tok_id]

                    if DEBUG_PRINT: 
                        print('Selected info, best_tok:', nbrs_sel_toks[best_tok_id])
                        for k in range(len(nbrs_sel_logits)):
                            print(nbrs_sel_toks[k], nbrs_softmax[k], 
                                  nbrs_sel_dists[k], nbrs_simil_adj[k])
                else:
                    if DEBUG_PRINT: print('Nothing found!')

                #if DEBUG_PRINT: print(preds[qid])
                if DEBUG_PRINT: 
                    print("====================================================================")



        #gc.collect()
        #torch.cuda.empty_cache()
        for k in range(0, batch_qty):
            src_sent = batch_sents[k]
            rd = replace_dict[k]
            #print('Replacement dict:', rd)
            dst_sent = replace_by_patterns(tokenizer, src_sent, rd)
            all_dst_sents.append(dst_sent)
            if DEBUG_PRINT:
                print("====================================================================")
                print('Replacement dict:', rd)
                print(src_sent)
                print('------------')
                print(dst_sent)
                print("====================================================================")

        #break

    t1 = time.time()
    print('# of src sentences:', len(all_src_sents), 
          "# of dst sentences:", len(all_dst_sents),
          ' time elapsed:', t1 - t0)
    src_data['comment_text'] = all_dst_sents
    src_data.to_csv(dst_file, index=False)

In [None]:
#src_data[src_data["toxic"]==1].head(20)

In [None]:
#fl = pd.read_csv(src_file)

In [None]:
#fl[fl["toxic"]==1].head(20)