# BERT's playground 
Hello there! Welcome on BERT's playground. You may play with BERT here and see what he can do but always make sure he feels respected and admired.

## Setups


In [1]:
import pickle
import numpy as np
import tensorflow as tf
import masking
import BERT

from Vectorisation import Vectorisation
from Config import Config
from MaskedLanguageModel import MaskedLanguageModel
from MaskedTextGenerator import MaskedTextGenerator

with open("./ml4science_data.pkl", "rb") as fp:
    data_dict = pickle.load(fp)

config = Config()
vec = Vectorisation(config=config)




In [2]:
# Prepare data for masked language model
encoded = vec.encode_dict(data_dict)
x_masked_encoded, y_masked_encoded, sample_weights = masking.mask_input_and_labels(encoded, config.TOKEN_DICT)
print(x_masked_encoded.shape, y_masked_encoded.shape, sample_weights.shape)

mlm_ds = tf.data.Dataset.from_tensor_slices((x_masked_encoded, y_masked_encoded, sample_weights))
mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)

print(mlm_ds)

(254, 128) (254, 128) (254, 128)
<_BatchDataset element_spec=(TensorSpec(shape=(None, 128), dtype=tf.int32, name=None), TensorSpec(shape=(None, 128), dtype=tf.int32, name=None), TensorSpec(shape=(None, 128), dtype=tf.float64, name=None))>


In [3]:
# TODO: continue implementing everything downstream so that BERT can finally run freely on his playground just as he wishes
sample_tokens = x_masked_encoded
print(encoded[0])

[26  2 26 21 21 26 23 25 21 23  2 21 26 21 21 26 21 25  3  8  3  3  3  3
  3  3  3  3  7  8  4  2  8  3 16 20 21 26 21 23 21  3  8  7  8  3  5  8
  4  8  2  8  3  8 10  9 10  9  9 15 20 11  9 14 13 11  9 14 13 14 10 11
  9  9 13 14 13 14 11  9 10 14 13 10 11  9 14  9 14 15 20 15 15 17 20  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0]


In [4]:
print(x_masked_encoded[0])

[26  2 26 21 21 26  9 25 21 23  2 21 26 21 21 26 21  1  3  8  3  1  3  3
  3  3  3  3  7  8  1  2  8  3 16 20 21 26 21 23 21  3  8  7  8  3  5  8
  4  1  2  8  3  1 10  9 10  9  9 15  1 11  9  1  1 11  9 14 13 14 10  1
  9  9 13 14 13 14 11  9 10 14 13 10 11  1 14  9  1 15 20 15 15 17 20  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0]


In [5]:
generator_callback = MaskedTextGenerator(sample_tokens, config.TOKEN_DICT['[MASK]'])

bert_masked_model = BERT.create_masked_language_bert_model(config)
bert_masked_model.summary()


Model: "masked_bert_model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128)]                0         []                            
                                                                                                  
 word_embedding (Embedding)  (None, 128, 256)             6912      ['input_1[0][0]']             
                                                                                                  
 tf.__operators__.add (TFOp  (None, 128, 256)             0         ['word_embedding[0][0]']      
 Lambda)                                                                                          
                                                                                                  
 encoder_0/multiheadattenti  (None, 128, 256)             263168    ['tf.__operat

In [6]:
bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])
bert_masked_model.save("bert_mlm.keras")

Epoch 1/5


(254, 128, 27)
(11, 27)
[21 26 10 22 14]
{'input_seq': array([26,  2, 26, 21, 21, 26,  9, 25, 21, 23,  2, 21, 26, 21, 21, 26, 21,
        1,  3,  8,  3,  1,  3,  3,  3,  3,  3,  3,  7,  8,  1,  2,  8,  3,
       16, 20, 21, 26, 21, 23, 21,  3,  8,  7,  8,  3,  5,  8,  4,  1,  2,
        8,  3,  1, 10,  9, 10,  9,  9, 15,  1, 11,  9,  1,  1, 11,  9, 14,
       13, 14, 10,  1,  9,  9, 13, 14, 13, 14, 11,  9, 10, 14, 13, 10, 11,
        1, 14,  9,  1, 15, 20, 15, 15, 17, 20,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0]),
 'prediction': array([21, 26, 10, 22, 14], dtype=int64),
 'probability': array([0.4080506 , 0.21907467, 0.0528652 , 0.05080076, 0.04285622],
      dtype=float32)}
Epoch 2/5
(254, 128, 27)
(11, 27)
[26 21 22 23 14]
{'input_seq': array([26,  2, 26, 21, 21, 26,  9, 25, 21, 23,  2, 21, 26, 21, 21, 26, 21,
        1,  3,  8,  3,  1,  3,  3,  3,  3,  3,  3,  7,  8,  1,  2,  8,  