# MathBERT

In [1]:
import os
import re

Index = str("4") # GPU Index/Indices

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= Index

import numpy as np
from scipy import stats
import torch
import torch.nn as nn 
from tqdm.notebook import tqdm
from transformers import *
import utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Tokenization

In [4]:
### Configuration
# model_name_or_path = 'allenai/scibert_scivocab_cased'
model_name_or_path = 'bert-base-cased'
PrecedingSentNum = 7
SucceedingSentNum = 0
AddNotationToVocab = True

data_files = {}

data_files["train"] = "../../Data/features_s2orc_external_cs_limit=1000.json"

In [5]:
model = AutoModelForMaskedLM.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    use_fast=True,
)

### Add Custom Vocabs
tokenizer.add_tokens(["SYMBOL"])
tokenizer.add_tokens(["SECTION"])
tokenizer.add_tokens(["EQUATION"])
tokenizer.add_tokens(["CITATION"])
tokenizer.add_tokens(["$$"])

# SplitMethod can be "ByPaperID", 
PreprocessedData, added_tokens = utils.Preprocessing(tokenizer, data_files["train"], SplitMethod="ByPaperID",
                                       PrecedingSentNum=PrecedingSentNum, SucceedingSentNum=SucceedingSentNum,
                                       AddNotationToVocab=AddNotationToVocab)
tokenizer.add_tokens(added_tokens)
train_src, train_tgt, valid_src, valid_tgt, test_src, test_tgt = PreprocessedData

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



AddNotationsToVocab


100%|██████████| 222401/222401 [00:01<00:00, 165535.24it/s]
  6%|▋         | 33/526 [00:00<00:01, 325.14it/s]

AddedTokens ['\\rightrightarrows', '\\subsetneqq', '\\deg', '\\blacksquare', '\\bowtie', '\\lim', '\\uplus', '\\eta', '\\textsf', '\\bigoplus', '\\rightarrow', '\\cr', '\\cap', '\\vphantom', '\\vee', '\\mid', '\\of', '\\textmd', '\\mathbfcal', '\\dots', '\\Psi', '\\mathbin', '\\right\\Vert', '\\ogreaterthan', '\\mathbf', '\\tt', '\\iint', '\\cup', '\\\\a', '\\\\v\\\\w', '\\iff', '\\\\\\qquad', '\\cite', '\\diagdown', '\\mu', '\\overrightarrow', '\\mbox', '\\makebox', '\\Bigm', '\\\\\\\\', '\\it', '\\scshape', '\\color', '\\leqslant', '\\index', '\\normalfont', '\\left\\Vert', '\\left\\lfloor', '\\div', '\\acute', '\\textstyle', '\\dashv', '\\ker', '\\Re', '\\small', '\\bigtriangleup', '\\spadesuit', '\\left\\bad', '\\dag', '\\\\\\nonumber', '\\bot', '\\complement', '\\Huge', '\\simeq', '\\ne', '\\phantom', '\\odot', '\\em', '\\left\\lbrace', '\\rightsquigarrow', '\\\\w', '\\footnotesize', '\\langle', '\\hline', '\\\\\\mathbf', '\\\\K', '\\textbf', '\\pazocal', '\\frac', '\\ddot', '\\\\

100%|██████████| 526/526 [00:01<00:00, 314.99it/s]
  0%|          | 52/222401 [00:00<07:13, 513.21it/s]


Prepare Data


100%|██████████| 222401/222401 [07:23<00:00, 501.75it/s] 


In [6]:
from datasets import Dataset, DatasetDict

MaxTokenLen = 512

def tokenize_function(examples):
    examples["text"] = tokenizer.convert_tokens_to_string(examples["text"].split())
    return tokenizer(
        examples["text"],
        truncation=False,
        max_length=MaxTokenLen,
        add_special_tokens = False,
        return_special_tokens_mask=False,
    )

## Test set Prediction

In [None]:
### Configuration
model_name_or_path = 'bert-base-cased'

Verbose = True
MaskNumConstraint = 10
FineTune = True
Test = True
EpochNum = 20

In [None]:
from termcolor import colored
from datasets.utils.logging import set_verbosity_error
import random
import copy

### Reload
model = AutoModelForMaskedLM.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
)
model.resize_token_embeddings(len(tokenizer))
model = nn.DataParallel(model).to(device)

