In [1]:
import torch
import evaluate
from safetensors.torch import load_model
from bert.model import BertMLM, BertConfig
from bert.data import load_pretraining_dataset, TrainingCollator
from datasets import load_from_disk
from torch.utils.data import DataLoader
from transformers import BertTokenizer

from bert.utils import decode_batch

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [3]:
config = BertConfig()
model = BertMLM(config)
weights_before = model.bias.detach().clone()
model_save_path = "/media/bryan/ssd01/expr/bert_from_scratch/debug01/checkpoints/checkpoint_10/model.safetensors"
load_model(model, model_save_path)
weight_after = model.bias.detach().clone()
assert not torch.allclose(weights_before, weight_after)

In [4]:
dataset = load_from_disk("/media/bryan/ssd01/expr/bert_from_scratch/debug01/initial_dataset_cache")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
dataloader = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=TrainingCollator(tokenizer))

In [5]:
batch = next(iter(dataloader))

In [6]:
with torch.inference_mode():
    token_logits = model(**batch)

In [104]:
def decode_with_mask(input_ids):
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    special_tokens = [x for x in tokenizer.all_special_tokens if x != tokenizer.mask_token]
    filtered_tokens = [x for x in tokens if x not in special_tokens]
    text = tokenizer.convert_tokens_to_string(filtered_tokens)
    clean_text = tokenizer.clean_up_tokenization(text)
    return clean_text


batch_original_text = tokenizer.batch_decode(batch["original_input_ids"], skip_special_tokens=True)


topk = 2
batch_index = 3

text = batch_original_text[batch_index]
text_with_mask = decode_with_mask(batch["input_ids"][batch_index])

mask_token_index = torch.where(batch["input_ids"][batch_index] == tokenizer.mask_token_id)[0]
mask_token_logits = token_logits[batch_index, mask_token_index,:]
mask_token_probs = torch.softmax(mask_token_logits, dim=1)
topk_tokens = torch.topk(mask_token_probs, topk, dim=1)

def decode_pred_string(k = 0, with_prob = True):
    pred_tokens = tokenizer.convert_ids_to_tokens(batch["input_ids"][batch_index])
    for i, token_index in enumerate(mask_token_index.tolist()):
        pred_token_id = topk_tokens.indices[i][k].item()
        pred_token = tokenizer.convert_ids_to_tokens(pred_token_id)
        if with_prob:
            pred_prob = topk_tokens.values[i][k].item()
            pred_token = f"{pred_token}[{pred_prob:.1%}]"
        pred_tokens[token_index] = pred_token
    filtered_pred_tokens = [x for x in pred_tokens if x not in tokenizer.all_special_tokens]
    text = tokenizer.convert_tokens_to_string(filtered_pred_tokens)
    clean_text = tokenizer.clean_up_tokenization(text)
    return clean_text

print(text)
print(text_with_mask)
for k in range(topk):
    pred_text = decode_pred_string(k, with_prob=True)
    print(f"{pred_text} (Top {k} prediction)")

he'd seen the movie almost by mistake, considering he was a little young for the pg cartoon, but with older cousins, along with her brothers, mason was often exposed to things that were older.
he [MASK] d seen [MASK] movie almost by mistake, [MASK] he was a little young for the pg cartoon, but with older [MASK], along with her brothers, mason was often exposed to things that were older [MASK]
he '[100.0%] d seen a[41.0%] movie almost by mistake, and[64.1%] he was a little young for the pg cartoon, but with older women[28.7%], along with her brothers, mason was often exposed to things that were older.[99.8%] (Top 0 prediction)
he.[0.0%] d seen the[30.6%] movie almost by mistake, because[4.9%] he was a little young for the pg cartoon, but with older men[24.4%], along with her brothers, mason was often exposed to things that were older![0.2%] (Top 1 prediction)


In [121]:
input_ids = batch["input_ids"][1]
torch.count_nonzero(input_ids == tokenizer.mask_token_id).item()

True

In [126]:
decoded_batch = decode_batch(tokenizer, batch, token_logits)

In [134]:
for decode in decoded_batch:
    for k,v in decode.items():
        tabs = "\t\t" if k == "text" else "\t"
        print(f"{k}{tabs}", v)
    print("\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\")

text		 usually, he would be tearing around the living room, playing with his toys.
text_with_mask	 [MASK], he would be tearing around the living room, playing with his toys.
pred_top_1	 now[5.8%], he would be tearing around the living room, playing with his toys.
pred_top_2	 instead[5.3%], he would be tearing around the living room, playing with his toys.
\\\\\\\\\\\\\\\\\\\\\
text		 but just one look at a minion sent him practically catatonic.
text_with_mask	 but just one look at a minion sent [MASK] [MASK] catatonic.
pred_top_1	 but just one look at a minion sent him[15.9%] to[14.6%] catatonic.
pred_top_2	 but just one look at a minion sent her[14.2%] a[9.1%] catatonic.
\\\\\\\\\\\\\\\\\\\\\
text		 that had been megan's plan when she got him dressed earlier.
text_with_mask	 that had been megan'[MASK] plan when [MASK] got him dressed earlier.
pred_top_1	 that had been megan's[99.9%] plan when she[40.7%] got him dressed earlier.
pred_top_2	 that had been megan'd[0.0%] plan when i[23.0%

In [23]:
import evaluate
import datasets

class PerplexityForMLM(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            module_type="metric",
            description="Perplexity for masked language models",
            citation="",
            inputs_description= "Logits and true token IDs",
            features = datasets.Features(
            {
                "logits": datasets.Value("float32"),
                "references": datasets.Value("int32"),
            }),
        )

    def _compute(self, logits, references):
        log_probs = torch.log_softmax(torch.tensor(logits), dim=-1)
        nll = -log_probs[range(len(references)), references]
        avg_nll = nll.mean()
        return {"perplexity": torch.exp(avg_nll).item()}

In [24]:
mask_token_batch_indices, mask_token_seq_indices = torch.where(batch["input_ids"] == tokenizer.mask_token_id)
mask_token_logits = token_logits[mask_token_batch_indices, mask_token_seq_indices, :]
true_token_ids = batch["labels"][mask_token_batch_indices, mask_token_seq_indices]

In [25]:
perplexity_metric = PerplexityForMLM()
result = perplexity_metric.compute(logits=mask_token_logits.cpu().numpy(), references=true_token_ids.cpu().numpy())
print(f"Custom Metric Perplexity: {result['perplexity']:.3f}")

ValueError: Module inputs don't match the expected format.
Expected format: {'logits': Value(dtype='float32', id=None), 'references': Value(dtype='int32', id=None)},
Input logits: [[ -9.368389   -4.121412   -4.546527  ...  -4.754039   -4.011625
   -4.8967376]
 [ -5.4772935  -1.5144248  -2.090969  ...  -3.2176354  -1.409075
   -1.9109528]
 [-10.223923   -4.045441   -4.095291  ...  -4.6059313  -4.705429
   -4.9013143]
 ...
 [ -9.298979   -4.258144   -4.248714  ...  -4.0665255  -4.6621933
   -4.9004464]
 [-10.275549   -6.0642056  -4.7409644 ...  -4.5790854  -4.271665
   -5.9925475]
 [-13.01876    -5.8997355  -6.4239573 ...  -5.3757486  -6.581865
   -6.3337   ]],
Input references: [13311 12756  2933  1005  1037  2021  1010  2411  2000  2108  2001  1037
  2287  1036  1029  4510  2471  2046]