In [1]:
!pip install einops
!pip install fancy_einsum

import os
from dataclasses import dataclass
from typing import List, Optional, Union
import torch as t
import transformers
from einops import rearrange, repeat
from fancy_einsum import einsum
from torch import nn
from torch.nn import functional as F
from typing import Optional, Iterator, cast, TypeVar, Generic, Callable

import tempfile
import os
import time
import torch as t
from torch import nn
import transformers
import joblib
import requests
import logging
from transformers.models.bert.modeling_bert import BertForMaskedLM
import http
from functools import wraps

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Installing collected packages: fancy_einsum
Successfully installed fancy_einsum-0.0.3


In [49]:
DEBUG_TOLERANCES = os.getenv("DEBUG_TOLERANCES")

def assert_all_equal(actual: t.Tensor, expected: t.Tensor) -> None:
    """Assert that actual and expected are exactly equal (to floating point precision)."""
    mask = actual == expected
    if not mask.all().item():
        bad = mask.nonzero()
        msg = f"Did not match at {len(bad)} indexes: {bad[:10]}{'...' if len(bad) > 10 else ''}"
        raise AssertionError(f"{msg}\nActual:\n{actual}\nExpected:\n{expected}")
        
    
def test_is_equal(actual: t.Tensor, expected: t.Tensor, test_name: str) -> None:
    try:
        run_and_report(assert_all_equal, test_name, actual, expected)
    except AssertionError as e:
        print(f"Test failed: {test_name}")
        raise e


def assert_shape_equal(actual: t.Tensor, expected: t.Tensor) -> None:
    if actual.shape != expected.shape:
        raise AssertionError(f"expected shape={expected.shape}, got {actual.shape}")
        
def allclose(actual: t.Tensor, expected: t.Tensor, rtol=1e-4) -> None:
    assert_shape_equal(actual, expected)
    left = (actual - expected).abs()
    right = rtol * expected.abs()
    num_wrong = (left > right).sum().item()
    if num_wrong > 0:
        print(f"Test failed. Max absolute deviation: {left.max()}")
        print(f"Actual:\n{actual}\nExpected:\n{expected}")
        raise AssertionError(f"allclose failed with {num_wrong} / {left.nelement()} entries outside tolerance")
    elif DEBUG_TOLERANCES:
        print(f"Test passed with max absolute deviation of {left.max()}")


def allclose_atol(actual: t.Tensor, expected: t.Tensor, atol: float) -> None:
    assert_shape_equal(actual, expected)
    left = (actual - expected).abs()
    num_wrong = (left > atol).sum().item()
    if num_wrong > 0:
        print(f"Test failed. Max absolute deviation: {left.max()}")
        print(f"Actual:\n{actual}\nExpected:\n{expected}")
        raise AssertionError(f"allclose failed with {num_wrong} / {left.nelement()} entries outside tolerance")
    elif DEBUG_TOLERANCES:
        print(f"Test passed with max absolute deviation of {left.max()}")
        


def allclose_scalar(actual: float, expected: float, rtol=1e-4) -> None:
    left = abs(actual - expected)
    right = rtol * abs(expected)
    wrong = left > right
    if wrong:
        raise AssertionError(f"Test failed. Absolute deviation: {left}\nActual:\n{actual}\nExpected:\n{expected}")
    elif DEBUG_TOLERANCES:
        print(f"Test passed with absolute deviation of {left}")


def allclose_scalar_atol(actual: float, expected: float, atol: float) -> None:
    left = abs(actual - expected)
    wrong = left > atol
    if wrong:
        raise AssertionError(f"Test failed. Absolute deviation: {left}\nActual:\n{actual}\nExpected:\n{expected}")
    elif DEBUG_TOLERANCES:
        print(f"Test passed with absolute deviation of {left}")
        
        

def report_success(testname):
    """POST to the server indicating success at the given test.

    Used to help the TAs know how long each section takes to complete.
    """
    server = os.environ.get("MLAB_SERVER")
    email = os.environ.get("MLAB_EMAIL")
    if server:
        if email:
            r = requests.post(
                server + "/api/report_success",
                json=dict(email=email, testname=testname),
            )
            if r.status_code != http.HTTPStatus.NO_CONTENT:
                raise ValueError(f"Got status code from server: {r.status_code}")
        else:
            raise ValueError(f"Server set to {server} but no MLAB_EMAIL set!")
    else:
        if email:
            raise ValueError(f"Email set to {email} but no MLAB_SERVER set!")
        else:
            return  # local dev, do nothing


