In [319]:
import torch
from torch import nn
from torch.nn import functional as F

In [312]:
class Transformer_Attention(nn.Module):
    """One head of self/cross-attention"""
    def __init__(self, embed_dim:int, masked:bool, dropout:float, kdim:int, vdim:int):
        super().__init__()

        self.WQ = nn.Linear(embed_dim, kdim, bias=False)
        self.WK = nn.Linear(embed_dim, kdim, bias=False)
        self.WV = nn.Linear(embed_dim, vdim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.masked = masked

    def forward(self, *x):
        # x[0] for WQ
        # x[1] or x[-1] for WK and WV
        assert x[0].size(-1) == x[-1].size(-1), "mismatched shapes"
        q = self.WQ(x[0]) # (B,T,kdim)
        k = self.WK(x[-1]) # (B,T,kdim)

        attention_act = q @ k.transpose(1, 2) * k.size(-1) ** -0.5
        attention_weights = self.dropout(F.softmax(attention_act, dim=-1)) # (B,T,T)

        if self.masked:
            tril = torch.tril(torch.ones_like(attention_weights, requires_grad=False)).to(x[-1].device)
            attention_weights.masked_fill(tril==0, float("-inf"))

        v = self.WV(x[-1]) # (B,T,vdim)
        ## Weighted sum
        out = attention_weights @ v
        return out

In [313]:
class MultiHead_TransformerAttention(nn.Module):
    """Multiple heads of self/cross-attention in parallel"""
    def __init__(self, embed_dim, num_heads, masked:bool, dropout=0.0, kdim=None, vdim=None):
        super().__init__()
        
        if not kdim:
            kdim = embed_dim

        if not vdim:
            vdim = embed_dim
        self.head_list = nn.ModuleList([Transformer_Attention(embed_dim, masked, dropout, kdim, vdim) for _ in range(num_heads)])
        self.WO = nn.Linear(num_heads*vdim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, *x):
        y = torch.cat([head(*x) for head in self.head_list], dim=-1)
        
        y = self.dropout(self.WO(y))
        return y

In [169]:
class FeedForward_Network(nn.Module):
    def __init__(self, d_model, dim_feedforward, dropout):
        super().__init__()
        self.ff = nn.Sequential(nn.Linear(d_model, dim_feedforward),
                                nn.ReLU(),
                                nn.Linear(dim_feedforward, d_model),
                                nn.Dropout(dropout))
    def forward(self, x):
        return self.ff(x)

### EncoderLayer

In [248]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model:int, nhead:int, dim_feedforward:int, dropout=0.0):
        super().__init__()
        head_features = d_model // nhead
        self.mha = MultiHead_TransformerAttention(d_model, nhead, False, dropout, head_features, head_features)
        self.ln_mha = nn.LayerNorm(d_model)

        self.ffn = FeedForward_Network(d_model, dim_feedforward, dropout)
        self.ln_ffn = nn.LayerNorm(d_model)
        
    def forward(self, x):
        ## (x +) is for residual connections
        ## It observed that doing layernorm before computation layers is better, but for now we only do re-implementing for the paper which do layernorm after computation layer.
        x = self.ln_mha(x + self.mha(x))
        x = self.ln_ffn(x + self.ffn(x))
        return x 

In [275]:
class Encoder(nn.Module):
    def __init__(self, num_layers: int, d_model:int, nhead:int, dim_feedforward:int, dropout=0.0):
        super().__init__()
        self.blocks = nn.Sequential(*[EncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)])
        
    def forward(self, x):
        return self.blocks(x)

### DecoderLayer
- There are two types of decoders:
    1. Decoder that works as auto-regressive, and this decoder will have only self-attention layers.
    2. Decoder that works with the Encoder, and this decoder will have self-attention layers followed my cross-attention layers (like in the paper).
- For this repo we are going to implement cross-attention decoder

In [276]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model:int, nhead:int, dim_feedforward:int, dropout=0.0):
        super().__init__()
        head_features = d_model // nhead

        self.self_mha = MultiHead_TransformerAttention(d_model, nhead, True, dropout, kdim=head_features, vdim=head_features)
        self.ln_self_mha = nn.LayerNorm(d_model)

        ## Cross attention
        self.cross_mha = MultiHead_TransformerAttention(d_model, nhead, False, dropout, kdim=head_features, vdim=head_features)
        self.ln_cross_mha = nn.LayerNorm(d_model)

        self.ffn = FeedForward_Network(d_model, dim_feedforward, dropout)
        self.ln_ffn = nn.LayerNorm(d_model)

    def forward(self, *x):
        # x[0] should be decoder input embedings
        # x[1] should be encoder output embedings
        x_dec = x[0]
        x_enc = x[1] ## x[1] instead of x[-1] to raise an error incase user forgot to pass encoder rich-tokens

        x_dec = self.ln_self_mha(x_dec + self.self_mha(x_dec))
        oo = self.ln_cross_mha(x_dec + self.cross_mha(x_dec, x_enc))
        oo = self.ln_ffn(oo + self.ffn(oo))
        return oo

