In [66]:
from transformers import BertTokenizer, DataCollatorForWholeWordMask
import torch

In [96]:
tok = BertTokenizer.from_pretrained('/Users/americanthinker1/NationalSecurityBERT/Preprocessing/Tokenization/wp-vocab-30500-vocab.txt', 
                                    return_special_tokens_mask=True, 
                                    max_len=512)

In [97]:
data = '/Users/americanthinker1/aws_data/processed_data/processed_chunks/english_docs_aa.txt'

In [10]:
with open(data) as f:
    lines = [line for line in f.read().splitlines()]

In [29]:
text = ' '.join(' '.join(lines[0].split()[:512]).split('.')[:-1]).strip()

In [160]:
batch = tok(lines[:10], 
            max_length=512, 
            padding='max_length', 
            truncation=True, 
            return_token_type_ids=False,
            return_tensors='pt')

In [180]:
def mlm_pipe(batch: object, mlm_prob=0.15) -> dict:
    '''
    Given a batch of encodings, return masked inputs and associated arrays.
    '''
    
    labels = batch.input_ids
    mask = batch.attention_mask
    input_ids = labels.detach().clone()
    
    #default masking prob = 15%, don't mask special tokens 
    mask_arr = (rand < mlm_prob) * (input_ids > 4)
    for i in range(input_ids.shape[0]):
        selection = torch.flatten(mask_arr[i].nonzero()).tolist()
        input_ids[i, selection] = 4
    encodings = {'input_ids': input_ids, 'attention_mask': mask, 'labels': labels}
    return encodings

In [181]:
encs = mlm_pipe(batch)

In [186]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
        
    def __len__(self):
        return self.encodings['input_ids'].shape[0]
    
    def __getitem__(self, i):
        return {key : tensor[i] for key, tensor in self.encodings.items()}

In [199]:
dataset = Dataset(encs)

In [204]:
loader = torch.utils.data.DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True)

In [207]:
from transformers import BertConfig

In [212]:
config = BertConfig(
        vocab_size=tok.vocab_size
    )

In [213]:
from transformers import BertForMaskedLM
model = BertForMaskedLM(config)