# Map from qualified name "test_w2d3.test_unidirectional_attn" to whether this test was passed in the current interpreter session
# Note this can get clobbered during autoreload
TEST_FN_PASSED = {}


def report(test_func):
    name = f"{test_func.__module__}.{test_func.__name__}"
    # This can happen when using autoreload, so don't complain about it.
    # if name in TEST_FN_PASSED:
    #     raise KeyError(f"Already registered: {name}")
    TEST_FN_PASSED[name] = False

    @wraps(test_func)
    def wrapper(*args, **kwargs):
        return run_and_report(test_func, name, *args, **kwargs)

    return wrapper


def run_and_report(test_func: Callable, name: str, *test_func_args, **test_func_kwargs):
    start = time.time()
    out = test_func(*test_func_args, **test_func_kwargs)
    elapsed = time.time() - start
    print(f"{name} passed in {elapsed:.2f}s.")
    if not TEST_FN_PASSED.get(name):
        report_success(name)
        TEST_FN_PASSED[name] = True
    return out


@report
def test_layernorm_mean_1d(LayerNorm):
    """If an integer is passed, this means normalize over the last dimension which should have that size."""
    x = t.randn(20, 10)
    ln1 = LayerNorm(10)
    out = ln1(x)
    max_mean = out.mean(-1).abs().max().item()
    assert max_mean < 1e-5, f"Normalized mean should be about 0, got {max_mean}"


@report
def test_layernorm_mean_2d(LayerNorm):
    """If normalized_shape is 2D, should normalize over both the last two dimensions."""
    x = t.randn(20, 10)
    ln1 = LayerNorm((20, 10))
    out = ln1(x)
    max_mean = out.mean((-1, -2)).abs().max().item()
    assert max_mean < 1e-5, f"Normalized mean should be about 0, got {max_mean}"


@report
def test_layernorm_std(LayerNorm):
    """If epsilon is small enough and no elementwise_affine, the output variance should be very close to 1."""
    x = t.randn(20, 10)
    ln1 = LayerNorm(10, eps=1e-11, elementwise_affine=False)
    out = ln1(x)
    var_diff = (1 - out.var(-1, unbiased=False)).abs().max().item()
    assert var_diff < 1e-6, f"Var should be about 1, off by {var_diff}"


@report
def test_layernorm_exact(LayerNorm):
    """Your LayerNorm's output should match PyTorch for equal epsilon, up to floating point rounding error.

    This test uses float64 and the result should be extremely tight.
    """
    x = t.randn(2, 3, 4, 5, dtype=t.float64)
    # Use large epsilon to make sure it fails if they forget it
    ln1 = LayerNorm((5,), dtype=t.float64, eps=1e-2)
    ln2 = t.nn.LayerNorm((5,), dtype=t.float64, eps=1e-2)  # type: ignore
    actual = ln1(x)
    expected = ln2(x)
    allclose(actual, expected)


@report
def test_layernorm_backward(LayerNorm):
    """The backwards pass should also match PyTorch exactly."""
    x = t.randn(10, 3)
    x2 = x.clone()
    x.requires_grad_(True)
    x2.requires_grad_(True)

    # Without parameters, should be deterministic
    ref = nn.LayerNorm(3, elementwise_affine=False)
    ref.requires_grad_(True)
    ref(x).sum().backward()

    ln = LayerNorm(3, elementwise_affine=False)
    ln.requires_grad_(True)
    ln(x2).sum().backward()
    # Use atol since grad entries are supposed to be zero here
    assert isinstance(x.grad, t.Tensor)
    assert isinstance(x2.grad, t.Tensor)
    allclose_atol(x.grad, x2.grad, atol=1e-5)
    
@report
def test_embedding(Embedding):
    """Indexing into the embedding should fetch the corresponding rows of the embedding."""
    emb = Embedding(6, 100)
    out = emb(t.tensor([1, 3, 5], dtype=t.int64))
    allclose(out[0], emb.weight[1])
    allclose(out[1], emb.weight[3])
    allclose(out[2], emb.weight[5])


