In [None]:
# import torch
# from torch import nn
# from d2l import torch as d2l

# Model creation

In [3]:
import torch
from torch.nn import CrossEntropyLoss
from torch import nn
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 6
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

In [4]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

# Predict all tokens
predictions = model(tokens_tensor, segments_tensors)

# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

In [5]:
#model

In [6]:
model.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): BertLayerNorm()
    )
    (decoder): Linear(in_features=768, out_features=30522, bias=False)
  )
)

In [11]:
vocab_size = 30522
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = model(input_ids=tokens, token_type_ids=segments, masked_lm_labels=torch.randint(0, 1, (2, 8)))
print(encoded_X)
encoded_X = model(input_ids=tokens, token_type_ids=segments)
encoded_X.shape 

tensor(13.9482, grad_fn=<NllLossBackward0>)


torch.Size([2, 8, 30522])

In [16]:
tokens

tensor([[17448, 17381, 25836, 17989, 29736, 24292,  3180, 25951],
        [ 3730, 29404, 10609,  4483, 25146, 13973,  3167, 29150]])

## Changing head

In [10]:
class MaskLM(nn.Module):
    """The masked language model task of BERT."""
    def __init__(self, num_hiddens, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_hiddens, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, 1),
                                 nn.Sigmoid())

    def forward(self, X):
        mlm_Y_hat = self.mlp(X)
        return mlm_Y_hat

In [19]:
model.cls = MaskLM(768)

In [32]:
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
mlm_positions = torch.tensor([[0,0,0,1,0,0,0,1], [0,1,0,0,1,0,0,1]], dtype=torch.float32)
mlm_Y_hat = model(input_ids=tokens, token_type_ids=segments) #masked_lm_labels=mlm_positions)


In [33]:
mlm_Y_hat.shape, mlm_positions.shape

(torch.Size([2, 8, 1]), torch.Size([2, 8]))

In [34]:
mlm_positions.dtype

torch.float32

In [35]:
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
                         segments_X,
                         masked_lm_labels):
    # Forward pass
    mlm_Y_hat = net(input_ids=tokens_X, token_type_ids=segments_X)
    # loss_fct = CrossEntropyLoss(ignore_index=-1)
    masked_lm_loss = loss(mlm_Y_hat.view(-1), masked_lm_labels.view(-1))

    return masked_lm_loss

In [36]:
_get_batch_loss_bert(model, CrossEntropyLoss(ignore_index=-1), vocab_size, tokens, segments, mlm_positions)

tensor(13.8079, grad_fn=<DivBackward1>)

# Training loop TODO

In [26]:
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
    net(*next(iter(train_iter))[:4])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)
    step, timer = 0, d2l.Timer()
    animator = d2l.Animator(xlabel='step', ylabel='loss',
                            xlim=[1, num_steps], legend=['mlm', 'nsp'])
    # Sum of masked language modeling losses, sum of next sentence prediction
    # losses, no. of sentence pairs, count
    metric = d2l.Accumulator(4)
    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
            mlm_weights_X, mlm_Y, nsp_y in train_iter:
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_x = valid_lens_x.to(devices[0])
            pred_positions_X = pred_positions_X.to(devices[0])
            mlm_weights_X = mlm_weights_X.to(devices[0])
            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
            trainer.zero_grad()
            timer.start()
            mlm_l, nsp_l, l = _get_batch_loss_bert(
                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
            l.backward()
            trainer.step()
            metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
            timer.stop()
            animator.add(step + 1,
                         (metric[0] / metric[3], metric[1] / metric[3]))
            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[3]:.3f}, '
          f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(devices)}')

In [37]:
def gpu(i=0):
    """Defined in :numref:`sec_use_gpu`"""
    return torch.device(f'cuda:{i}')

In [38]:
def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists.
    Defined in :numref:`sec_use_gpu`"""
    return [gpu(i) for i in range(torch.cuda.device_count())]

In [39]:
devices = try_all_gpus()

In [40]:
devices

[device(type='cuda', index=0)]