In [12]:
from collections import Counter
import pandas as pd
import string
from transformers import AutoTokenizer, AutoConfig
import spacy
import codecs
import random
import pickle

In [2]:
def get_pos_tags(doc_path):
    nlp = spacy.load('en_core_web_sm')
    nlp.max_length = 2000000
    pos_dict = {}
    with open(doc_path, 'r', encoding='utf-8') as file:
        text = file.read()
        doc = nlp(text)
        for token in doc:
            if token.text in pos_dict and not token.pos_ in pos_dict[token.text]:
                pos_dict[token.text].append(token.pos_)
            else:
                pos_dict[token.text] = [token.pos_]
        
    return pd.DataFrame(list(pos_dict.items()), columns=['token', 'POS'])

### Creating wordbank file

In [38]:
with open('sample_data/wikitext/wikitext103_test.txt', 'r') as file:
    contents = file.read()

words = contents.split()
df = pd.DataFrame(list(Counter(words).items()), columns=['token', 'count']).sort_values('count', ascending=False).reset_index(drop=True)
df = df[df['token'].apply(lambda w: w.isascii() and w  not in string.punctuation and not w.isdigit())]
df

AttributeError: 'str' object has no attribute 'isascii'

In [1]:
import sys
print(sys.version)

3.6.8 (default, Apr 16 2020, 01:36:27) 
[GCC 8.3.1 20191121 (Red Hat 8.3.1-5)]


In [33]:
# only keeping words that are considered one token by the language model
tokenizer = AutoTokenizer.from_pretrained('google/multiberts-seed_0')
filtered_words = [word for word in df['token'].tolist() if len(tokenizer.tokenize(word)) == 1]
filtered_df = df[df['token'].isin(filtered_words)].reset_index(drop=True)

print(f'Percentage of words that are considered one token by the language model: {((len(filtered_df) / len(df)) * 100):.2f}%')

filtered_df

Percentage of words that are considered one token by the language model: 67.64%


Unnamed: 0,token,count
0,the,13988
1,of,6731
2,and,5780
3,to,4724
4,in,4495
...,...,...
13626,Silence,1
13627,hid,1
13628,eyebrows,1
13629,Straps,1


In [34]:
document = "sample_data/wikitext/wikitext103_test.txt"
pos_tags = get_pos_tags(document)
pos_tags

Unnamed: 0,token,POS
0,,[SPACE]
1,Robert,[PROPN]
2,Boulter,[PROPN]
3,is,[AUX]
4,an,[DET]
...,...,...
20889,Author,[NOUN]
20890,credibility,[NOUN]
20891,Ronnie,[PROPN]
20892,Pelkey,[PROPN]


In [36]:
merged_df.sort_values('token')

Unnamed: 0,token,count,POS
237,0,84,[NUM]
823,00,31,[PUNCT]
79,000,186,[NUM]
11241,001,1,[NUM]
5468,01,4,[NUM]
...,...,...,...
7160,→,2,[ADP]
4561,−,5,[PROPN]
11175,♯,1,[PROPN]
13116,東,1,[PROPN]


In [6]:
merged_df = pd.merge(filtered_df, pos_tags, on='token', how='inner')
merged_df.sort_values()

Unnamed: 0,token,count,POS
0,the,13988,[DET]
1,of,6731,[ADP]
2,and,5780,[CCONJ]
3,to,4724,"[ADP, PART]"
4,in,4495,[ADP]
...,...,...,...
14190,Silence,1,[PROPN]
14191,hid,1,[VERB]
14192,eyebrows,1,[NOUN]
14193,Straps,1,[NOUN]


In [27]:
merged_df.to_csv('sample_data/wikitext/wikitext_wordbank.tsv', sep='\t', index=None)

---