@report
def test_embedding_std(Embedding):
    """The standard deviation should be roughly 0.02."""
    t.manual_seed(5)
    emb = Embedding(6, 100)
    allclose_scalar_atol(emb.weight.std().item(), 0.02, atol=0.001)
    
T = TypeVar("T")
class StaticModuleList(nn.ModuleList, Generic[T]):
    """ModuleList where the user vouches that it only contains objects of type T.

    This allows the static checker to work instead of only knowing that the contents are Modules.
    """

    # TBD lowpri: is it possible to do this just with signatures, without actually overriding the method bodies to add a cast?

    def __getitem__(self, index: int) -> T:
        return cast(T, super().__getitem__(index))

    def __iter__(self) -> Iterator[T]:
        return cast(Iterator[T], iter(self._modules.values()))

    def __repr__(self):
        # CM: modified from t.nn.Module.__repr__
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split("\n")
        child_lines = []
        modules = iter(self._modules.items())
        key, module = next(modules)
        n_rest = sum(1 for _ in modules)
        mod_str = repr(module)
        mod_str = _addindent(mod_str, 2)
        child_lines.append("(" + key + "): " + mod_str)
        lines = extra_lines + child_lines + [f"+ {n_rest} more..."]

        main_str = self._get_name() + "("
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += "\n  " + "\n  ".join(lines) + "\n"

        main_str += ")"
        return main_str
    

    

mem = joblib.Memory(tempfile.gettempdir() + "/joblib_cache")
@mem.cache
def load_pretrained_bert() -> BertForMaskedLM:
    """Load the HuggingFace BERT.

    Supresses the spurious warning about some weights not being used.
    """
    logger = logging.getLogger("transformers.modeling_utils")
    was_disabled = logger.disabled
    logger.disabled = True
    bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
    logger.disabled = was_disabled
    return cast(BertForMaskedLM, bert)

@report
def test_bert_prediction(predict, model, tokenizer):
    """Your Bert should know some names of American presidents."""
    text = "Former President of the United States of America, George[MASK][MASK]"
    predictions = predict(model, tokenizer, text)
    print(f"Prompt: {text}")
    print("Model predicted: \n", "\n".join(map(str, predictions)))
    assert "Washington" in predictions[0]
    assert "Bush" in predictions[0]
    
def remove_hooks(module: t.nn.Module):
    """Remove all hooks from module.

    Use module.apply(remove_hooks) to do this recursively.
    """
    module._backward_hooks.clear()
    module._forward_hooks.clear()
    module._forward_pre_hooks.clear()

In [13]:
@dataclass(frozen = True)
class BertConfig:
    
    """Constants used throughout the Bert model. Most are self-explanatory.

    intermediate_size is the number of hidden neurons in the MLP (see schematic)
    type_vocab_size is only used for pretraining on "next sentence prediction", which we aren't doing.

    Note that the head size happens to be hidden_size // num_heads, but this isn't necessarily true and your code shouldn't assume it.
    """
    
    vocab_size: int = 28996
    intermediate_size: int = 3072
    hidden_size: int = 768
    num_layers: int = 12
    num_heads: int = 12
    head_size: int = 64
    max_position_embeddings: int = 512
    dropout: float = 0.1
    type_vocab_size: int = 2
    layer_norm_epsilon: float = 1e-12
        
config = BertConfig()

