In [None]:
# %%
"""
# Week 2 Day 1 - Build Your Own BERT
Today you'll implement your own BERT model such that it can load the weights from the actual BERT and predict some masked tokens.
Reading: [Language Modelling with Transformers](https://docs.google.com/document/d/1XJQT8PJYzvL0CLacctWcT0T5NfL7dwlCiIqRtdTcIqA/edit#)
Reading: [BERT Paper, Section 3.1 "Pre-Training BERT"](https://arxiv.org/pdf/1810.04805.pdf)
Refer to the below schematic for the architecture of BERT. You can ignore the classification head

See here for BERT architecture:
https://i.imgur.com/2ekVyly.png

If this is too hard, I'd recommend
(foundations) getting familiar with classes here https://realpython.com/python3-object-oriented-programming/
(building on this) understand what the point of modules is: https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html (classes, that can be nested, and always implement a forward method, that takes in a tensor and returns a tensor)

"""

In [None]:
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the EasyTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
import os
os.system("pip install transformers einops fancy_einsum")

This cell just makes utilities for tests

In [None]:
import tempfile
import os
import time
import torch as t
import transformers
import requests
import logging
import http
from functools import wraps
from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from typing import Optional




def load_pretrained_gpt() -> GPT2LMHeadModel:
    """Load the HuggingFace GPT-2.

    On first use this downloads about 500MB from the Internet.
    Later uses should hit the cache and take under 1s to load.
    """
    return transformers.AutoModelForCausalLM.from_pretrained("gpt2")



def load_pretrained_bert() -> BertForMaskedLM:
    """Load the HuggingFace BERT.

    Supresses the spurious warning about some weights not being used.
    """
    logger = logging.getLogger("transformers.modeling_utils")
    was_disabled = logger.disabled
    logger.disabled = True
    bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
    logger.disabled = was_disabled
    return bert


def assert_all_equal(actual: t.Tensor, expected: t.Tensor):
    mask = actual == expected
    if not mask.all():
        bad = mask.nonzero()
        msg = f"Did not match at {len(bad)} indexes: {bad[:10]}{'...' if len(bad) > 10 else ''}"
        raise AssertionError(f"{msg}\nActual:\n{actual}\nExpected:\n{expected}")


def test_is_equal(actual: t.Tensor, expected: t.Tensor, test_name: str):
    try:
        run_and_report(assert_all_equal, test_name, actual, expected)
    except AssertionError as e:
        print(f"Test failed: {test_name}")
        raise e


def assert_shape_equal(actual: t.Tensor, expected: t.Tensor):
    if actual.shape != expected.shape:
        raise AssertionError(f"expected shape={expected.shape}, got {actual.shape}")


def allclose(actual: t.Tensor, expected: t.Tensor, rtol=1e-4, atol: Optional[float] = None) -> None:
    if (rtol is None) == (atol is None):
        raise Exception("This version of allclose expects exactly one of rtol and atol")

    assert_shape_equal(actual, expected)

    left = (actual - expected).abs()
    if rtol is not None:
        right = rtol * expected.abs()
        pct_wrong = int(100 * (left > right).float().mean())
    elif atol is not None:
        pct_wrong = int(100 * (left > atol).float().mean())
    else:
        raise Exception("Bad arguments")

    if pct_wrong > 0:
        print(f"Test failed. Max absolute deviation: {left.max()}")
        print(f"Actual:\n{actual}\nExpected:\n{expected}")
        raise AssertionError(f"allclose failed with {pct_wrong} percent of entries outside tolerance")


def allclose_atol(actual: t.Tensor, expected: t.Tensor, atol: float) -> None:
    assert_shape_equal(actual, expected)
    left = (actual - expected).abs()
    pct_wrong = int(100 * (left > atol).float().mean())
    if pct_wrong > 0:
        print(f"Test failed. Max absolute deviation: {left.max()}")
        print(f"Actual:\n{actual}\nExpected:\n{expected}")
        raise AssertionError(f"allclose failed with {pct_wrong} percent of entries outside tolerance")


