In [1]:
import numpy as np

import jax
import jax.numpy as jnp

from jax import random
main_rng = random.PRNGKey(421)

from typing import List


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [2]:
from jax_mlm_helpers import apply_random_masking
from jax_mlm_helpers import pad_and_crop_to_maximum_length
from utils_display import pc

In [3]:
import spacy
from collections import Counter
nlp = spacy.load("en_core_web_sm")

# Example text

In [4]:
text = "I asked him what he was doing with a basilisk and he said that was his business. Now completely overwhelmed by curiosity, I said that these days, with all the deaths, there could be no more secret matters, and I would tell William. Then Salvatore ardently begged me to remain silent, opened the bundle, and showed me a black cat. He drew me closer and, with an obscene smile, said that he didn’t  want the cellarer, who was powerful, or me, young and handsome, to enjoy the love of the village girls any more, when he couldn’t because he was ugly and a poor wretch. But he knew a prodigious spell that would make every woman succumb to love. You had to kill a black cat and dig out its eyes, then put them in two eggs of a black hen, one eye in one egg, one eye in the other (and he showed me two eggs that he swore he had taken from appropriate hens). Then you had to let the eggs rot in a pile of horse dung (and he had one ready in a corner of the vegetable garden where nobody ever went), and there a little devil would be born from each egg, and would then be at your service, procuring for you all the delights of this world. But, alas, he told me, for the magic spell to work, the woman whose love he wanted had to spit on the eggs before they were buried in the dung, and that problem tormented him, because he would have to have the woman in question at hand that night, and make her perform the ritual without knowing its purpose."

In [5]:
sentences = [i for i in nlp(text).sents]
for s in sentences:
    print(s)

I asked him what he was doing with a basilisk and he said that was his business.
Now completely overwhelmed by curiosity, I said that these days, with all the deaths, there could be no more secret matters, and I would tell William.
Then Salvatore ardently begged me to remain silent, opened the bundle, and showed me a black cat.
He drew me closer and, with an obscene smile, said that he didn’t  want the cellarer, who was powerful, or me, young and handsome, to enjoy the love of the village girls any more, when he couldn’t because he was ugly and a poor wretch.
But he knew a prodigious spell that would make every woman succumb to love.
You had to kill a black cat and dig out its eyes, then put them in two eggs of a black hen, one eye in one egg, one eye in the other (and he showed me two eggs that he swore he had taken from appropriate hens).
Then you had to let the eggs rot in a pile of horse dung (and he had one ready in a corner of the vegetable garden where nobody ever went), and the

In [6]:
def build_vocabulary(texts, max_vocab=10000, min_freq=1):

    nlp = spacy.blank("en")
    word_counter = Counter()
    for doc in nlp.pipe(texts):
        for word in doc:
            word_counter[word.lower_] += 1

    dico_word2index = {}
    dico_index2word = {}
    
    special_tokens = ['[MASK]', '[PAD]', '[UNK]']
    for token in special_tokens:
        index = len(dico_word2index)
        dico_word2index[token] = index
        dico_index2word[index] = token
    
    for word, count in word_counter.most_common():
        if count < min_freq: break
        if len(dico_word2index) >= max_vocab: break
        index = len(dico_word2index)
        dico_word2index[word] = index
        dico_index2word[index] = word
        
    return dico_word2index, dico_index2word

In [7]:
def apply_word_tokenization(text) -> List[str]:
    nlp = spacy.blank("en")
    doc = nlp(text)
    word_tokens = [str(w).lower() for w in doc]
    return word_tokens

In [8]:
doc = nlp(text)
my_tokens = apply_word_tokenization(text=doc)
print(my_tokens)