In [38]:
class BertSelfAttention(nn.Module):
    project_query: nn.Linear
    project_key: nn.Linear
    project_value: nn.Linear
    project_output: nn.Linear
        
    def __init__(self, config: BertConfig):
        super().__init__()
        self.config = config
        self.project_query = nn.Linear(config.hidden_size, config.num_heads * config.head_size)
        self.project_key = nn.Linear(config.hidden_size, config.num_heads * config.head_size)
        self.project_value = nn.Linear(config.hidden_size, config.num_heads * config.head_size)
        self.project_output = nn.Linear(config.num_heads * config.head_size, config.hidden_size)
    
    def attention_pattern_pre_softmax(self, x: t.Tensor) -> t.Tensor:
        """
        x: shape (batch, seq, hidden_size)
        Return the attention pattern after scaling but before softmax.

        pattern[batch, head, q, k] should be the match between a query at sequence position q and a key at sequence position k.
        """
        b, s, h = x.shape
        q = self.project_query(x)
        k = self.project_key(x)
        q = rearrange(q, 'b seq (head head_size) -> b head seq head_size', head = self.config.num_heads)
        k = rearrange(k, 'b seq (head head_size) -> b head seq head_size', head = self.config.num_heads)
        out = einsum('b head seq_q head_size, b head seq_k head_size -> b head seq_q seq_k', q, k)
        out = out / self.config.head_size ** 0.5
        return out
    
    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor] = None) -> t.Tensor:
        """
        additive_attention_mask: shape (batch, head=1, seq_q=1, seq_k) - used in training to prevent copying data from padding tokens. Contains 0 for a real input token and a large negative number for a padding token. If provided, add this to the attention pattern (pre softmax).

        Return: (batch, seq, hidden_size)
        """
        b,s,h = x.shape
        attention_pattern = self.attention_pattern_pre_softmax(x)
        if additive_attention_mask is not None:
            attention_pattern = attention_pattern + additive_attention_mask
            
        softmax_attention = attention_pattern.softmax(dim=-1)
        v = self.project_value(x)
        v = rearrange(v, 'b seq (head head_size) -> b head seq head_size', head = self.config.num_heads)
        weighted_attention = einsum('b head seq_q seq_k, b head seq_k head_size -> b head seq_q head_size',
                                   softmax_attention, v)
        weighted_attention = rearrange(weighted_attention, 'b head seq_q head_size -> b seq_q (head head_size)')
        output = self.project_output(weighted_attention)
        return output
    

In [39]:
class LayerNorm(nn.Module):
    weight: nn.Parameter
    bias: nn.Parameter
        
    def __init__(
        self, normalized_shape: Union[int, tuple, t.Size],
        eps = 1e-05, elementwise_affine = True, device = None,
        dtype = None,
    ):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
            
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.normalized_dims = tuple(range(-1, -1 - len(self.normalized_shape), -1))
        if elementwise_affine:
            self.weight = nn.Parameter(t.empty(self.normalized_shape, device = device, dtype = dtype))
            self.bias = nn.Parameter(t.empty(self.normalized_shape, device = device, dtype = dtype))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        """Initialize the weight and bias, if applicable."""
        if self.elementwise_affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)
    
    def forward(self, x: t.Tensor) -> t.Tensor:
        """x and the output should both have shape (batch, *)."""
        mean = x.mean(dim = self.normalized_dims, keepdim = True)
        var = x.var(dim = self.normalized_dims, keepdim = True, unbiased = False)
        output = (x - mean) / ((var + self.eps) ** 0.5)
        if self.elementwise_affine:
            output = output * self.weight + self.bias
        return output
    
test_layernorm_mean_1d(LayerNorm)
test_layernorm_mean_2d(LayerNorm)
test_layernorm_std(LayerNorm)
test_layernorm_exact(LayerNorm)
test_layernorm_backward(LayerNorm)

__main__.test_layernorm_mean_1d passed in 0.00s.
__main__.test_layernorm_mean_2d passed in 0.00s.
__main__.test_layernorm_std passed in 0.00s.
__main__.test_layernorm_exact passed in 0.00s.
__main__.test_layernorm_backward passed in 0.00s.


In [40]:
class Embedding(nn.Module):
    num_embeddings: int
    embedding_dim: int
    weight: nn.Parameter

    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.weight = nn.Parameter(t.empty(self.num_embeddings, self.embedding_dim))

    def forward(self, x: t.LongTensor) -> t.Tensor:
        """For each integer in the input, return that row of the embedding.

        Don't convert x to one-hot vectors - this works but is too slow.
        """
        return self.weight[x]

    def extra_repr(self) -> str:
        return f"{self.num_embeddings}, {self.embedding_dim}"
    
assert repr(Embedding(10, 20)) == repr(t.nn.Embedding(10, 20))
test_embedding(Embedding)

__main__.test_embedding passed in 0.00s.


