# Librairies

In [5]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Let's build Bert

In [19]:
#Q of shape (Batch size, sequence_len)


class Attention(nn.Module):
    def __init__(self, dk, dv, d_model,attention_probs_dropout_prob,device='cpu'):
        super().__init__()
        self.WQ = nn.Linear(d_model,dk, device = device)
        self.WK = nn.Linear(d_model,dk, device = device)
        self.WV = nn.Linear(d_model,dv, device = device)
        self.dk = dk
        self.dropout = nn.Dropout(attention_probs_dropout_prob)


    def forward(self, q, k, v ):
        q = self.WQ(q)
        k = self.WK(k)
        v = self.WV(v)
        attention_scores  = F.softmax(torch.matmul(q,k.transpose(1,2))/np.sqrt(self.dk),dim=-1)
        
        attention_matrix  = torch.matmul(attention_scores, v)

        return self.dropout(attention_matrix)
    

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, dk, dv, d_model,attention_probs_dropout_prob, device='cpu'):
        super().__init__()
        self.attentions = nn.ModuleList([Attention(dk,dv, d_model,attention_probs_dropout_prob, device) for _ in range(n_head)])

        self.output = nn.Linear(n_head*dv, d_model, device = device )


    def forward(self, q, k, v):

        attention_matrices = [attention(q,k,v) for attention in self.attentions]

        attentions = torch.cat(attention_matrices, dim=-1)

        return self.output(attentions)
    

class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim,hidden_dropout_prob, device = 'cpu'):
        super().__init__()
        self.ff2 = nn.Linear(hidden_dim, d_model, device = device)
        self.ff1 = nn.Linear(d_model, hidden_dim, device = device)
        self.GELU = nn.GELU()
        self.dropout = nn.Dropout(hidden_dropout_prob)  


    def forward(self, x):
        x = self.ff1(x)
        x = self.GELU(x)
        x = self.ff2(x)

        return self.dropout(x)

class BERTLayer(nn.Module):
    def __init__(self, n_head,dk, dv, d_model, hidden_dim,hidden_dropout_prob,attention_probs_dropout_prob, device = 'cpu'):
        super().__init__()
        self.mha = MultiHeadAttention( n_head, dk, dv, d_model,attention_probs_dropout_prob, device)
        

        self.feedforward = FeedForward(d_model, hidden_dim,hidden_dropout_prob, device)

    
        self.norm1 =  nn.LayerNorm(d_model,device = device)
        self.norm2 = nn.LayerNorm(d_model,device=device)



    def forward(self,x):
        
        x =  self.norm1(self.mha(x,x,x)+x)
        x = self.norm2(self.feedforward(x)+x)

        return x
    


class Bert(nn.Module):
    def __init__(self, n_layer, n_head,dk, dv, d_model, hidden_dim, vocab_size,max_seq_len,hidden_dropout_prob,attention_probs_dropout_prob, device = 'cpu'):
        super().__init__()
        self.layers = nn.ModuleList([BERTLayer(n_head,dk, dv, d_model, hidden_dim,hidden_dropout_prob,attention_probs_dropout_prob, device) for _ in range(n_layer)])
        self.n_layer = n_layer
        self.tok_embedding = nn.Embedding(vocab_size, d_model, device=device)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model, device=device)
        self.segment_embedding = nn.Embedding(2, d_model, device=device) 
        self.device = device


    def forward(self, x, segment_ids):
        batch_size, seq_len = x.size()
        positions = torch.arange(0, seq_len, device=self.device).unsqueeze(0).expand(batch_size, seq_len)

        x = (
            self.tok_embedding(x)
            + self.pos_embedding(positions)
            + self.segment_embedding(segment_ids)
        )
        for layer in self.layers:
            x =  layer(x)

        return x 



        
## In the paper of the paper of the BERTbased the hyperparameter where:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
model = Bert(
    n_layer=12,
    n_head=12,
    dk=64,
    dv=64,
    d_model=768,
    hidden_dim=3072,
    vocab_size=30522,
    max_seq_len=512,
    hidden_dropout_prob = 0.1,
    attention_probs_dropout_prob = 0.1,
    device=device
)

        



input_ids = torch.randint(0, 30522, (2, 20)).to(device)           # (batch_size, seq_len)
token_type_ids = torch.zeros_like(input_ids)                       # all segment 0
output = model(input_ids, token_type_ids)
        

        




        
        



    
        


torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 768])
torch.Size([2, 20, 64])
torch.Size([2, 20, 7

In [10]:
model.eval()

Bert(
  (layers): ModuleList(
    (0-11): 12 x BERTLayer(
      (mha): MultiHeadAttention(
        (attentions): ModuleList(
          (0-11): 12 x Attention(
            (WQ): Linear(in_features=64, out_features=768, bias=True)
            (WK): Linear(in_features=64, out_features=768, bias=True)
            (WV): Linear(in_features=768, out_features=64, bias=True)
          )
        )
        (output): Linear(in_features=768, out_features=768, bias=True)
      )
      (feedforward): FeedForward(
        (ff2): Linear(in_features=3072, out_features=768, bias=True)
        (ff1): Linear(in_features=768, out_features=3072, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
  )
  (tok_embedding): Embedding(30000, 768)
  (pos_embedding): Embedding(512, 768)
  (segment_embedding): Embedding(2, 768)
)