# Questions: 
In appendix A.1 of the BERT paper, the masking procedure is described. 15% of the time the tokens are masked (meaning we process these tokens in some way, and perform gradient descent on the model's predictive loss on this token). However, masking doesn't always mean we replace the token with [MASK]. Of these 15% of cases, 80% of the time the word is replaced with [MASK], 10% of the time it is replaced with a random word, and 10% of the time it is kept unchanged. This is sometimes referred to as the 80-10-10 rule.
Why is this used?

On the one hand the model learns to understand text by filling out masks in the sentence. In a sentence like " a MASK jumps over the fence " it will learn which words can possibly be used here. on the other hand we don't want the model to just copy all of the other tokens but it essentially has to learn to spot strange words that dont't belong in the sentence. I assume gradient descent is only applied to the 15% "masked" tokens, so it also should sometimes get punished if it incorrectly changes a word, i.e. if it assumes a word was replaced by a random word. that is why gradient descent is also applied to unchanged tokens.

In [24]:
# ! pip install transformers
# ! pip install wandb
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w2d2/utils.py
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w2d2/solutions_build_bert.py
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w2d2/functions_from_previous_days.py
# ! wget https://raw.githubusercontent.com/DalasNoin/arena/main/w2/shakespeare.py
# ! wget https://raw.githubusercontent.com/DalasNoin/arena/main/w2/sampling.py
# ! pip install torchinfo

import torch as t
import torch
from torch import nn
from torch.nn import GELU, Softmax
from dataclasses import dataclass
import transformers
import utils
from simon_utils import TransformerConfig
import matplotlib
from typing import Optional


device = "cuda" if torch.cuda.is_available() else "cpu"

bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [25]:

config = TransformerConfig(
    num_layers = 12,    # check bert.bert
    num_heads = 12,     # check bert.bert
    vocab_size = 28996, # tokenizer.vocab_size
    hidden_size = 768,  # check bert.bert
    max_seq_len = 512,  # bert.bert.embeddings.position_embeddings
    dropout = 0.1,      # check bert.bert
    layer_norm_epsilon = 1e-12, # check bert.bert
    device=device
)

![alt text](bert.png "Title")

In [26]:
class MLP(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.dropout = config.dropout
        self.mlp_block = nn.Sequential(
            nn.Linear(self.hidden_size, 4*self.hidden_size),
            GELU(),
            nn.Linear(4*self.hidden_size, self.hidden_size),
            nn.Dropout(self.dropout)
        )
    def forward(self, x: torch.Tensor):
        return self.mlp_block(x)

class MultiheadAttention(nn.Module):

    def __init__(self, config: TransformerConfig):
        super.__init__()

    def forward(self, x: torch.Tensor, additive_attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
        pass 


class BERTBlock(nn.Module):

    def __init__(self, config):
        super.__init__()
        self.attention = MultiheadAttention(config=config)
        self.layernorm1 = nn.LayerNorm(normalized_shape=(config.hidden_size,))
        self.mlp = MLP(config=config)
        self.layernorm2 = nn.LayerNorm(normalized_shape=(config.hidden_size,))

    def forward(self, x: torch.Tensor, additive_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        '''
        x: shape (batch, seq, hidden_size)
        additive_attention_mask: shape (batch, nheads=1, seqQ=1, seqK)
        '''
        x += self.attention(x)
        x = self.layernorm1(x)
        x += self.mlp(x)
        return self.layernorm2(x)



def make_additive_attention_mask(one_zero_attention_mask: torch.Tensor, big_negative_number: float = -10000) -> torch.Tensor:
    '''
    one_zero_attention_mask: 
        shape (batch, seq)
        Contains 1 if this is a valid token and 0 if it is a padding token.

    big_negative_number:
        Any negative number large enough in magnitude that exp(big_negative_number) is 0.0 for the floating point precision used.

    Out: 
        shape (batch, nheads=1, seqQ=1, seqK)
        Contains 0 if attention is allowed, big_negative_number if not.
    '''
    return (big_negative_number * ~one_zero_attention_mask)[:, None, None, :]

# utils.test_make_additive_attention_mask(make_additive_attention_mask)

In [27]:
class BertCommon(nn.Module):

    def __init__(self, config: TransformerConfig):
        super.__init__()
        self.token_embedding = nn.Embedding(num_embedding=config.vocab_size, embedding_dim=config.hidden_size)
        self.position_embedding = nn.Embedding(num_embeddings=config.max_seq_len, embedding_dim=config.hidden_size)
        self.token_type_embedding = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size)
        self.layernorm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(p=config.dropout)
        blocks = [BERTBlock(config) for _ in range(config.num_layers)]
        self.bert_blocks = nn.Sequential(*blocks)

    def forward(
        self,
        x: t.Tensor,
        one_zero_attention_mask: Optional[t.Tensor] = None,
        token_type_ids: Optional[t.Tensor] = None,
    ) -> t.Tensor:
        '''
        input_ids: (batch, seq) - the token ids
        one_zero_attention_mask: (batch, seq) - only used in training, passed to `make_additive_attention_mask` and used in the attention blocks.
        token_type_ids: (batch, seq) - only used for NSP, passed to token type embedding.
        '''

        attention_mask = None
        if one_zero_attention_mask:
            attention_mask = make_additive_attention_mask(one_zero_attention_mask=one_zero_attention_mask)
        
        positions = t.arange(x.shape[1], device=self.config.device)

        if token_type_ids is None:
            token_type_ids = t.zeros(x.shape[1], device=self.config.device)
        
        x = self.token_embedding(x) + self.position_embedding(positions) + self.token_type_embedding(token_type_ids)

        x = self.layernorm(x)
        x = self.dropout(x)

        return self.bert_blocks(x)




class BertLanguageModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert_common = BertCommon(config)
        self.linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.tied_unembed_bias = nn.Parameter(t.zeros(config.vocab_size))

    def forward(
        self,
        input_ids: t.Tensor,
        one_zero_attention_mask: Optional[t.Tensor] = None,
        token_type_ids: Optional[t.Tensor] = None,
    ) -> torch.Tensor:
        pass


## load and test

In [None]:
def copy_weights_from_bert(my_bert: BertLanguageModel, bert: transformers.models.bert.modeling_bert.BertForMaskedLM) -> BertLanguageModel:
    '''
    Copy over the weights from bert to your implementation of bert.

    bert should be imported using: 
        bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")

    Returns your bert model, with weights loaded in.
    '''

    # FILL IN CODE: define a state dict from my_bert.named_parameters() and bert.named_parameters()

    my_bert.load_state_dict(state_dict)
    return my_bert

In [None]:
def predict(model, tokenizer, text: str, k=15) -> List[List[str]]:
    '''
    Return a list of k strings for each [MASK] in the input.
    '''
    pass

def test_bert_prediction(predict, model, tokenizer):
    '''Your Bert should know some names of American presidents.'''
    text = "Former President of the United States of America, George[MASK][MASK]"
    predictions = predict(model, tokenizer, text)
    print(f"Prompt: {text}")
    print("Model predicted: \n", "\n".join(map(str, predictions)))
    assert "Washington" in predictions[0]
    assert "Bush" in predictions[0]

test_bert_prediction(predict, my_bert, tokenizer)