In [41]:
class BertMLP(nn.Module):
    first_linear: nn.Linear
    second_linear: nn.Linear
    layer_norm: LayerNorm

    def __init__(self, config: BertConfig):
        super().__init__()
        self.first_linear = nn.Linear(config.hidden_size, config.intermediate_size)
        self.act = nn.GELU()
        self.second_linear = nn.Linear(config.intermediate_size, config.hidden_size)
        self.layer_norm = LayerNorm(config.hidden_size, eps = config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: t.Tensor) -> t.Tensor:
        x1 = self.first_linear(x)
        x1 = self.act(x1)
        x1 = self.second_linear(x1)
        x1 = self.dropout(x1)
        x = self.layer_norm(x1 + x)
        return x
    


In [42]:
class BertAttention(nn.Module):
    self_attn: BertSelfAttention
    layer_norm: LayerNorm

    def __init__(self, config: BertConfig):
        super().__init__()
        self.layer_norm = LayerNorm(config.hidden_size, eps = config.layer_norm_epsilon)
        self.self_attn = BertSelfAttention(config)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor] = None) -> t.Tensor:
        skip = x
        x = self.self_attn(x, additive_attention_mask)
        x = self.dropout(x)
        x = self.layer_norm(x + skip)
        return x
    
class BertBlock(nn.Module):
    attention: BertAttention
    mlp: BertMLP

    def __init__(self, config: BertConfig):
        super().__init__()
        self.mlp = BertMLP(config)
        self.attention = BertAttention(config)

    def forward(self, x: t.Tensor, additive_attention_mask: Optional[t.Tensor] = None) -> t.Tensor:
        return self.mlp(self.attention(x, additive_attention_mask))

In [43]:
def make_additive_attention_mask(one_zero_attention_mask: t.Tensor, big_negative_number: float = -10000) -> t.Tensor:
    """
    one_zero_attention_mask: shape (batch, seq). Contains 1 if this is a valid token and 0 if it is a padding token.
    big_negative_number: Any negative number large enough in magnitude that exp(big_negative_number) is 0.0 for the floating point precision used.

    Out: shape (batch, heads, seq, seq). Contains 0 if attention is allowed, and big_negative_number if it is not allowed.
    """
    return rearrange((1 - one_zero_attention_mask) * big_negative_number, "b k -> b 1 1 k")

class BertCommon(nn.Module):
    token_embedding: Embedding
    pos_embedding: Embedding
    token_type_embedding: Embedding
    layer_norm: LayerNorm
    blocks: StaticModuleList[BertBlock]

    def __init__(self, config: BertConfig):
        super().__init__()
        self.token_embedding = Embedding(config.vocab_size, config.hidden_size)
        self.pos_embedding = Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embedding = Embedding(config.type_vocab_size, config.hidden_size)
        self.layer_norm = LayerNorm(config.hidden_size, eps = config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout)
        self.blocks = StaticModuleList([BertBlock(config) for _ in range(config.num_layers)])

    def forward(
        self,
        input_ids: t.Tensor,
        token_type_ids: Optional[t.Tensor] = None,
        one_zero_attention_mask: Optional[t.Tensor] = None,
    ) -> t.Tensor:
        """
        input_ids: (batch, seq) - the token ids
        token_type_ids: (batch, seq) - only used for next sentence prediction.
        one_zero_attention_mask: (batch, seq) - only used in training. See make_additive_attention_mask.
        """
        
        if token_type_ids is None:
            token_type_ids = t.zeros_like(input_ids, dtype = t.int64)
            
        position = t.arange(input_ids.shape[1]).to(input_ids.device)
        position = repeat(position, 'n -> b n', b = input_ids.shape[0])
        
        if one_zero_attention_mask is None:
            additive_attention_mask = None
        else:
            additive_attention_mask = make_additive_attention_mask(one_zero_attention_mask)
            
        x = self.token_embedding(input_ids)
        x = x + self.pos_embedding(position)
        x = x + self.token_type_embedding(token_type_ids)
        x = self.dropout(self.layer_norm(x))
        for block in self.blocks:
            x = block(x, additive_attention_mask = additive_attention_mask)
            
        return x
        
        

