In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class PositionalEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, seq_length):
        super().__init__()
        self.normal_embedding = nn.Embedding(num_embeddings, embedding_dim) ## output is input x embedding_dim
        position = torch.arange(seq_length).unsqueeze(1)
        
        self.positional_encoding = torch.zeros(seq_length, embedding_dim)
        self.positional_encoding[:, 0::2] = torch.sin(position/10_000**(torch.arange(0, embedding_dim, 2) / embedding_dim))
        self.positional_encoding[:, 1::2] = torch.cos(position/10_000**(torch.arange(0, embedding_dim, 2) / embedding_dim))

        # self.register_buffer("positional_encoding", self.positional_encoding)
    def forward(self, x): # input size is batch x seq_length x d_model
        return self.normal_embedding(x) + self.positional_encoding

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, dk, dv):
        super().__init__()
        self.n_heads = n_heads
        self.w_q = nn.Linear(d_model, n_heads * dk, bias=False) # we know that n_heads * dk or n_heads * dv is d_model 
        self.w_k = nn.Linear(d_model, n_heads * dk, bias=False)
        self.w_v = nn.Linear(d_model, n_heads * dv, bias=False)
        self.w_o = nn.Linear(n_heads * dv, d_model, bias=False)
        self.dk = dk
        self.dv = dv
        

    @staticmethod
    def attention(Q, K, V, mask=None):
        dk = Q.shape[-1]
        ## Q --> batch x h x ds x dk, K --> batch x h x ds x dk, V --> batch x h x ds x dv
        attention = (Q @ K.transpose(-1, -2)) / (dk ** 0.5) # attention --> batch x h x ds x ds
        if mask is not None:
            attention *= mask ## mask needs to be ds x ds
        attention = F.softmax(attention, dim=-1)
        print(attention.shape)
        print(V.shape)
        return attention @ V, attention

    def forward(self, Q, K, V, mask=None):
        query = self.w_q(Q) ## Batch x seq_len x d_model --> batch x seq_length x d_model
        key = self.w_k(K) ## Batch x seq_len x d_model --> batch x seq_length x d_model
        value = self.w_v(V) ## Batch x seq_len x d_model --> batch x seq_length x d_model
        batch_size, seq_length = query.shape[0], query.shape[1]
        query = query.view(batch_size, seq_length, self.n_heads, self.dk).transpose(1, 2)
        value = value.view(batch_size, seq_length, self.n_heads, self.dk).transpose(1, 2)
        value = key.view(batch_size, seq_length, self.n_heads, self.dv).transpose(1, 2)
        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask=mask)
        x = x.transpose(1,2).contigious().view(batch_size, seq_length, -1)

        x = self.w_o(x)
        return x

class LayerNormalisation(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.ones(1))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return (self.alpha) * (x - mean) / (std + self.eps) + self.bias


class ResidualConnection(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = LayerNormalisation()
    
    def forward(self, x, sublayer):
        return x + sublayer(self.norm(x))
    
class PositionWiseFFN(torch.nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.linear_2(F.relu(self.linear_1(x)))

In [17]:
class EncoderBlock(nn.Module):
    def __init__(self, n_heads, d_model, dk, dv, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dk=dk, dv=dv)
        self.ff =  PositionWiseFFN(d_model=d_model, d_ff=d_ff)
        self.residuals = nn.ModuleList([ResidualConnection() for _ in range(2)])
    
    def forward(self, x, mask=None):
        x = self.residuals[0](x, lambda x: self.attention(x, x, x, mask=mask))
        x = self.residuals[1](x, self.ff)
        return x
    

class Encoder(nn.Module): 
    def __init__(self, num,  n_heads, d_model, dk, dv, d_ff):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(n_heads, d_model, dk, dv, d_ff) for _ in range(num)])
        # TODO: Build some LayerNormalisation Here
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x # Maybe needs to be a LayerNorm here as well

class DecoderBlock(nn.Module):
    def __init__(self, n_heads, d_model, dk, dv, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dk=dk, dv=dv)
        self.ff =  PositionWiseFFN(d_model=d_model, d_ff=d_ff)
        self.attention_2 = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dk=dk, dv=dv)
        self.residuals = nn.ModuleList([ResidualConnection() for _ in range(3)])
    
    def forward(self, x, encoder_output, encoder_mask=None, decoder_mask=None):
        x = self.residuals[0](x, lambda x: self.attention(x, x, x, mask=decoder_mask))
        x = self.residuals[1](x, lambda x: self.attention_2(x, encoder_output, encoder_output, mask=encoder_mask))
        x = self.residuals[2](x, self.ff)
        return x
    
class Decoder(nn.Module):
    def __init__(self,num, n_heads, d_model, dk, dv, d_ff):
        super().__init__()
        self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, dk, dv, d_ff) for _ in range(num)])
        # TODO: Build some LayerNormalisation Here
    def forward(self, x, encoder_output, encoder_mask=None, decoder_mask=None):
        for layer in self.layers:
            x = layer(x, encoder_output, encoder_mask, decoder_mask)
        return x # Maybe needs to be a LayerNorm here as well
    
class Projection(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return F.log_softmax(self.linear(x), dim=-1)

In [18]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embedding: PositionalEmbedding, tgt_embedding: PositionalEmbedding, projection: Projection):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.proj = projection
        self.src_embedding = src_embedding
        self.tgt_embedding = tgt_embedding
    
    def encode(self, x, encoder_mask):
        embed = self.src_embedding(x)
        return self.encoder(embed, encoder_mask)
    
    def decode(self, x, encoder_output, encoder_mask, decoder_mask):
        embed = self.tgt_embedding(x)
        return self.decoder(embed, encoder_output, encoder_mask, decoder_mask)
    
    def forward(self, x, encoder_mask=None, decoder_mask=None):
        x_encoder = self.encode(x, encoder_mask)
        x_decoder = self.decode(x, x_encoder, encoder_mask, decoder_mask)
        x = self.proj(x_decoder)
        return x

In [19]:
d_model = 512
num_encoder = 6
num_decoder = 6
encoder_heads = 8
decoder_heads = 8
dk = d_model // encoder_heads
dv = d_model // encoder_heads
d_ff = 2048
src_seq_len = 10
tgt_seq_len = 10
vocab_size = 2755 # 2499 bpe pairs and 256 utf8 


In [20]:
encoder_transformer = Encoder(num_encoder, encoder_heads, d_model, dk, dv, d_ff)
decoder_transformer = Decoder(num_decoder, decoder_heads, d_model, dk, dv, d_ff)
src_embeddings = PositionalEmbedding(vocab_size, d_model, src_seq_len)
tgt_embeddings = PositionalEmbedding(vocab_size, d_model, tgt_seq_len)
projection = Projection(d_model, vocab_size)

In [21]:
model = Transformer(encoder_transformer, decoder_transformer, src_embeddings, tgt_embeddings, projection)

In [22]:
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [23]:
test_x = torch.randint(low=0, high=vocab_size, size=(12, src_seq_len))

In [24]:

model(test_x)

RuntimeError: The size of tensor a (8) must match the size of tensor b (12) at non-singleton dimension 1