In [1]:
import os
import re
import transformers

import pandas as pd

from collections import Counter

### Part 1: Preprocessing

Remove all special charachters

In [2]:
def basicPreprocess(text):
    processed_text = text.lower()
    processed_text = re.sub(r"[^a-zA-Z0-9]+", ' ', processed_text)
    return processed_text

This dataset is based on the [CORD Challenge](https://www.kaggle.com/allen-institute-for-ai/CORD-19-research-challenge) on Kaggle, to check out how I cleaned it up from source and converted it into a CSV for ease of use, please check out [my notebook](https://github.com/lordtt13/word-embeddings/blob/master/COVID-19%20Research%20Data/prep_pdf.ipynb).

[Download this CSV](https://drive.google.com/file/d/1n6r40XFGlYF9phWP-Hx6Y4QiTqw_I7uS/view?usp=sharing) for yourself. It's approximately 4 GB.

In [3]:
complete_df = pd.read_csv("data/clean_df.csv")

In [4]:
data = complete_df.sample(frac = 1).sample(frac = 1)
data.dropna(inplace = True)
del complete_df

In [5]:
data = data["abstract"].apply(basicPreprocess).replace("\n"," ")

In [6]:
text = ''
for i in data.values:
    text += i
del data

In [7]:
counter = Counter(text.split())
del text

Remove words that are too frequent or infrequent

In [8]:
vocab = []
for keys, values in counter.items():
    if(values > 100 and values < 10000):
        vocab.append(keys)

In [9]:
len(vocab)

6735

### Part 2: Load Pretrained model and expand

#### Load the awesome [Allen AI SciBERT Model](https://github.com/allenai/scibert) which is a BERT model trained on scientific text.

In [10]:
tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = transformers.AutoModelWithLMHead.from_pretrained('allenai/scibert_scivocab_uncased').to('cuda')

In [11]:
model.config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

In [12]:
print(len(tokenizer))

31090


Add new tokens to the existing tokenizer.

In [13]:
tokenizer.add_tokens(vocab)
print(len(tokenizer))

31941


Now we need to resize the dictionary size of the embedding layer

In [14]:
model.resize_token_embeddings(len(tokenizer)) 
model.config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31941
}

In [15]:
del vocab

### Part 3: Fine Tune Model for Language Modeling

In [16]:
os.mkdir('models/COVID-scibert-latest')
tokenizer.save_pretrained('models/COVID-scibert-latest')

('models/COVID-scibert-latest/vocab.txt',
 'models/COVID-scibert-latest/special_tokens_map.json',
 'models/COVID-scibert-latest/added_tokens.json')

In [17]:
dataset = transformers.LineByLineTextDataset(
    tokenizer = tokenizer,
    file_path = "lm_data/train.txt",
    block_size = 16,
)

In [18]:
data_collator = transformers.DataCollatorForLanguageModeling(
    tokenizer = tokenizer, mlm = True, mlm_probability = 0.15
)

In [19]:
training_args = transformers.TrainingArguments(
    output_dir = "models/COVID-scibert-latest",
    overwrite_output_dir = True,
    num_train_epochs = 5,
    per_device_train_batch_size = 16,
    save_steps = 10_000,
    save_total_limit = 3,
)

trainer = transformers.Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = dataset,
    prediction_loss_only = True,
)

