In [1]:
from transformers import BertConfig, BertTokenizer, BertTokenizerFast, TFBertForMaskedLM, AdamWeightDecay, DataCollatorForLanguageModeling
import tensorflow as tf
from datasets import load_dataset
from collections import Counter
import os.path

In [2]:
# Load dataset

dataset_path = "/Users/gjuan/semiolog/models/en_bnc_berttest/corpus/"
dataset_files = ["train","dev","test"]
extension = "text"

# Use the argument split=["train[:1%]", "dev[:1%]"] to load only 1% of each
# split. However, this makes the dataset object not be a dict (with keys "train"
# and "dev") but a list. To call the train split (for instance, below), one
# should use dataset[0]["text"] instead of dataset["train"]["text"]

# dataset = load_dataset(extension, data_files={f:dataset_path+f+".txt" for f in dataset_files})#, split=["train[:1%]", "dev[:1%]"])
dataset = load_dataset(extension, data_files={k:dataset_path+f+".txt" for k,f in zip(dataset_files,["test","dev","test"])})


Using custom data configuration default-8b71f57495763543


Downloading and preparing dataset text/default to /Users/gjuan/.cache/huggingface/datasets/text/default-8b71f57495763543/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5...


100%|██████████| 3/3 [00:00<00:00, 5007.13it/s]
100%|██████████| 3/3 [00:00<00:00, 596.77it/s]


Dataset text downloaded and prepared to /Users/gjuan/.cache/huggingface/datasets/text/default-8b71f57495763543/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 318.95it/s]


In [3]:
# Build vocabulary of segmented sentences and save it to a vocab.txt file

if not os.path.isfile("vocab.txt") or False:
    tokens = []

    for sent in dataset["train"]["text"]:
            tokens.extend(sent.split())
    tokens_count = Counter(tokens) # This could probably be done directly on the counter
    vocab = [token for token, freq in tokens_count.most_common()]

    with open("vocab.txt", 'w') as f:
        for token in ["[PAD]", "[SEP]", "[CLS]", "[MASK]", "[UNK]"] + vocab:
            f.write("%s\n" % token)

In [4]:
# Build tokenizer out that vocabulary and tokenize the dataset

# tokenizer= BertTokenizer(
#         vocab_file = "vocab.txt",
#         do_lower_case=True,
#         do_basic_tokenize=True,
#         never_split=None,
#         unk_token="[UNK]",
#         sep_token="[SEP]",
#         pad_token="[PAD]",
#         cls_token="[CLS]",
#         mask_token="[MASK]",
#         tokenize_chinese_chars=True,
#         strip_accents=None,
# )

tokenizer = BertTokenizerFast(
        vocab_file="vocab.txt",
        tokenizer_file=None,
        do_lower_case=True,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        tokenize_chinese_chars=True,
        strip_accents=None,
)

def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=10000, num_proc=8, remove_columns=["text"])

# print a decoded tokenized sentence
print(tokenizer.decode(tokenized_datasets["train"][0]["input_ids"]))

 #0:   0%|          | 0/1 [00:00<?, ?ba/s]