In [44]:
class BertLanguageModel(nn.Module):
    common: BertCommon
    lm_linear: nn.Linear
    lm_layer_norm: LayerNorm
    unembed_bias: nn.Parameter

    def __init__(self, config: BertConfig):
        super().__init__()
        self.common = BertCommon(config)
        self.lm_linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.lm_layer_norm = LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.unembed_bias = nn.Parameter(t.zeros(config.vocab_size))

    def forward(
        self,
        input_ids: t.Tensor,
        token_type_ids: Optional[t.Tensor] = None,
        one_zero_attention_mask: Optional[t.Tensor] = None,
    ) -> t.Tensor:
        """Compute logits for each token in the vocabulary.

        Return: shape (batch, seq, vocab_size)
        """
        x = self.common(input_ids, token_type_ids, one_zero_attention_mask)
        x = self.lm_linear(x)
        x = F.gelu(x)
        x = self.lm_layer_norm(x)
        x = t.einsum("vh,bsh->bsv", self.common.token_embedding.weight, x)
        x = x + self.unembed_bias
        return x
    
    

In [45]:
def load_pretrained_weights(config: BertConfig) -> BertLanguageModel:
    hf_bert = load_pretrained_bert()

    def _copy(mine, theirs):
        mine.detach().copy_(theirs)

    def _copy_weight_bias(mine, theirs, transpose=False):
        _copy(mine.weight, theirs.weight.T if transpose else theirs.weight)
        if getattr(mine, "bias", None) is not None:
            _copy(mine.bias, theirs.bias)

    mine = BertLanguageModel(config)
    # Let's set everything to NaN and then we'll know if we missed one.
    for name, p in mine.named_parameters():
        p.requires_grad = False
        p.fill_(t.nan)

    _copy_weight_bias(mine.common.token_embedding, hf_bert.bert.embeddings.word_embeddings)
    _copy_weight_bias(mine.common.pos_embedding, hf_bert.bert.embeddings.position_embeddings)
    _copy_weight_bias(mine.common.token_type_embedding, hf_bert.bert.embeddings.token_type_embeddings)
    _copy_weight_bias(mine.common.layer_norm, hf_bert.bert.embeddings.LayerNorm)

    # Set up type hints so our autocomplete works properly
    from transformers.models.bert.modeling_bert import BertLayer

    my_block: BertBlock
    hf_block: BertLayer

    for my_block, hf_block in zip(mine.common.blocks, hf_bert.bert.encoder.layer):  # type: ignore
        _copy_weight_bias(my_block.attention.self_attn.project_query, hf_block.attention.self.query)
        _copy_weight_bias(my_block.attention.self_attn.project_key, hf_block.attention.self.key)
        _copy_weight_bias(my_block.attention.self_attn.project_value, hf_block.attention.self.value)
        _copy_weight_bias(my_block.attention.self_attn.project_output, hf_block.attention.output.dense)
        _copy_weight_bias(my_block.attention.layer_norm, hf_block.attention.output.LayerNorm)

        _copy_weight_bias(my_block.mlp.first_linear, hf_block.intermediate.dense)
        _copy_weight_bias(my_block.mlp.second_linear, hf_block.output.dense)
        _copy_weight_bias(my_block.mlp.layer_norm, hf_block.output.LayerNorm)

    _copy_weight_bias(mine.lm_linear, hf_bert.cls.predictions.transform.dense)
    _copy_weight_bias(mine.lm_layer_norm, hf_bert.cls.predictions.transform.LayerNorm)

    assert t.allclose(
        hf_bert.bert.embeddings.word_embeddings.weight,
        hf_bert.cls.predictions.decoder.weight,
    ), "Embed and unembed weight should be the same"
    # "Cannot assign non-leaf Tensor to parameter 'weight'"
    # mine.unembed.weight = mine.token_embedding.weight

    # Won't remain tied
    # mine.unembed.weight = hf_bert.bert.embeddings.word_embeddings.weight

    # Won't remain tied during training
    # mine.unembed.weight.copy_(mine.token_embedding.weight)
    # mine.unembed.bias.copy_(hf_bert.cls.predictions.decoder.bias)

    # I think works but maybe less good if others have ref to the old Parameter?
    # mine.unembed_bias = nn.Parameter(input_embeddings.weight.clone())

    mine.unembed_bias.detach().copy_(hf_bert.cls.predictions.decoder.bias)

    fail = False
    for name, p in mine.named_parameters():
        if t.isnan(p).any():
            print(f"Forgot to initialize: {name}")
            fail = True
        else:
            p.requires_grad_(True)
    assert not fail
    return mine
    