if FineTune:
    LossFunction = nn.CrossEntropyLoss()
    Optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.module.parameters()), lr=5e-6)
    pbar1 = tqdm(total=EpochNum)
    TopAcc = 0.
    for ep in range(EpochNum):
        ### Variables
        total_loss = []
        #####
        model.train()
        pbar2 = tqdm(total=len(train_tgt), leave=False)
        for i in range(len(train_tgt)):
            MaskNum = np.sum(np.array(train_src[i].split()) == "[MASK]")
            if MaskNum > MaskNumConstraint: # MASK Length Constraint
                pbar2.update(1); continue
            elif MaskNum == 0:
                pbar2.update(1); continue
            DataDict_train = { "train": Dataset.from_dict({ "text": [' '.join(train_src[i].split()) ] }), }
            datasets_train = DatasetDict(DataDict_train)

            tokenized_datasets = datasets_train.map(
                tokenize_function,
                batched=False,
                num_proc=1,
            )
            
            pred_token = train_src[i].split()
            tgt_token = train_tgt[i].split()
            
            for _ in range(1): # Partial Masking
                for cnt, x in enumerate(tokenized_datasets["train"]):
                    MASK_idx = [ii for ii, iids in enumerate(x["input_ids"])
                                if iids == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]]
                    outputs = model(torch.tensor(x["input_ids"]).unsqueeze(0).to(device),)

                p = []
                loss = 0
                for m in MASK_idx:
                    loss += LossFunction(outputs["logits"][0][m:m+1],
                                         torch.tensor(tokenizer.convert_tokens_to_ids([tgt_token[m]])).to(device))
                Optimizer.zero_grad()
                if loss != 0:
                    total_loss.append((loss/len(MASK_idx)).item())
                    print(loss/len(MASK_idx), end='\r')
                    loss.backward()
                Optimizer.step()
            pbar2.update(1)
        pbar2.close()
        print("{}/{}\tLoss:{:.4f}".format(ep+1, EpochNum, np.mean(total_loss)))
        pbar1.update(1)
        
        ### Evaluation
        model.eval()
        
        correct, total = 0., 0
        correct_topk = np.zeros(10+1)
        perplexity_list = []
        CorrectFlag = False

        pbar = tqdm(total = len(valid_tgt))
        for i in range(len(valid_tgt)):
            MaskNum = np.sum(np.array(valid_src[i].split()) == "[MASK]")
            if MaskNum > MaskNumConstraint: # MASK Length Constraint
                pbar.update(1); continue
            elif MaskNum == 0:
                pbar.update(1); continue
            DataDict_valid = { "valid": Dataset.from_dict({ "text": [' '.join(valid_src[i].split()) ] }), }
            datasets_valid = DatasetDict(DataDict_valid)

            tokenized_datasets = datasets_valid.map(
                tokenize_function,
                batched=False,
                num_proc=1,
            )

            for cnt, x in enumerate(tokenized_datasets["valid"]):
                with torch.no_grad():
                    MASK_idx = [ii for ii, iids in enumerate(x["input_ids"])
                                if iids == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]]
                    outputs = model(torch.tensor(x["input_ids"]).unsqueeze(0).to(device),)
                    
                    p = []
                    pred_token = valid_src[i].split()
                    tgt_token = valid_tgt[i].split()

                for m in MASK_idx:
                    predicted_index = torch.argmax(outputs["logits"][0][m]).item()
                    pred_t = tokenizer.convert_ids_to_tokens([predicted_index])[0]
#                     pred_token[m] = '[[' + pred_t + ']]' # When Visualization
                    pred_token[m] = pred_t
                    total += 1
                    if tgt_token[m] == pred_t:
                        CorrectFlag = True
                        correct += 1
                    ### Top-k
                    pred_topk = []; pred_ts = []
                    for n in range(1,10+1): # Starts from Top-1
                        pred_topk = torch.argsort(-outputs["logits"][0][m])[:n]
                        pred_ts = tokenizer.convert_ids_to_tokens([t for t in pred_topk])
                        if tgt_token[m] in pred_ts:
                            correct_topk[n] += 1

                    exp_t = outputs["logits"][0][m].cpu().numpy()
                    # https://huggingface.co/transformers/perplexity.html
                    prob = np.exp(exp_t[tokenizer.convert_tokens_to_ids(tgt_token[m])]/len(exp_t))
                    p.append(prob)
            if p: perplexity_list.append(np.sum(p))
            pbar.update(1)
        pbar.close()

        print(correct, '/', total, correct/total)
        for n in range(1,10+1):
            print("Top-"+str(n), "Accuracy", correct_topk[n]/total)
        print("Perplexity(MASK)", np.mean(perplexity_list))
        
        if correct/total > TopAcc:
            print("Model Saving in ./save/")
            model.module.save_pretrained('./save/')
            tokenizer.save_vocabulary("./save/")
            TopAcc = correct/total
    pbar1.close()
    ###