In [20]:
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=5.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2259.0, style=ProgressStyle(description_w…

{"loss": 2.9176381278038024, "learning_rate": 4.778663125276671e-05, "epoch": 0.2213368747233289, "step": 500}
{"loss": 2.6581398136615753, "learning_rate": 4.557326250553343e-05, "epoch": 0.4426737494466578, "step": 1000}
{"loss": 2.5586014343500136, "learning_rate": 4.335989375830013e-05, "epoch": 0.6640106241699867, "step": 1500}
{"loss": 2.436116787314415, "learning_rate": 4.114652501106684e-05, "epoch": 0.8853474988933157, "step": 2000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2259.0, style=ProgressStyle(description_w…

{"loss": 2.3573547369241714, "learning_rate": 3.893315626383356e-05, "epoch": 1.1066843736166445, "step": 2500}
{"loss": 2.2459749839305876, "learning_rate": 3.671978751660027e-05, "epoch": 1.3280212483399734, "step": 3000}
{"loss": 2.218323789358139, "learning_rate": 3.450641876936698e-05, "epoch": 1.5493581230633025, "step": 3500}
{"loss": 2.2082864822149277, "learning_rate": 3.229305002213369e-05, "epoch": 1.7706949977866313, "step": 4000}
{"loss": 2.141116299211979, "learning_rate": 3.00796812749004e-05, "epoch": 1.9920318725099602, "step": 4500}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2259.0, style=ProgressStyle(description_w…

{"loss": 2.041673380970955, "learning_rate": 2.786631252766711e-05, "epoch": 2.213368747233289, "step": 5000}
{"loss": 2.007178209066391, "learning_rate": 2.565294378043382e-05, "epoch": 2.434705621956618, "step": 5500}
{"loss": 2.0318748190402984, "learning_rate": 2.3439575033200534e-05, "epoch": 2.6560424966799467, "step": 6000}
{"loss": 1.9773010560274125, "learning_rate": 2.1226206285967244e-05, "epoch": 2.8773793714032756, "step": 6500}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2259.0, style=ProgressStyle(description_w…

{"loss": 1.9445158489942551, "learning_rate": 1.9012837538733954e-05, "epoch": 3.098716246126605, "step": 7000}
{"loss": 1.8315122603178025, "learning_rate": 1.6799468791500664e-05, "epoch": 3.3200531208499338, "step": 7500}
{"loss": 1.8183673046827316, "learning_rate": 1.4586100044267376e-05, "epoch": 3.5413899955732626, "step": 8000}
{"loss": 1.8523061389923097, "learning_rate": 1.2372731297034086e-05, "epoch": 3.7627268702965915, "step": 8500}
{"loss": 1.8157472529411316, "learning_rate": 1.0159362549800798e-05, "epoch": 3.9840637450199203, "step": 9000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2259.0, style=ProgressStyle(description_w…

{"loss": 1.7939044281244279, "learning_rate": 7.945993802567508e-06, "epoch": 4.20540061974325, "step": 9500}
{"loss": 1.765317411363125, "learning_rate": 5.732625055334219e-06, "epoch": 4.426737494466578, "step": 10000}




{"loss": 1.7384810733795166, "learning_rate": 3.51925630810093e-06, "epoch": 4.648074369189907, "step": 10500}
{"loss": 1.7090454839468003, "learning_rate": 1.3058875608676407e-06, "epoch": 4.869411243913236, "step": 11000}




TrainOutput(global_step=11295, training_loss=2.084379817630507)

In [21]:
trainer.save_model("models/COVID-scibert-latest")

### Part 4: Pipeline the Model for mask filling

In [22]:
model = transformers.AutoModelWithLMHead.from_pretrained('models/COVID-scibert-latest')

In [23]:
tokenizer = transformers.AutoTokenizer.from_pretrained("models/COVID-scibert-latest")

In [24]:
nlp_fill = transformers.pipeline('fill-mask', model = model, tokenizer = tokenizer)
nlp_fill('Coronavirus or COVID-19 can be prevented by a' + nlp_fill.tokenizer.mask_token)

[{'sequence': '[CLS] coronavirus or covid - 19 can be prevented by a combination [SEP]',
  'score': 0.1719885915517807,
  'token': 2702},
 {'sequence': '[CLS] coronavirus or covid - 19 can be prevented by a simple [SEP]',
  'score': 0.054218728095293045,
  'token': 2177},
 {'sequence': '[CLS] coronavirus or covid - 19 can be prevented by a novel [SEP]',
  'score': 0.043364267796278,
  'token': 3045},
 {'sequence': '[CLS] coronavirus or covid - 19 can be prevented by a high [SEP]',
  'score': 0.03732519596815109,
  'token': 597},
 {'sequence': '[CLS] coronavirus or covid - 19 can be prevented by a vaccine [SEP]',
  'score': 0.021863549947738647,
  'token': 7039}]