In [7]:
def get_sample_sentences(tokenizer, wordbank_file, tokenized_examples_file,
                         max_seq_len, min_seq_len, max_samples, bidirectional=True):
    short_sents = 0
    long_sents = 0
    superfluous = 0
    num_lines = 0
    # Each entry of token data is a tuple of token, token_id, masked_sample_sentences.
    token_data = []
    # Load words.
    df = pd.read_csv(wordbank_file, sep='\t')
    wordbank_tokens = df.token.unique().tolist()
    # Get token ids.
    for token in wordbank_tokens:
        token_id = tokenizer.convert_tokens_to_ids(token)
        if token_id != tokenizer.unk_token_id:
            token_data.append(tuple([token, token_id, []]))
    # Load sentences.
    print(f"Loading sentences from {tokenized_examples_file}.")
    infile = codecs.open(tokenized_examples_file, 'rb', encoding='utf-8')
    for line_count, line in enumerate(infile):
        num_lines += 1
        if line_count % 100000 == 0:
            print("Finished line {}.".format(line_count))
        example_string = line.strip()
        example = [int(token_id) for token_id in example_string.split()]
        # Use the pair of sentences (instead of individual sentences), to have
        # longer sequences. Also more similar to training.
        if len(example) < min_seq_len:
            short_sents += 1
            continue
        if len(example) > max_seq_len:
            long_sents += 1
            example = example[:max_seq_len]
        for token, token_id, sample_sents in token_data:
            if len(sample_sents) >= max_samples:
                # This token already has enough sentences.
                superfluous += 1
                continue
            token_indices = [index for index, curr_id in enumerate(example) if curr_id == token_id]
            # Warning: in bidirectional contexts, the mask can be in the first or last position,
            # which can cause no mask prediction to be made for the biLSTM.
            if not bidirectional:
                # The token must have enough unidirectional context.
                # The sequence length (including the target token) must be at least min_seq_len.
                token_indices = [index for index in token_indices if index >= min_seq_len-1]
            if len(token_indices) > 0:
                new_example = example.copy()
                mask_idx = random.choice(token_indices)
                new_example[mask_idx] = tokenizer.mask_token_id
                sample_sents.append(new_example)
    infile.close()
    # Logging.
    print(f'{superfluous} out of {len(token_data)} tokens ({(superfluous/len(token_data))*100:.2f}%) had more than {max_samples} samples.')
    print(f'{((len(token_data) - superfluous)/len(token_data))*100:.2f}% of tokens had a maximum of {max_samples} samples.')
    print(f'{short_sents} examples were shorter than {min_seq_len} tokens and were thus disregarded. {((num_lines-short_sents)/num_lines)*100:.2f}% of examples were kept.')
    print(f'{long_sents} examples were longer than {max_seq_len} tokens and were thus clipped.')
    return token_data

In [8]:
wordbank_file = 'sample_data/wikitext/wikitext_wordbank.tsv'
tokenized_examples_file = 'sample_data/wikitext/test_tokenized.txt'
config = AutoConfig.from_pretrained('google/multiberts-seed_0')
max_seq_len = config.max_position_embeddings
min_seq_len = 8
max_samples = 512

token_data = get_sample_sentences(
    tokenizer, wordbank_file, tokenized_examples_file, max_seq_len, min_seq_len, max_samples
    )

Loading sentences from sample_data/wikitext/test_tokenized.txt.
Finished line 0.
4841 out of 9886 tokens (48.97%) had more than 512 samples.
51.03% of tokens had a maximum of 512 samples.
6 examples were shorter than 8 tokens and were thus disregarded. 99.45% of examples were kept.
60 examples were longer than 512 tokens and were thus clipped.


In [10]:
len(token_data)

9886

In [14]:
with open('sample_data/wikitext/bidirectional_samples.pickle', 'rb') as f:
    bidirectional_samples = pd.DataFrame(pickle.load(f), columns=['token', 'token_id', 'sample_sents'])

bidirectional_samples

Unnamed: 0,token,token_id,sample_sents
0,the,1996,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
1,of,1997,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
2,and,1998,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
3,to,2000,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
4,in,1999,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
...,...,...,...
9881,investigators,14766,"[[101, 2035, 9171, 1997, 1996, 4484, 4641, 199..."
9882,clawed,22544,"[[101, 2474, 18834, 2050, 6104, 1017, 4832, 10..."
9883,hid,11041,"[[101, 12500, 8042, 1516, 2085, 3938, 2086, 22..."
9884,eyebrows,8407,"[[101, 2474, 18834, 2050, 6104, 1017, 4832, 10..."


In [19]:
final = bidirectional_samples[bidirectional_samples['sample_sents'].apply(lambda x: len(x) >= 8)]
final

Unnamed: 0,token,token_id,sample_sents
0,the,1996,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
1,of,1997,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
2,and,1998,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
3,to,2000,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
4,in,1999,"[[101, 2728, 8945, 11314, 2121, 2003, 2019, 23..."
...,...,...,...
9559,warrior,6750,"[[101, 1043, 4135, 7442, 1005, 1055, 4799, 673..."
9574,i,1045,"[[101, 1999, 2294, 8945, 11314, 2121, 5652, 19..."
9671,ve,2310,"[[101, 1996, 10860, 2189, 2678, 1010, 2856, 20..."
9709,summers,10945,"[[101, 4649, 11802, 6006, 2006, 16779, 2547, 2..."