if Test:
    correct, total = 0., 0
    total_loss = []
    TopAcc = 0.
    correct_topk = np.zeros(10+1)
    perplexity_list = []
    CorrectFlag = False

    print("Model Loading from ./save/")
    model = AutoModelForMaskedLM.from_pretrained('./save/')
    model.resize_token_embeddings(len(tokenizer))
    model = nn.DataParallel(model).to(device)
    
    model.eval()
    pbar = tqdm(total = len(test_tgt))
    for i in range(len(test_tgt)):
        MaskNum = np.sum(np.array(test_src[i].split()) == "[MASK]")
        if MaskNum > MaskNumConstraint: # MASK Length Constraint
            pbar.update(1); continue
        elif MaskNum == 0:
            pbar.update(1); continue
        DataDict_test = { "test": Dataset.from_dict({ "text": [' '.join(test_src[i].split()) ] }), }
        datasets_test = DatasetDict(DataDict_test)

        tokenized_datasets = datasets_test.map(
            tokenize_function,
            batched=False,
            num_proc=1,
        )

        for cnt, x in enumerate(tokenized_datasets["test"]):
            with torch.no_grad():
                MASK_idx = [ii for ii, iids in enumerate(x["input_ids"])
                            if iids == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]]
                outputs = model(torch.tensor(x["input_ids"]).unsqueeze(0).to(device),)
                p = []
                pred_token = test_src[i].split()
                tgt_token = test_tgt[i].split()

            for m in MASK_idx:
                predicted_index = torch.argmax(outputs["logits"][0][m]).item()
                pred_t = tokenizer.convert_ids_to_tokens([predicted_index])[0]
#                 pred_token[m] = '[[' + pred_t + ']]' # Visualization
                pred_token[m] = pred_t
                total += 1
                if tgt_token[m] == pred_t:
                    CorrectFlag = True
                    correct += 1
                ### Top-k
                pred_topk = []; pred_ts = []
                for n in range(1,10+1): # Starts from Top-1
                    pred_topk = torch.argsort(-outputs["logits"][0][m])[:n]
                    pred_ts = tokenizer.convert_ids_to_tokens([t for t in pred_topk])
                    if tgt_token[m] in pred_ts:
                        correct_topk[n] += 1

                exp_t = outputs["logits"][0][m].cpu().numpy()
                # https://huggingface.co/transformers/perplexity.html
                prob = np.exp(exp_t[tokenizer.convert_tokens_to_ids(tgt_token[m])]/len(exp_t))
                p.append(prob)
                
        if Verbose:
            if i <= 10:
                print("#Mask", MaskNum)
                ### Coloring
                for t1, t2 in zip(test_src[i].split(), pred_token):
                    if t1 == "[PAD]": break
                    else:
                        if t1 == "[MASK]":
                            print(colored(t2, 'red'), end=' ')
                        else:
                            print(t2, end=' ')
                print()
                #####
                for t1, t2 in zip(test_src[i].split(), test_tgt[i].split()):
                    if t1 == "[PAD]": break
                    else:
                        if t1 == "[MASK]":
                            print(colored(t2, 'blue'), end=' ')
                        else:
                            print(t2, end=' ')
                print()
                CorrectFlag = False

        if p: perplexity_list.append(np.sum(p))

        pbar.update(1)
    pbar.close()

    print(correct, '/', total, correct/total)
    for n in range(1,10+1):
        print("Top-"+str(n), "Accuracy", correct_topk[n]/total)
    print("Perplexity(MASK)", np.mean(perplexity_list))

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