def allclose_scalar(actual: float, expected: float, rtol=1e-4, atol: Optional[float] = None) -> None:
    if (rtol is None) == (atol is None):
        raise Exception("This version of allclose expects exactly one of rtol and atol")
    left = abs(actual - expected)
    if rtol is not None:
        right = rtol * abs(expected)
        wrong = left > right
    elif atol is not None:
        wrong = left > atol
    else:
        raise Exception("Bad arguments")

    if wrong:
        print(f"Test failed. Absolute deviation: {left}")
        print(f"Actual:\n{actual}\nExpected:\n{expected}")


def report_success(testname):
    """POST to the server indicating success at the given test.

    Used to help the TAs know how long each section takes to complete.
    """
    server = os.environ.get("MLAB_SERVER")
    email = os.environ.get("MLAB_EMAIL")
    if server:
        if email:
            r = requests.post(
                server + "/api/report_success",
                json=dict(email=email, testname=testname),
            )
            if r.status_code != http.HTTPStatus.NO_CONTENT:
                raise ValueError(f"Got status code from server: {r.status_code}")
        else:
            raise ValueError(f"Server set to {server} but no MLAB_EMAIL set!")
    else:
        if email:
            raise ValueError(f"Email set to {email} but no MLAB_SERVER set!")
        else:
            return  # local dev, do nothing


# Map from qualified name "test_w2d3.test_unidirectional_attn" to whether this test was passed in the current interpreter session
# Note this can get clobbered during autoreload
TEST_FN_PASSED = {}


def report(test_func):
    name = f"{test_func.__module__}.{test_func.__name__}"
    # This can happen when using autoreload, so don't complain about it.
    # if name in TEST_FN_PASSED:
    #     raise KeyError(f"Already registered: {name}")
    TEST_FN_PASSED[name] = False

    @wraps(test_func)
    def wrapper(*args, **kwargs):
        return run_and_report(test_func, name, *args, **kwargs)

    return wrapper


def run_and_report(test_func, name, *test_func_args, **test_func_kwargs):
    start = time.time()
    out = test_func(*test_func_args, **test_func_kwargs)
    elapsed = time.time() - start
    print(f"{name} passed in {elapsed:.2f}s.")
    if not TEST_FN_PASSED.get(name):
        report_success(name)
        TEST_FN_PASSED[name] = True
    return out


def remove_hooks(module: t.nn.Module):
    """Remove all hooks from module.

    Use module.apply(remove_hooks) to do this recursively.
    """
    module._backward_hooks.clear()
    module._forward_hooks.clear()
    module._forward_pre_hooks.clear()

In [None]:
Setup the BERT objects

In [None]:
from dataclasses import dataclass
from typing import List, Optional, Union


import torch as t
import transformers
from einops import rearrange, repeat
from fancy_einsum import einsum
from torch import nn
from torch.nn import functional as F

@dataclass(frozen=True)
class BertConfig:
    """Constants used throughout the Bert model. Most are self-explanatory.
    intermediate_size is the number of hidden neurons in the MLP (see schematic)
    type_vocab_size is only used for pretraining on "next sentence prediction", which we aren't doing.
    """

    vocab_size: int = 28996
    intermediate_size: int = 3072
    hidden_size: int = 768
    num_layers: int = 12
    num_heads: int = 12
    max_position_embeddings: int = 512
    dropout: float = 0.1
    type_vocab_size: int = 2
    layer_norm_epsilon: float = 1e-12


config = BertConfig()


@dataclass
class BertOutput:
    """The output of your Bert model.
    logits is used for W2D1 and is the prediction for each token in the vocabulary.
    The other fields are used on W2D2 for the sentiment task.
    """

    logits: Optional[t.Tensor] = None
    is_positive: Optional[t.Tensor] = None
    star_rating: Optional[t.Tensor] = None



Now the exercises begin

In [None]:

