In [3]:
import torch
import torch.nn as nn
import torch.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class multihead(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads , dropout=0.0):
        super().__init__()

        assert d_out % n_heads == 0, "d_out must be divisible by n_heads"

        self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads
        self.context_size = context_size
        self.d_head = d_out // n_heads
        self.dropout = dropout

        self.q = nn.Linear(self.d_in, self.d_out, bias=False)
        self.v = nn.Linear(self.d_in, self.d_out, bias=False)
        self.k = nn.Linear(self.d_in, self.d_out, bias=False)

        self.out = nn.Linear(self.d_out, self.d_out, bias=False)

    def forward(self, x):
        b, num_tokens, d_in = x.size()

        q = self.q(x).view(b, num_tokens, self.n_heads, self.d_head)
        k = self.k(x).view(b, num_tokens, self.n_heads, self.d_head)
        v = self.v(x).view(b, num_tokens, self.n_heads, self.d_head)

        q = q.permute(0, 2, 1, 3)  # (b, n_heads, num_tokens, d_head)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        att = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout,
            is_causal=False,
        )

        att = att.permute(0, 2, 1, 3).contiguous()
        att = att.view(b, num_tokens, self.d_out)

        return self.out(att)


In [None]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
        nn.Linear(n_embed, 4*n_embed),
        nn.GELU(),
        nn.Linear(4*n_embed, n_embed),
        nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Block(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads, dropout=0.0):
        super().__init__()
        self.att = multihead(d_in, d_out, context_size, n_heads, dropout)
        self.ff = FeedForward(d_out, dropout)
        self.layer_norm1 = nn.LayerNorm(d_out)
        self.layer_norm2 = nn.LayerNorm(d_out)

    def forward(self, x):
        x = self.layer_norm1(x + self.att(x))
        x = self.layer_norm2(x + self.att(x))
        return x

In [None]:
class BERTModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_layers, n_heads, context_size, dropout=0.0):
        super().__init__()

        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.context_size = context_size
        self.vocab_size = vocab_size

        self.token_embed = nn.Embedding(self.vocab_size, self.hidden_size)
        self.position_embed = nn.Embedding(self.context_size, self.hidden_size)
        self.segment_embed = nn.Embedding(2, self.hidden_size)

        self.embed_layernorm = nn.LayerNorm(self.hidden_size)
        self.embed_dropout = nn.Dropout(dropout)

        self.blocks = nn.Sequential(
            *[Block(d_in=self.hidden_size, d_out=self.hidden_size,
                    context_size=self.context_size, n_heads=self.n_heads, dropout=dropout)
            for _ in range(self.n_layers)]
        )

        self.fc = nn.Linear(self.hidden_size, 2)

    def forward(self, x, seg_ids):
        b, num_tokens = x.size()

        pos_ids = torch.arange(num_tokens, dtype=torch.long, device=x.device).unsqueeze(0)

        tok_embed = self.token_embed(x)
        pos_embed = self.position_embed(pos_ids)
        seg_embed = self.segment_embed(seg_ids)

        embeddings = tok_embed + pos_embed + seg_embed

        embeddings = self.embed_layernorm(embeddings)
        embeddings = self.embed_dropout(embeddings)

        output = self.blocks(embeddings)

        cls_tokens = output[:, 0, :]

        logits = self.fc(cls_tokens)
        return logits