tensor(0.2429, device='cuda:0', grad_fn=<DivBackward0>))0>)



1/20	Loss:1.7341evice='cuda:0', grad_fn=<DivBackward0>))0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 140045.0 / 195005 0.7181610727930053
Top-1 Accuracy 0.7181610727930053
Top-2 Accuracy 0.7770159739493859
Top-3 Accuracy 0.8053485808056203
Top-4 Accuracy 0.8244968077741597
Top-5 Accuracy 0.8387477244173226
Top-6 Accuracy 0.8499423091715597
Top-7 Accuracy 0.8592497628265942
Top-8 Accuracy 0.8671572523781441
Top-9 Accuracy 0.8738647726981359
Top-10 Accuracy 0.8797415450885875
Perplexity(MASK) 4.770176012085848
Model Saving in ./save/Left3/4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

2/20	Loss:1.2780evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143890.0 / 195005 0.7378785159354888
Top-1 Accuracy 0.7378785159354888
Top-2 Accuracy 0.7935334991410476
Top-3 Accuracy 0.8209584369631547
Top-4 Accuracy 0.8397784672187892
Top-5 Accuracy 0.8541165611138176
Top-6 Accuracy 0.86510089484885
Top-7 Accuracy 0.8742852747365453
Top-8 Accuracy 0.8813517602112766
Top-9 Accuracy 0.8873875028845414
Top-10 Accuracy 0.8929463347093665
Perplexity(MASK) 4.770283215645581
Model Saving in ./save/Left3/4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

3/20	Loss:1.1359evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144868.0 / 195005 0.7428937719545653
Top-1 Accuracy 0.7428937719545653
Top-2 Accuracy 0.7974923719904617
Top-3 Accuracy 0.8259377964667572
Top-4 Accuracy 0.8442501474321171
Top-5 Accuracy 0.8577062126612138
Top-6 Accuracy 0.8687623394271942
Top-7 Accuracy 0.8773005820363581
Top-8 Accuracy 0.8843209148483372
Top-9 Accuracy 0.8906694700135894
Top-10 Accuracy 0.8964385528576191
Perplexity(MASK) 4.770297205057709
Model Saving in ./save/Left3/4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

4/20	Loss:1.0276evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144837.0 / 195005 0.742734801671752
Top-1 Accuracy 0.742734801671752
Top-2 Accuracy 0.7982769672572498
Top-3 Accuracy 0.8261480474859619
Top-4 Accuracy 0.8447014179123612
Top-5 Accuracy 0.8585779851798672
Top-6 Accuracy 0.8694289890002821
Top-7 Accuracy 0.8778646701366631
Top-8 Accuracy 0.8852234558088254
Top-9 Accuracy 0.8914232968385426
Top-10 Accuracy 0.8967462372759673
Perplexity(MASK) 4.770350409970787


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

5/20	Loss:0.9338evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144686.0 / 195005 0.7419604625522422
Top-1 Accuracy 0.7419604625522422
Top-2 Accuracy 0.7965590625881388
Top-3 Accuracy 0.8251429450526909
Top-4 Accuracy 0.8437373400682033
Top-5 Accuracy 0.8575421143047615
Top-6 Accuracy 0.8684751673034025
Top-7 Accuracy 0.8771826363426579
Top-8 Accuracy 0.8844286043947591
Top-9 Accuracy 0.8906745980872285
Top-10 Accuracy 0.8959924104510141
Perplexity(MASK) 4.770400250063041


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

6/20	Loss:0.8488evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144831.0 / 195005 0.7427040332299172
Top-1 Accuracy 0.7427040332299172
Top-2 Accuracy 0.7968411066382913
Top-3 Accuracy 0.8237173405810108
Top-4 Accuracy 0.8423219917438014
Top-5 Accuracy 0.8559677956975462
Top-6 Accuracy 0.8667008538242609
Top-7 Accuracy 0.8753006333170944
Top-8 Accuracy 0.8823466064972693
Top-9 Accuracy 0.888490038716956
Top-10 Accuracy 0.8939668213635548
Perplexity(MASK) 4.770463049876609


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