my_bert = load_pretrained_weights(config)
for (name, p) in my_bert.named_parameters():
    assert (
        p.is_leaf
    ), "Parameter {name} is not a leaf node, which will cause problems in training. Try adding detach() somewhere."



In [48]:
def predict(model: BertLanguageModel, tokenizer, text: str, k=15) -> List[List[str]]:
    """
    Return a list of k strings for each [MASK] in the input.
    """
    model.eval()
    input_ids = tokenizer(text, return_tensors = 'pt')['input_ids']
    out = model(input_ids)
    pred = out[input_ids == tokenizer.mask_token_id]
    num_masks, vocab = pred.shape
    tops = pred.topk(k, dim=-1).indices
    return [[tokenizer.decode(t) for t in mask] for mask in tops]
    

tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-cased')
test_bert_prediction(predict, my_bert, tokenizer)
your_text = "The Answer to the Ultimate Question of Life, The Universe, and Everything is [MASK]."
predictions = predict(my_bert, tokenizer, your_text)
print("Model predicted: \n", "\n".join(map(str, predictions)))

Prompt: Former President of the United States of America, George[MASK][MASK]
Model predicted: 
 ['W', 'Washington', 'Bush', 'Wallace', 'Dewey', 'Polk', 'Patton', 'H', 'Marshall', 'C', 'Buchanan', 'Clinton', 'G', 'E', 'Carter']
['.', ';', '?', '!', '|', '...', 'Johnson', ',', 'Press', 'Brown', 'Smith', 'Anderson', 'Carter', 'Jones', 'III']
__main__.test_bert_prediction passed in 0.09s.
Model predicted: 
 ['Everything', 'Life', 'One', 'Love', 'Good', 'God', 'Nothing', 'Time', 'Here', 'Infinite', 'Space', 'Impossible', 'Free', 'Earth', 'Truth']


In [59]:
input_ids = tokenizer('Hello there', return_tensors='pt')['input_ids']
expected = []

def hook(module, inputs, output):
    x = inputs[0]
    out = output[0]
    expected.append((x, out))
    
hf_bert = load_pretrained_bert()
hf_bert.apply(remove_hooks)
hf_bert.eval()
for layer in hf_bert.bert.encoder.layer:
    layer.attention.register_forward_hook(hook)
    layer.register_forward_hook(hook)
hf_bert(input_ids)
actual = []

def my_hook(module, inputs, output):
    x = inputs[0]
    actual.append((x, output))
    
my_bert.eval()
my_bert.apply(remove_hooks)
for layer in my_bert.common.blocks:
    layer.attention.register_forward_hook(my_hook)
    layer.register_forward_hook(my_hook)
    
my_bert(input_ids)

assert len(expected) == len(actual)
for i, ((ex_in, ex_out), (ac_in, ac_out)) in enumerate(zip(expected, actual)):
    print(f"Step {i} input:", end="")
    allclose_atol(ac_in, ex_in, atol=1e-5)
    print('OK')
    print(f"Step {i} output:", end="")
    allclose_atol(ac_out, ex_out, atol=1e-5)
    print('OK')



Step 0 input:OK
Step 0 output:OK
Step 1 input:OK
Step 1 output:OK
Step 2 input:OK
Step 2 output:OK
Step 3 input:OK
Step 3 output:OK
Step 4 input:OK
Step 4 output:OK
Step 5 input:OK
Step 5 output:OK
Step 6 input:OK
Step 6 output:OK
Step 7 input:OK
Step 7 output:OK
Step 8 input:OK
Step 8 output:OK
Step 9 input:OK
Step 9 output:OK
Step 10 input:OK
Step 10 output:OK
Step 11 input:OK
Step 11 output:OK
Step 12 input:OK
Step 12 output:OK
Step 13 input:OK
Step 13 output:OK
Step 14 input:OK
Step 14 output:OK
Step 15 input:OK
Step 15 output:OK
Step 16 input:OK
Step 16 output:OK
Step 17 input:OK
Step 17 output:OK
Step 18 input:OK
Step 18 output:OK
Step 19 input:OK
Step 19 output:OK
Step 20 input:OK
Step 20 output:OK
Step 21 input:OK
Step 21 output:OK
Step 22 input:OK
Step 22 output:OK
Step 23 input:OK
Step 23 output:OK