# %%
"""
# Embedding`
Implement your version of PyTorch's `nn.Embedding` module. The PyTorch version has some extra options in the constructor, but you don't need to implement those since BERT doesn't use them.
The `Parameter` should be named `weight` and initialized with normally distributed random values.
"""
# %%
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__()
        self.weight = nn.Parameter(t.randn(num_embeddings, embedding_dim))

    def forward(self, x: t.LongTensor) -> t.Tensor:
        """For each integer in the input, return that row of the embedding.
        Don't convert x to one-hot vectors - this works but is too slow.
        """
        pass

def test_embedding(Embedding):
    """Indexing into the embedding should fetch the corresponding rows of the embedding."""
    emb = Embedding(6, 10)
    out = emb(t.LongTensor([1, 3, 5]))
    allclose(out[0], emb.weight[1])
    allclose(out[1], emb.weight[3])
    allclose(out[2], emb.weight[5])

test_embedding(Embedding)



In [None]:

# %%
"""
# Layer Normalization
Use the ([PyTorch docs](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)) for Layer Normalization to implement your own version which exactly mimics the official API. Use the biased estimator for Var[x] as shown in the docs. You can assume elementwise affine is always True.
"""
# %%
class LayerNorm(nn.Module):
    def __init__(
        self, normalized_shape: Union[int, tuple, t.Size], eps=1e-5, elementwise_affine=True, device=None, dtype=None
    ):
        super().__init__()
        self.weight = None
        self.bias = None # implement these

    def forward(self, x: t.Tensor):
        """x and the output should both have shape (N, *)."""
        pass

def test_layernorm_mean_1d(LayerNorm):
    """If an integer is passed, this means normalize over the last dimension which should have that size."""
    x = t.randn(20, 10)
    ln1 = LayerNorm(10)
    out = ln1(x)
    max_mean = out.mean(-1).abs().max().item()
    assert max_mean < 1e-5, f"Normalized mean should be about 0, got {max_mean}"



def test_layernorm_mean_2d(LayerNorm):
    """If normalized_shape is 2D, should normalize over both the last two dimensions."""
    x = t.randn(20, 10)
    ln1 = LayerNorm((20, 10))
    out = ln1(x)
    max_mean = out.mean((-1, -2)).abs().max().item()
    assert max_mean < 1e-5, f"Normalized mean should be about 0, got {max_mean}"



def test_layernorm_std(LayerNorm):
    """If epsilon is small enough and no elementwise_affine, the output variance should be very close to 1."""
    x = t.randn(20, 10)
    ln1 = LayerNorm(10, eps=1e-11, elementwise_affine=False)
    out = ln1(x)
    var_diff = (1 - out.var(-1, unbiased=False)).abs().max().item()
    assert var_diff < 1e-6, f"Var should be about 1, off by {var_diff}"



def test_layernorm_exact(LayerNorm):
    """Your LayerNorm's output should exactly match PyTorch for equal epsilon."""
    x = t.randn(2, 3, 4, 5)
    # Use large epsilon to make sure it fails if they forget it
    ln1 = LayerNorm((5,), eps=1e-2)
    ln2 = t.nn.LayerNorm((5,), eps=1e-2)  # type: ignore
    actual = ln1(x)
    expected = ln2(x)
    allclose(actual, expected)

def test_layernorm_backward(LayerNorm):
    """The backwards pass should also match PyTorch exactly."""
    x = t.randn(10, 3)
    x2 = x.clone()
    x.requires_grad_(True)
    x2.requires_grad_(True)

    # Without parameters, should be deterministic
    ref = nn.LayerNorm(3, elementwise_affine=False)
    ref.requires_grad_(True)
    ref(x).sum().backward()

    ln = LayerNorm(3, elementwise_affine=False)
    ln.requires_grad_(True)
    ln(x2).sum().backward()
    # Use atol since grad entries are supposed to be zero here
    assert isinstance(x.grad, t.Tensor)
    assert isinstance(x2.grad, t.Tensor)
    allclose_atol(x.grad, x2.grad, atol=1e-5)

