In [2]:
from transformers import BertConfig, BertForMaskedLM, BertTokenizerFast, pipeline
import transformers 

In [3]:
config = BertConfig()
default_model = BertForMaskedLM(config=config)
default_model.num_parameters()

109514298

In [4]:
%%capture

!wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt

In [5]:
config = BertConfig()

In [6]:
default_tokenizer = BertTokenizerFast.from_pretrained('./vocab', config=config)
default_tokenizer.model_max_length = 512
default_tokenizer.init_kwargs['model_max_length'] = 512
default_tokenizer

PreTrainedTokenizerFast(name_or_path='./vocab', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### MLM using the original BERT (this is not fine-tuned on our COVID articles)

In [7]:
MASK_TOKEN = default_tokenizer.mask_token
MASK_TOKEN

'[MASK]'

In [8]:
default_tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [9]:
fill_mask = pipeline('fill-mask', model=default_model, tokenizer=default_tokenizer)

In [10]:
f'covid is a {MASK_TOKEN}'

'covid is a [MASK]'

In [11]:
fill_mask([f'covid is a {MASK_TOKEN}'])

[2022-07-27 18:51:38.360 pytorch-1-8-gpu-py3-ml-g4dn-xlarge-60bd0d07a83be181dcf7335baae2:10907 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-07-27 18:51:38.393 pytorch-1-8-gpu-py3-ml-g4dn-xlarge-60bd0d07a83be181dcf7335baae2:10907 INFO profiler_config_parser.py:102] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.


[{'score': 0.0002743445511441678,
  'token': 2636,
  'token_str': 'turn',
  'sequence': 'covid is a turn'},
 {'score': 0.0002334642776986584,
  'token': 21810,
  'token_str': 'doping',
  'sequence': 'covid is a doping'},
 {'score': 0.00021816801745444536,
  'token': 30044,
  'token_str': '10b',
  'sequence': 'covid is a 10b'},
 {'score': 0.00021085298794787377,
  'token': 26260,
  'token_str': 'courtroom',
  'sequence': 'covid is a courtroom'},
 {'score': 0.00019946105021517724,
  'token': 28741,
  'token_str': 'fon',
  'sequence': 'covid is a fon'}]

In [12]:
fill_mask('Covid-19 is a [MASK]')

[{'score': 0.0002585537440609187,
  'token': 19442,
  'token_str': 'bri',
  'sequence': 'covid - 19 is a bri'},
 {'score': 0.00023853234597481787,
  'token': 21019,
  'token_str': 'هذا',
  'sequence': 'covid - 19 is a هذا'},
 {'score': 0.00022808139328844845,
  'token': 10795,
  'token_str': 'recording',
  'sequence': 'covid - 19 is a recording'},
 {'score': 0.00022737719700671732,
  'token': 16085,
  'token_str': 'africans',
  'sequence': 'covid - 19 is a africans'},
 {'score': 0.00022075384913478047,
  'token': 632,
  'token_str': '所',
  'sequence': 'covid - 19 is a 所'}]

In [13]:
fill_mask('Covid is a [MASK]')

[{'score': 0.0002085938467644155,
  'token': 19926,
  'token_str': '##enta',
  'sequence': 'covid is aenta'},
 {'score': 0.00020403138478286564,
  'token': 11469,
  'token_str': '##antis',
  'sequence': 'covid is aantis'},
 {'score': 0.00020274311827961355,
  'token': 1929,
  'token_str': 'earn',
  'sequence': 'covid is a earn'},
 {'score': 0.00020188064081594348,
  'token': 19243,
  'token_str': '##agher',
  'sequence': 'covid is aagher'},
 {'score': 0.0002011073229368776,
  'token': 10248,
  'token_str': 'не',
  'sequence': 'covid is a не'}]

In [14]:
fill_mask('Omicron [MASK] in US')

[{'score': 0.00027196938754059374,
  'token': 16085,
  'token_str': 'africans',
  'sequence': 'omicron africans in us'},
 {'score': 0.0002351998700760305,
  'token': 13339,
  'token_str': 'keir',
  'sequence': 'omicron keir in us'},
 {'score': 0.00022676926164422184,
  'token': 28507,
  'token_str': 'hibern',
  'sequence': 'omicron hibern in us'},
 {'score': 0.00021411867055576295,
  'token': 1929,
  'token_str': 'earn',
  'sequence': 'omicron earn in us'},
 {'score': 0.00020654028048738837,
  'token': 24644,
  'token_str': '£500',
  'sequence': 'omicron £500 in us'}]