7/20	Loss:0.7732evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144466.0 / 195005 0.740832286351632
Top-1 Accuracy 0.740832286351632
Top-2 Accuracy 0.7947078280044101
Top-3 Accuracy 0.8216917514935514
Top-4 Accuracy 0.8398964129124894
Top-5 Accuracy 0.853198635932412
Top-6 Accuracy 0.863834260659983
Top-7 Accuracy 0.8724750647419297
Top-8 Accuracy 0.8799671803287096
Top-9 Accuracy 0.8865464988077228
Top-10 Accuracy 0.8919463603497346
Perplexity(MASK) 4.770453070554677


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

8/20	Loss:0.7013evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144720.0 / 195005 0.742134817055973
Top-1 Accuracy 0.742134817055973
Top-2 Accuracy 0.7942821978923618
Top-3 Accuracy 0.820481526114715
Top-4 Accuracy 0.8380605625496782
Top-5 Accuracy 0.8510397169303351
Top-6 Accuracy 0.8609779236429835
Top-7 Accuracy 0.8693059152329428
Top-8 Accuracy 0.8767005974205789
Top-9 Accuracy 0.8828389015666265
Top-10 Accuracy 0.8880951770467423
Perplexity(MASK) 4.770519262318138


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

9/20	Loss:0.6359evice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144429.0 / 195005 0.7406425476269839
Top-1 Accuracy 0.7406425476269839
Top-2 Accuracy 0.7926617266223943
Top-3 Accuracy 0.8188815671393042
Top-4 Accuracy 0.8364759877951847
Top-5 Accuracy 0.849321812261224
Top-6 Accuracy 0.8598600035896515
Top-7 Accuracy 0.8684136304197329
Top-8 Accuracy 0.8752596087279814
Top-9 Accuracy 0.8814799620522551
Top-10 Accuracy 0.8868900797415451
Perplexity(MASK) 4.770506537959315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

10/20	Loss:0.5767vice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144015.0 / 195005 0.738519525140381
Top-1 Accuracy 0.738519525140381
Top-2 Accuracy 0.7896977000589729
Top-3 Accuracy 0.8161431758160047
Top-4 Accuracy 0.8335632419681547
Top-5 Accuracy 0.8459834363221456
Top-6 Accuracy 0.8560806133176072
Top-7 Accuracy 0.8645573190431015
Top-8 Accuracy 0.8718289274633984
Top-9 Accuracy 0.8778185174739109
Top-10 Accuracy 0.883192738647727
Perplexity(MASK) 4.770620676154674


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

11/20	Loss:0.5217vice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 144047.0 / 195005 0.7386836234968334
Top-1 Accuracy 0.7386836234968334
Top-2 Accuracy 0.7890054101176893
Top-3 Accuracy 0.8142304043486065
Top-4 Accuracy 0.8310915104740904
Top-5 Accuracy 0.8434860644598856
Top-6 Accuracy 0.853014025281403
Top-7 Accuracy 0.8616343170687931
Top-8 Accuracy 0.8685161918925156
Top-9 Accuracy 0.8747724417322633
Top-10 Accuracy 0.8801928155688316
Perplexity(MASK) 4.770718111939788


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

12/20	Loss:0.4737vice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143496.0 / 195005 0.7358580549216687
Top-1 Accuracy 0.7358580549216687
Top-2 Accuracy 0.7865541909181816
Top-3 Accuracy 0.8117586728545422
Top-4 Accuracy 0.8288607984410656
Top-5 Accuracy 0.8416758544652702
Top-6 Accuracy 0.8513217609804877
Top-7 Accuracy 0.8594189892566857
Top-8 Accuracy 0.8664341939950257
Top-9 Accuracy 0.8724750647419297
Top-10 Accuracy 0.877669803338376
Perplexity(MASK) 4.770684454217437


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

13/20	Loss:0.4288, device='cuda:0', grad_fn=<DivBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143493.0 / 195005 0.7358426707007513
Top-1 Accuracy 0.7358426707007513
Top-2 Accuracy 0.7844260403579395
Top-3 Accuracy 0.8096561626624958
Top-4 Accuracy 0.8257583138893875
Top-5 Accuracy 0.8377682623522473
Top-6 Accuracy 0.8480244096305223
Top-7 Accuracy 0.8562344555267813
Top-8 Accuracy 0.8628189020794339
Top-9 Accuracy 0.8686751621753288
Top-10 Accuracy 0.8736904181944053
Perplexity(MASK) 4.770759175650433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