[A

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





 #0: 100%|██████████| 1/1 [00:00<00:00,  6.78ba/s]
 #1: 100%|██████████| 1/1 [00:00<00:00,  6.08ba/s]

 #2: 100%|██████████| 1/1 [00:00<00:00,  5.53ba/s]


 #3: 100%|██████████| 1/1 [00:00<00:00,  5.47ba/s]





[A[A[A[A[A


[A[A[A



[A[A[A[A





 #6: 100%|██████████| 1/1 [00:00<00:00,  5.31ba/s]
 #4: 100%|██████████| 1/1 [00:00<00:00,  4.78ba/s]
 #5: 100%|██████████| 1/1 [00:00<00:00,  4.83ba/s]
 #7: 100%|██████████| 1/1 [00:00<00:00,  5.21ba/s]


 #0:   0%|          | 0/1 [00:00<?, ?ba/s]
[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





 #1: 100%|██████████| 1/1 [00:00<00:00,  9.56ba/s]
 #3: 100%|██████████| 1/1 [00:00<00:00,  9.38ba/s]



[A[A[A



 #4: 100%|██████████| 1/1 [00:00<00:00,  5.96ba/s]
 #5: 100%|██████████| 1/1 [00:00<00:00,  5.99ba/s]





 #6: 100%|██████████| 1/1 [00:00<00:00,  5.63ba/s]
 #0: 100%|██████████| 1/1 [00:00<00:00,  4.30ba/s

[CLS] then using all thep an ache of basil faw l ty i display ed a carefully laid out tray of h oot ers for aparty go er [SEP]


In [5]:
# Build the model (Huggingface Tensor Flow Bert for Mask Language Model: TFBertForMaskedLM)

configuration = BertConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        position_embedding_type="absolute",
        use_cache=True,
        classifier_dropout=None,
)

model = TFBertForMaskedLM(
    configuration
)

learning_rate = 5e-5 #2e-5
weight_decay = 0.01

optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay)

model.compile(
    optimizer = optimizer
    # optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
    # loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # metrics=tf.metrics.SparseCategoricalAccuracy(),
)

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! Please ensure your labels are passed as the 'labels' key of the input dict so that they are accessible to the model during the forward pass. To disable this behaviour, please pass a loss argument, or explicitly pass loss=None if you do not want your model to compute a loss.


In [6]:
# Build a Data Collator and train and validation sets. The Data Collator
# construct the batches, with padding, and in this particular case, random
# masking at a probability defined in the argument: mlm_probability"). Outputing
# TensorFlow tensors must be asked explicitly.

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15,
return_tensors="tf"
)

train_set = tokenized_datasets["train"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "labels"],
    shuffle=True,
    batch_size=16,
    collate_fn=data_collator,
)

validation_set = tokenized_datasets["dev"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "labels"],
    shuffle=False,
    batch_size=16,
    collate_fn=data_collator,
)

In [7]:
# Train the model

model.fit(train_set, validation_data=validation_set, epochs=2)

Epoch 1/2
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x179db30d0>

In [8]:
model.save_pretrained(save_directory="./")

In [19]:
config2 = "./config.json"
model2 = TFBertForMaskedLM.from_pretrained('./tf_model.h5', config=config2)

All model checkpoint layers were used when initializing TFBertForMaskedLM.

All the layers of TFBertForMaskedLM were initialized from the model checkpoint at ./tf_model.h5.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.


In [23]:
sent = dataset["dev"][200]["text"]
n = 4
sent_mask = " ".join([t if i!=n else "[MASK]" for i,t in enumerate(sent.split())])
sent_mask

'these canget into food [MASK] water especially in countries with poor s an itary facilities and thus infect otherpeople'

In [24]:
outputs = model(tokenizer(sent_mask, return_tensors="tf")["input_ids"])
outputs

TFMaskedLMOutput(loss=None, logits=<tf.Tensor: shape=(1, 21, 26290), dtype=float32, numpy=
array([[[-3.7524655, -1.9820781, -1.8103092, ..., -4.796723 ,
         -5.1437397, -5.0155263],
        [-3.7525015, -1.9820858, -1.8103209, ..., -4.796742 ,
         -5.1437283, -5.015563 ],
        [-3.7523944, -1.9820297, -1.8102915, ..., -4.796576 ,
         -5.143581 , -5.0154   ],
        ...,
        [-3.752418 , -1.9820464, -1.8103127, ..., -4.796629 ,
         -5.143641 , -5.0154343],
        [-3.7524612, -1.9820689, -1.8103261, ..., -4.7967014,
         -5.1437335, -5.01552  ],
        [-3.7524457, -1.9820572, -1.8103211, ..., -4.7966614,
         -5.14363  , -5.0154767]]], dtype=float32)>, hidden_states=None, attentions=None)

In [25]:
sent

'these canget into food or water especially in countries with poor s an itary facilities and thus infect otherpeople'

In [26]:
parad = [[(tokenizer.ids_to_tokens[k],v) for v,k in sorted([(v,i) for i,v in enumerate(outputs.logits[:,p].numpy().tolist()[0])], reverse=True)] for p in range(20)]

In [29]:
[parad[i][:4] for i in range(len(parad))]

[[('and', 4.18861722946167),
  ('in', 3.6120033264160156),
  ('the', 3.423407793045044),
  ('a', 3.3643510341644287)],
 [('and', 4.188627243041992),
  ('in', 3.6119933128356934),
  ('the', 3.423452138900757),
  ('a', 3.364354133605957)],
 [('and', 4.18848991394043),
  ('in', 3.611909866333008),
  ('the', 3.4233429431915283),
  ('a', 3.364262342453003)],
 [('and', 4.1886138916015625),
  ('in', 3.611983299255371),
  ('the', 3.4234120845794678),
  ('a', 3.3643462657928467)],
 [('and', 4.188564300537109),
  ('in', 3.6119673252105713),
  ('the', 3.4233763217926025),
  ('a', 3.3643195629119873)],
 [('and', 4.188572406768799),
  ('in', 3.611945152282715),
  ('the', 3.423354387283325),
  ('a', 3.3643226623535156)],
 [('and', 4.188621997833252),
  ('in', 3.6120097637176514),
  ('the', 3.4233970642089844),
  ('a', 3.3643224239349365)],
 [('and', 4.18859338760376),
  ('in', 3.611996650695801),
  ('the', 3.423382043838501),
  ('a', 3.364325523376465)],
 [('and', 4.188673973083496),
  ('in', 3.6120

In [30]:
parad[-1]

[('and', 4.188596725463867),
 ('in', 3.611995220184326),
 ('the', 3.4233922958374023),
 ('a', 3.3643202781677246),
 ('of', 3.1104166507720947),
 ('s', 2.9200122356414795),
 ('as', 2.7588236331939697),
 ('to', 2.752182722091675),
 ('for', 2.6618316173553467),
 ('is', 2.543020486831665),
 ('ing', 2.5419626235961914),
 ('that', 2.469393730163574),
 ('ed', 2.4655823707580566),
 ('with', 2.2697136402130127),
 ('but', 2.2521259784698486),
 ('it', 2.1866116523742676),
 ('on', 2.175072193145752),
 ('are', 2.151090145111084),
 ('was', 2.099869728088379),
 ('their', 2.0783278942108154),
 ('by', 2.038564443588257),
 ('from', 1.9598013162612915),
 ('or', 1.9573495388031006),
 ('which', 1.919629454612732),
 ('sof', 1.9009828567504883),
 ('at', 1.8680275678634644),
 ('an', 1.8291583061218262),
 ('all', 1.7533565759658813),
 ('inthe', 1.6526025533676147),
 ('ofthe', 1.6392264366149902),
 ('so', 1.5874210596084595),
 ('his', 1.5371545553207397),
 ('ly', 1.5112706422805786),
 ('forthe', 1.4921764135360