In [277]:
class Decoder(nn.Module):
    def __init__(self, num_layers: int, d_model:int, nhead:int, dim_feedforward:int, dropout=0.0):
        super().__init__()
        self.blocks = nn.Sequential(*[DecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)])
        
    def forward(self, *x):
        x_dec = x[0]
        x_enc = x[1]
        for block in self.blocks:
            x_dec = block(x_dec, x_enc)
        return x_dec

### Encoder-Decoder Transformer

In [278]:
## Hyperparameters (for test)
layers_repeat = 6
heads = 8 ## number of heads for each multi-head Attention
d_model = 512
dropout = 0.1
dim_feedforward = d_model*4
enc_vocab_size = 30000
dec_vocab_size = 40000

In [367]:
class Identity(nn.Module):
    def __init__(self, ):
        super().__init__()
    
    def forward(self, x):
        return torch.zeros_like(x)

class Transformer(nn.Module):
    def __init__(self, ):
        super().__init__()
        
        self.encEmbed = nn.Embedding(enc_vocab_size, d_model)
        self.decEmbed = nn.Embedding(dec_vocab_size, d_model)
        self.pos = Identity() # Pass for now

        self.encoder = Encoder(num_layers=layers_repeat, d_model=d_model,
                                nhead=heads, dim_feedforward=dim_feedforward, dropout=dropout)
        
        self.decoder = Decoder(num_layers=layers_repeat, d_model=d_model,
                                nhead=heads, dim_feedforward=dim_feedforward, dropout=dropout)

        self.classifier = nn.Linear(d_model, dec_vocab_size)
    
    def forward(self, src_tokens, trgt_tokens):

        src_embed = self.encEmbed(src_tokens)
        src_pos = self.pos(src_embed)
        encoder_input = src_embed + src_pos
        encoder_context = self.encoder(encoder_input)
        
        trgt_embed = self.decEmbed(trgt_tokens)
        trgt_pos = self.pos(trgt_embed)
        decoder_input = trgt_embed + trgt_pos
        decoder_context = self.decoder(decoder_input, encoder_context)
        
        logits = self.classifier(decoder_context)
        return logits
    
    @torch.no_grad()
    def inference(self, src_tokens):
        src_embed = self.encEmbed(src_tokens)
        src_pos = self.pos(src_embed)
        encoder_input = src_embed + src_pos
        encoder_context = self.encoder(encoder_input)

        ## Assume <pad>:0, <unk>:1 <s>:2, </s>:3
        token_list = [2]
        confidences = []
        token = token_list[0]
        maxtries = 0

        while token != 3 and maxtries <= src_tokens.size(-1) + 5:
            trgt_embed = self.decEmbed(torch.tensor([token_list]).to(src_tokens.device))
            trgt_pos = self.pos(trgt_embed)
            decoder_input = trgt_embed + trgt_pos
            decoder_context = self.decoder(decoder_input, encoder_context)
            logits = self.classifier(decoder_context) # (B,T,vocab_size) often B=1
            
            softmax = F.softmax(logits[:,-1,:], dim=-1) # (B,vocab_size)
            confidence, token = torch.max(softmax, dim=-1)
            
            token = token.item()
            confidence = confidence.item()
            token_list.append(token)
            confidences.append(confidence)
            maxtries += 1
        
        return token_list, confidences

In [368]:
transformer = Transformer()
src_tokens = torch.randint(low=0, high=enc_vocab_size, size=(32, 10))
trgt_tokens = torch.randint(low=0, high=enc_vocab_size, size=(32, 15))

logits = transformer(src_tokens, trgt_tokens)
print(logits.shape)

torch.Size([32, 15, 40000])


In [388]:
x_text = torch.tensor([[2,500,100,8564, 21, 1, 754, 3]]) ## dumb
transformer.eval()
tokens, confidences = transformer.inference(x_text)
transformer.train()
print(tokens, confidences, sep='\n')
print("Not Trained yet.")

[2, 24271, 37332, 18889, 26387, 19201, 34973, 17613, 26224, 18935, 20685, 20685, 20685, 20685, 20685]
[0.00022207263100426644, 0.0003712518373504281, 0.0003272329340688884, 0.00031365477479994297, 0.0002349430724279955, 0.00019178724323865026, 0.00018867503968067467, 0.00022262873244471848, 0.00023418100317940116, 0.00020738778403028846, 0.00019847130170091987, 0.0002041447296505794, 0.00020849690190516412, 0.0002112431247951463]
Not Trained yet.


In [381]:
def get_parameters_info(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad==True)
    nontrainable = sum(p.numel() for p in model.parameters() if p.requires_grad==False)

    return trainable, nontrainable

tr, nontr = get_parameters_info(transformer)
print(f"Total trainable parameters= {tr:,}\nTotal non-trainable parameters= {nontr:,}") 

Total trainable parameters= 100,470,848
Total non-trainable parameters= 0