14/20	Loss:0.3916vice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 142993.0 / 195005 0.7332786338811825
Top-1 Accuracy 0.7332786338811825
Top-2 Accuracy 0.7819850773057101
Top-3 Accuracy 0.807425450629471
Top-4 Accuracy 0.8231942770698187
Top-5 Accuracy 0.8355067818773878
Top-6 Accuracy 0.8454706289582319
Top-7 Accuracy 0.8536191379708212
Top-8 Accuracy 0.8606138304146047
Top-9 Accuracy 0.8664341939950257
Top-10 Accuracy 0.8713622727622369
Perplexity(MASK) 4.770843298243514


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

15/20	Loss:0.3562vice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143372.0 / 195005 0.7352221737904157
Top-1 Accuracy 0.7352221737904157
Top-2 Accuracy 0.7827235199097459
Top-3 Accuracy 0.8072510961257404
Top-4 Accuracy 0.8235378580036409
Top-5 Accuracy 0.8355016538037486
Top-6 Accuracy 0.845101407656214
Top-7 Accuracy 0.8529627445450116
Top-8 Accuracy 0.8596805210122818
Top-9 Accuracy 0.8653778108253635
Top-10 Accuracy 0.8706853670418707
Perplexity(MASK) 4.770915974282884


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

16/20	Loss:0.3263vice='cuda:0', grad_fn=<DivBackward0>)d0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143201.0 / 195005 0.7343452731981231
Top-1 Accuracy 0.7343452731981231
Top-2 Accuracy 0.7809492064306044
Top-3 Accuracy 0.8046357785697803
Top-4 Accuracy 0.8207635701648676
Top-5 Accuracy 0.8323068639265659
Top-6 Accuracy 0.8416861106125484
Top-7 Accuracy 0.8497064177841593
Top-8 Accuracy 0.8562498397476987
Top-9 Accuracy 0.8620958436963155
Top-10 Accuracy 0.8670341786108049
Perplexity(MASK) 4.770979901087079


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

17/20	Loss:0.2999, device='cuda:0', grad_fn=<DivBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143279.0 / 195005 0.7347452629419758
Top-1 Accuracy 0.7347452629419758
Top-2 Accuracy 0.7813902207635701
Top-3 Accuracy 0.8050306402399938
Top-4 Accuracy 0.8210917668777724
Top-5 Accuracy 0.8330350503833235
Top-6 Accuracy 0.8428553114022718
Top-7 Accuracy 0.850439732314556
Top-8 Accuracy 0.8569728981308171
Top-9 Accuracy 0.8627624932694034
Top-10 Accuracy 0.8678803107612625
Perplexity(MASK) 4.771026826873388


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

18/20	Loss:0.2762, device='cuda:0', grad_fn=<DivBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 143031.0 / 195005 0.7334735006794697
Top-1 Accuracy 0.7334735006794697
Top-2 Accuracy 0.7802620445629599
Top-3 Accuracy 0.8041229712058665
Top-4 Accuracy 0.8198712853516577
Top-5 Accuracy 0.8315940616907259
Top-6 Accuracy 0.8406451116638035
Top-7 Accuracy 0.8484859362580447
Top-8 Accuracy 0.8549985897797492
Top-9 Accuracy 0.8607984410656137
Top-10 Accuracy 0.865588061844568
Perplexity(MASK) 4.77111567217422


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

19/20	Loss:0.2555, device='cuda:0', grad_fn=<DivBackward0>)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=216333.0), HTML(value='')))


Top-1 Accuracy 142866.0 / 195005 0.7326273685290121
Top-1 Accuracy 0.7326273685290121
Top-2 Accuracy 0.77966718802082
Top-3 Accuracy 0.8031588933617086
Top-4 Accuracy 0.8190200251275608
Top-5 Accuracy 0.8307838260557422
Top-6 Accuracy 0.8400912797107767
Top-7 Accuracy 0.8477269813594523
Top-8 Accuracy 0.8539216943155303
Top-9 Accuracy 0.85969077715956
Top-10 Accuracy 0.8645983436322145
Perplexity(MASK) 4.771162248244989


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1910591.0), HTML(value='')))

tensor(0.5056, device='cuda:0', grad_fn=<DivBackward0>)d0>)