['i', 'asked', 'him', 'what', 'he', 'was', 'doing', 'with', 'a', 'basilisk', 'and', 'he', 'said', 'that', 'was', 'his', 'business', '.', 'now', 'completely', 'overwhelmed', 'by', 'curiosity', ',', 'i', 'said', 'that', 'these', 'days', ',', 'with', 'all', 'the', 'deaths', ',', 'there', 'could', 'be', 'no', 'more', 'secret', 'matters', ',', 'and', 'i', 'would', 'tell', 'william', '.', 'then', 'salvatore', 'ardently', 'begged', 'me', 'to', 'remain', 'silent', ',', 'opened', 'the', 'bundle', ',', 'and', 'showed', 'me', 'a', 'black', 'cat', '.', 'he', 'drew', 'me', 'closer', 'and', ',', 'with', 'an', 'obscene', 'smile', ',', 'said', 'that', 'he', 'did', 'n’t', ' ', 'want', 'the', 'cellarer', ',', 'who', 'was', 'powerful', ',', 'or', 'me', ',', 'young', 'and', 'handsome', ',', 'to', 'enjoy', 'the', 'love', 'of', 'the', 'village', 'girls', 'any', 'more', ',', 'when', 'he', 'could', 'n’t', 'because', 'he', 'was', 'ugly', 'and', 'a', 'poor', 'wretch', '.', 'but', 'he', 'knew', 'a', 'prodigious'

In [9]:
dico_word2index, dico_index2wordd = build_vocabulary(texts=my_tokens)
print(dico_word2index)

{'[MASK]': 0, '[PAD]': 1, '[UNK]': 2, ',': 3, 'the': 4, 'he': 5, 'and': 6, 'a': 7, '.': 8, 'to': 9, 'that': 10, 'in': 11, 'me': 12, 'would': 13, 'of': 14, 'had': 15, 'was': 16, 'then': 17, 'eggs': 18, 'one': 19, 'i': 20, 'with': 21, 'said': 22, 'be': 23, 'black': 24, 'love': 25, 'woman': 26, 'you': 27, 'him': 28, 'all': 29, 'there': 30, 'could': 31, 'more': 32, 'showed': 33, 'cat': 34, 'n’t': 35, 'because': 36, 'but': 37, 'spell': 38, 'make': 39, 'its': 40, 'two': 41, 'eye': 42, 'egg': 43, '(': 44, 'from': 45, ')': 46, 'dung': 47, 'at': 48, 'for': 49, 'have': 50, 'asked': 51, 'what': 52, 'doing': 53, 'basilisk': 54, 'his': 55, 'business': 56, 'now': 57, 'completely': 58, 'overwhelmed': 59, 'by': 60, 'curiosity': 61, 'these': 62, 'days': 63, 'deaths': 64, 'no': 65, 'secret': 66, 'matters': 67, 'tell': 68, 'william': 69, 'salvatore': 70, 'ardently': 71, 'begged': 72, 'remain': 73, 'silent': 74, 'opened': 75, 'bundle': 76, 'drew': 77, 'closer': 78, 'an': 79, 'obscene': 80, 'smile': 81, 'd

In [10]:
input_sentence = sentences[0]
word_tokens = apply_word_tokenization(text=str(input_sentence))
input_indices = [dico_word2index.get(w, '[UNK]') for w in word_tokens]

pc("Input sentence", input_sentence)
pc("Word tokens", word_tokens)
pc("Input indices", input_indices)

[34mInput sentence[0m: I asked him what he was doing with a basilisk and he said that was his business.
[34mWord tokens[0m: ['i', 'asked', 'him', 'what', 'he', 'was', 'doing', 'with', 'a', 'basilisk', 'and', 'he', 'said', 'that', 'was', 'his', 'business', '.']
[34mInput indices[0m: [20, 51, 28, 52, 5, 16, 53, 21, 7, 54, 6, 5, 22, 10, 16, 55, 56, 8]


In [11]:
masking_probability = 0.15
label_for_unmasked_values = -100
maximum_sequence_length = 25
mask_index = dico_word2index["[MASK]"]
pad_index = dico_word2index["[PAD]"]

In [12]:
input_indices, mask, masked_indices, labels = apply_random_masking(                                                                                              
    input_indices=input_indices,                                                                                                 
    index_for_masked_values=mask_index,
    label_for_unmasked_values=label_for_unmasked_values,                                                                                 
     masking_probability=masking_probability)

print("Before padding / cropping")
pc("Input indices", input_indices)
pc("Mask", mask)
pc("Masked indices", masked_indices)
pc("Labels", labels)

Before padding / cropping
[34mInput indices[0m: [20 51 28 52  5 16 53 21  7 54  6  5 22 10 16 55 56  8]
[34mMask[0m: [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0]
[34mMasked indices[0m: [ 0 51  0 52  5 16 53 21  7 54  6  5 22 10 16  0 56  8]
[34mLabels[0m: [  20 -100   28 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
 -100   55 -100 -100]


In [13]:
input_indices = pad_and_crop_to_maximum_length(input_indices, padding_value=pad_index, maximum_sequence_length=maximum_sequence_length)
mask = pad_and_crop_to_maximum_length(mask, padding_value=1, maximum_sequence_length=maximum_sequence_length)
masked_indices = pad_and_crop_to_maximum_length(masked_indices, padding_value=pad_index, maximum_sequence_length=maximum_sequence_length)
labels = pad_and_crop_to_maximum_length(labels, padding_value=-100, maximum_sequence_length=maximum_sequence_length)

print("After padding / cropping")
pc("Input indices", input_indices)
pc("Mask", mask)
pc("Masked indices", masked_indices)
pc("Labels", labels)

After padding / cropping
[34mInput indices[0m: [20 51 28 52  5 16 53 21  7 54  6  5 22 10 16 55 56  8  1  1  1  1  1  1
  1]
[34mMask[0m: [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 1 1 1 1 1 1]
[34mMasked indices[0m: [ 0 51  0 52  5 16 53 21  7 54  6  5 22 10 16  0 56  8  1  1  1  1  1  1
  1]
[34mLabels[0m: [  20 -100   28 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
 -100   55 -100 -100 -100 -100 -100 -100 -100 -100 -100]