test_layernorm_mean_1d(LayerNorm)
test_layernorm_mean_2d(LayerNorm)
test_layernorm_std(LayerNorm)
test_layernorm_exact(LayerNorm)
test_layernorm_backward(LayerNorm) # optional extra test for the backward pass - we aren't doing training so don't worry if it fails


In [None]:

# %%
"""
# BertMLP
Make the MLP block, following the schematic. Use `nn.Dropout` for the dropout layer.
"""
# %%

class BertMLP(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        pass # refer to the diagram!!!

    def forward(self, x: t.Tensor) -> t.Tensor:
        pass


def test_bert_mlp(BertMLP, batch_size=2, seq_len=5, hidden_size=6, dropout=0.0):
    """The MLP's output should exactly match the reference solution.
    Dropout is not tested.
    """

    config = BertConfig(
        hidden_size=hidden_size,
        intermediate_size=3 * hidden_size,
        dropout=dropout,
    )
    x = t.randn(batch_size, seq_len, hidden_size)

    t.manual_seed(988)
    ref = BertMLP(config)
    expected = ref(x)

    t.manual_seed(988)
    yours = BertMLP(config)
    actual = yours(x)

    allclose(actual, expected)

test_bert_mlp(BertMLP, dropout=0.5)



In [None]:

# %%
"""
# Batched Self-Attention
We're going to implement a version of self-attention that computes all sequences in a batch at once, and all heads at once. Make sure you understand how single sequence, single head attention works first.
# Attention Pattern Pre-Softmax
Spend at least 5 minutes thinking about how to batch the computation before looking at the spoilers.
<details>
<summary>What should the shape of `project_query` be?</summary>
`project_query` should go from `hidden_size` to `num_heads * self.head_size` which in this case is equal to `hidden_size`. This represents all the heads's Q matrices concatenated together, and one call to it now computes all the queries at once (broadcasting over the leading batch and seq dimensions of the input x).
</details>
<details>
<summary>Should my Linear layers have a bias?</summary>
While these Linear layers are traditionally referred to as projections, in BERT they DO have a bias.
</details>
<details>
<summary>What does the einsum to make the attention pattern look like?</summary>
We need to sum out the head_size and keep the seq_q dimension before the seq_k dimension. For a single batch and single head, it would be: `einsum("seq_q head_size, seq_k head_size -> seq_q seq_k")`. You'll want to do a `rearrange` before your `einsum`.
</details>
<details>
<summary>Which dimension do I softmax over?</summary>
The desired property is that `pattern[batch,head,q]` sums to 1 for all `q`. So the softmax needs to be over the `k` dimension.
</details>
<details>
<summary>I'm still confused about how to batch the computation.</summary>
## Pre Softmax
- Apply Q, K, and V to the input x
- rearrange Q and K to split the `hidden_size` dimension apart into heads and head_size dimensions.
- Einsum Q and K to get a (batch, head, seq_q, seq_k) shape. 
- Divide by the square root of the head size.
## Forward
- Softmax over the `k` dimension to obtain attention probs
- rearrange V just like Q and K previously
- einsum V and your attention probs to get the weighted Vs
- rearrange weighted Vs to combine head and head_size and put that at the end
- apply O
</details>
# Attention Forward Function
Your forward should call `attention_pattern_pre_softmax` and then finish the computations using `einsum` and `rearrange` again. Remember to apply the output projection.
"""


class BertSelfAttention(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        self.num_heads = config.num_heads
        assert config.hidden_size % config.num_heads == 0
        self.head_size = config.hidden_size // config.num_heads
        self.project_query = nn.Linear(config.hidden_size, config.num_heads * self.head_size)
        self.project_key = nn.Linear(config.hidden_size, config.num_heads * self.head_size)
        self.project_value = nn.Linear(config.hidden_size, config.num_heads * self.head_size)
        self.project_output = nn.Linear(config.num_heads * self.head_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)

    def attention_pattern_pre_softmax(self, x: t.Tensor) -> t.Tensor:
        """Return the attention pattern after scaling but before softmax.
        pattern[batch, head, q, k] should be the match between a query at sequence position q and a key at sequence position k.
        """
        pass

    def forward(self, x: t.Tensor) -> t.Tensor:
        pass


def test_attention_pattern_pre_softmax(BertSelfAttention, batch_size=2, seq_len=5, hidden_size=6, num_heads=2):
    """The attention pattern should exactly match the reference solution."""

    config = BertConfig(hidden_size=hidden_size, num_heads=num_heads)

    x = t.randn(batch_size, seq_len, hidden_size)

    t.manual_seed(987)
    ref = BertSelfAttention(config)
    expected = ref.attention_pattern_pre_softmax(x)

    t.manual_seed(987)
    yours = BertSelfAttention(config)
    actual = yours.attention_pattern_pre_softmax(x)

    allclose(actual, expected)



def test_attention(BertSelfAttention, batch_size=2, seq_len=5, hidden_size=6, num_heads=2):
    """The attention layer's output should exactly match the reference solution.

    Dropout is not tested!
    """

    config = BertConfig(hidden_size=hidden_size, num_heads=num_heads, dropout=0.0)
    x = t.randn(batch_size, seq_len, hidden_size)

    t.manual_seed(988)
    ref = BertSelfAttention(config)
    expected = ref(x)

    t.manual_seed(988)
    yours = BertSelfAttention(config)
    actual = yours(x)

    allclose(actual, expected)

test_attention_pattern_pre_softmax(BertSelfAttention)
test_attention(BertSelfAttention)


In [None]:

# %%
"""
# Bert Block
Assemble the BertAttention and BertBlock classes following the schematic.
"""
# %%
class BertAttention(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        pass

    def forward(self, x: t.Tensor) -> t.Tensor:
        pass


class BertBlock(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        pass

    def forward(self, x: t.Tensor) -> t.Tensor:
        pass


def test_bert_block(BertBlock):
    """Your BertBlock should exactly match the reference solution in eval mode.

    Dropout is not tested.
    """

    config = BertConfig()
    t.random.manual_seed(0)
    reference = BertBlock(config)
    reference.eval()
    t.random.manual_seed(0)
    theirs = BertBlock(config)
    theirs.eval()
    input_activations = t.rand((2, 3, 768))
    allclose(theirs(input_activations), reference(input_activations))

test_bert_block(BertBlock)



In [None]:

# %%
"""
# Putting it All Together
Now fill in the entire Bert module, following the schematic. Tips:
- The language modelling `Linear` after the blocks has shape `(embedding_size, embedding_size)`
- If `token_type_ids` isn't provided to `forward`, make it the same shape as `input_ids` but all zeros.
- The unembedding at the end that takes data from `hidden_size` to `vocab_size` shouldn't be its own Linear layer because it shares the same data as `token_embedding.weight`. Just reuse `token_embedding.weight` and add a bias term.
- Print your model out to see if it resembles the schematic.
"""
# %%
class Bert(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        pass

    def forward(self, input_ids, token_type_ids=None) -> BertOutput:
        pass

def test_bert(your_module):
    """Your full Bert should exactly match the reference solution in eval mode.

    Dropout is not tested.
    """

    config = BertConfig()
    t.random.manual_seed(0)
    reference = Bert(config)
    reference.eval()
    t.random.manual_seed(0)
    theirs = your_module(config)
    theirs.eval()
    input_ids = t.LongTensor([[101, 1309, 6100, 1660, 1128, 1146, 102]])
    allclose(theirs(input_ids=input_ids).logits, reference(input_ids=input_ids).logits)


test_bert(Bert)



In [None]:

# %%
"""
# Loading Pretrained Weights 
Now copy parameters from the pretrained BERT returned by `load_pretrained_bert()` into your BERT.
This is somewhat tedious, but is representative of real ML work. Race yourself to see if you can do it more quickly than last time!
Remember that the embedding and unembedding weights are tied, so `hf_bert.bert.embeddings.word_embeddings.weight` and `hf_bert.cls.predictions.decoder.weight` should be equal and you should only use one of them.
You can look at the solution if you get frustrated.
<details>
<summary>I'm confused about my Parameter not being a leaf!</summary>
When you copied data from the HuggingFace version, PyTorch tracked the history of the copy operation. This means if you were to call `backward`, it would try to backpropagate through your Parameter back to the HuggingFace version, which is not what we want.
To fix this, you can call `detach()` to make a new tensor that shares storage with the original doesn't have any history.
</details>
"""
# %%
def load_pretrained_weights(config: BertConfig) -> Bert:
    hf_bert = load_pretrained_bert()
    pass

my_bert = load_pretrained_weights(config)

for name, p in my_bert.named_parameters():
    assert (
        p.is_leaf
    ), "Parameter {name} is not a leaf node, which will cause problems in training. Try adding detach() somewhere."



In [None]:

# %%
"""
# Tokenization
We're going to use a HuggingFace tokenizer for now to encode text into a sequence of tokens that our model can use. The tokenizer has to match the model - our model was trained with the `bert-base-cased` tokenizer which is case-sensitive. If you tried to use the `bert-base-uncased` tokenizer which is case-insensitive, it wouldn't work at all.
Use `transformers.AutoTokenizer.from_pretrained` to fetch the appropriate tokenizer and try encoding and decoding some text.
## Vocabulary
Check out `tokenizer.vocab` to get an idea of what sorts of strings are assigned to tokens. In WordPiece, tokens represent a whole word unless they start with `##`, which denotes this token is part of a word. 
## Special Tokens
Check out `tokenizer.special_tokens_map`. The strings here are mapped to tokens which have special meaning - for example `tokenizer.mask_token` which is the literal string '[MASK]' is converted to `tokenizer.mask_token_id` which is 103.
## Predicting Masked Tokens
Write the `predict` function which takes a string with one or more instances of the substring '[MASK]', runs it through your model, finds the top K predictions and decodes each prediction.
Tips:
- `torch.topk` is useful
- The model should be in evaluation mode for predictions - this disables dropout and makes the predictions deterministic.
- If your model gives different predictions than the HuggingFace section, proceed to the next section on debugging.
"""
# %%
def predict(model: Bert, tokenizer, text: str, k=15) -> List[List[str]]:
    """
    Return a list of k strings for each [MASK] in the input.
    """
    pass


def test_bert_prediction(predict, model, tokenizer):
    """Your Bert should know some names of American presidents."""
    text = "Former President of the United States of America, George[MASK][MASK]"
    predictions = predict(model, tokenizer, text)
    print(f"Prompt: {text}")
    print("Model predicted: \n", "\n".join(map(str, predictions)))
    assert "Washington" in predictions[0]
    assert "Bush" in predictions[0]

tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
test_bert_prediction(predict, my_bert, tokenizer)

In [None]:
# now play with your BERT!!!

In [None]:
your_text = "The Answer to the Ultimate Question of Life, The Universe, and Everything is [MASK]."
predictions = predict(my_bert, tokenizer, your_text)
print("Model predicted: \n", "\n".join(map(str, predictions)))


# %%
"""
# Model Debugging
If your model works correctly at this point then congratulations, you can skip this section. 
The challenge with debugging ML code is that it often silently computes the wrong result instead of erroring out. Some things you can check:
- Do I have any square matrices transposed, so the shapes still match but they do the wrong thing?
- Did I forget to pass any optional arguments, and the wrong default is being used?
- If I `print` my model, do the layers look right?
- Can I add asserts in my code to check assumptions that I've made? In particular, sometimes unintentional broadcasting creates outputs of the wrong shape.
You won't always have a reference implementation, but given that you do, a good technique is to use hooks to collect the inputs and outputs that should be identical, and compare when they start to diverge. This narrows down the number of places where you have to look for the bug.
Read the [documentation](https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html) for `register_forward_hook` on a `nn.Module` and try logging the input and output of each block on your model and the HuggingFace version.
"""
# %%

hf_bert = load_pretrained_bert()
hf_bert.apply(remove_hooks)
hf_bert.eval()
# this should load in a BERT that can be used for reference and debugging