# Transformer Architecture
Follows the "Attention is all you need" paper architecture

In [12]:
import torch
from torch import nn
import math

# Input

## Token Embeddings

In [13]:
class EmeddingsLayer(nn.Module):
    def __init__(self, d_model:int, vocab_size: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.d_model)

    def forward(self, X):
        return self.embedding(X) * math.sqrt(self.d_model) # "In the embedding layers, we multiply those weights by sqrt(d_model)"


## Positional Ecodings

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model:int, context_size:int):
        super().__init__()
        self.d_model = d_model
        self.context_size = context_size

        self.pe = torch.zeros(self.context_size, self.d_model,requires_grad=False)
        for pos in range(self.context_size):
            for i in range(0, self.d_model, 2):
                self.pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.d_model)))
                self.pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.d_model)))
    
    def forward(self):
        return self.pe.unsqueeze(0)

## Test embedding and position encodings

In [15]:

embed_test = EmeddingsLayer(d_model=512,vocab_size=50000)
pencoding = PositionalEncoding(d_model=512,context_size=1024)
example_data = torch.randint(1,50000,(64,1024))
print(example_data.shape)
embed_output = embed_test(example_data)
print(embed_output.shape,embed_output[0][0][:10])
pe_output = pencoding()
print(pe_output.shape,pe_output[0][0][:10])
embed_pos_output = embed_output + pe_output
print(embed_pos_output.shape,embed_pos_output[0][0][:10])

torch.Size([64, 1024])
torch.Size([64, 1024, 512]) tensor([  8.1030,  23.7976,  56.3750, -24.6229,   4.5044,  23.9931,   0.3400,
          5.2031, -21.3926,  55.1794], grad_fn=<SliceBackward0>)
torch.Size([1, 1024, 512]) tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
torch.Size([64, 1024, 512]) tensor([  8.1030,  24.7976,  56.3750, -23.6229,   4.5044,  24.9931,   0.3400,
          6.2031, -21.3926,  56.1794], grad_fn=<SliceBackward0>)


# Attention

In [16]:
class AttentionHead(nn.Module):
    def __init__(self, head_dim:int = 64, masked:bool = False) -> None:
        super().__init__()
        self.masked = masked
        self.queries= nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.keys = nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.values = nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?

    def forward(self,Q:torch.Tensor,K:torch.Tensor,V:torch.Tensor) -> torch.Tensor:
        B,T,C = K.shape
        scaled_dot_product_attention = (Q @ K.transpose(2,1))/torch.sqrt(torch.tensor(C))
        if self.masked:
            mask = torch.tril(torch.ones(T,T)) == 0
            scaled_dot_product_attention = scaled_dot_product_attention.masked_fill(mask,-float("inf"))

        dot_product_softened = torch.softmax(scaled_dot_product_attention,dim=-1)
        return dot_product_softened @ V


In [17]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,d_model:int = 512, num_heads:int = 8, masked = False) -> None:
        super().__init__()
        self.head_dim = math.floor(d_model/num_heads)
        self.layer_norm = nn.LayerNorm(d_model)
        self.heads = [AttentionHead(head_dim=self.head_dim,masked=masked) for h in range(num_heads)]
        self.linear = nn.Linear(d_model,d_model)
    
    def forward(self,X:torch.Tensor,Q:tuple,K:tuple,V:tuple) ->torch.Tensor:
        heads_output = []
        for head_index,head in enumerate(self.heads):
                queries = Q[head_index]
                keys = K[head_index]
                values = V[head_index]
                x = head(queries,keys,values) # this could be distributed to multiple devices for // processing
                heads_output.append(x) # accumulate result
        
        o = torch.cat(heads_output,dim=-1)
        linear_output = self.linear(o)
        return self.layer_norm(X+linear_output)


mhsa = MultiHeadSelfAttention(d_model=512,num_heads=8,masked=True)
sample_data = torch.randn((5,10,512))
splits = torch.split(sample_data,64,dim=2)
mhsa(sample_data,splits,splits,splits).shape

torch.Size([5, 10, 512])

# Position-wise Feedforward Network

In [18]:
class FeedForward(nn.Module):
    def __init__(self,d_model:int=512,d_ff:int=2048) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model,out_features=d_ff),
            nn.ReLU(),
            nn.Linear(in_features=d_ff,out_features=d_model)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.layer_norm(X+self.ffn(X))


# feedforward = FeedForward()
# feedforward.state_dict()

# Encorder

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model:int,d_ff:int,num_heads:int,**kwargs) -> None:
        super().__init__()
        self.head_dim = d_model // num_heads
        self.multihead_self_attention = MultiHeadSelfAttention(d_model=d_model,num_heads=num_heads)
        self.feedforward = FeedForward(d_model=d_model,d_ff=d_ff)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        splits = torch.split(X,self.head_dim,dim=2)
        return self.feedforward(self.multihead_self_attention(X,splits,splits,splits))
    
class Encoder(nn.Module):
    def __init__(self, number_of_encoder_blocks:int=6,**kwargs) -> None:
        super().__init__()
        self.encode = nn.Sequential(*[EncoderLayer(**kwargs) for n in range(number_of_encoder_blocks)])

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.encode(X)


# Decoder

In [20]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model:int,d_ff:int,num_heads:int,**kwargs) -> None:
        super().__init__()
        self.head_dim = d_model // num_heads
        self.masked_multi_head_self_attention = MultiHeadSelfAttention(d_model=d_model,num_heads=num_heads,masked=True)
        self.multi_head_self_attention = MultiHeadSelfAttention(d_model=d_model,num_heads=num_heads,masked=False)
        self.feedforward = FeedForward(d_model=d_model,d_ff=d_ff)

    def forward(self,outputs:torch.Tensor,encoded_sequence:torch.Tensor) -> torch.Tensor:
        """
        Encoder sequence torch.Size([64, 10, 512])
        """

        output_splits = torch.split(outputs,self.head_dim,dim=2)
        encoded_sequence_splits = torch.split(encoded_sequence,self.head_dim,dim=2)
        masked_output = self.masked_multi_head_self_attention(outputs,output_splits,output_splits,output_splits)
        mhsa_output = self.multi_head_self_attention(masked_output,Q=masked_output,K=encoded_sequence_splits,V=encoded_sequence_splits)
        
        return self.feedforward(mhsa_output)


class Decoder(nn.Module):
    def __init__(self, number_of_decoder_blocks:int, **kwargs) -> None:
        super().__init__()
        self.decoder_layers = nn.ModuleList([DecoderLayer(**kwargs) for n in range(number_of_decoder_blocks)])
        
    def forward(self,outputs:torch.Tensor,encoded_sequence:torch.Tensor) -> torch.Tensor:
        for decoder_layer in self.decoder_layers:
            outputs = decoder_layer(outputs,encoded_sequence)
        return outputs



# Transformer

In [34]:
class Transformer(nn.Module):
    def __init__(self,
                 vocab_size:int,
                 batch_size:int,
                 context_size:int,
                 d_model:int,
                 d_ff:int,
                 num_heads:int,
                 number_of_encoder_blocks:int,
                 number_of_decoder_blocks:int,
                 p_drop:float):
        
        super().__init__()
        self.context_size = context_size

        self.embedding = EmeddingsLayer(d_model=d_model,vocab_size=vocab_size)
        self.positional_encoding = PositionalEncoding(d_model=d_model,context_size=context_size)
        self.dropout = nn.Dropout(p=p_drop)
        self.encoder = Encoder(
                            vocab_size=batch_size,
                            batch_size=batch_size,
                            context_size=context_size,
                            d_model=d_model,
                            d_ff=d_ff,
                            num_heads=num_heads,
                            number_of_encoder_blocks=number_of_encoder_blocks,
                            p_drop=p_drop)
        
        self.decoder = Decoder(
                            vocab_size=batch_size,
                            batch_size=batch_size,
                            context_size=context_size,
                            d_model=d_model,
                            p_drop=p_drop,
                            d_ff=d_ff,
                            num_heads=num_heads,
                            number_of_decoder_blocks=number_of_decoder_blocks)
        
        self.linear = nn.Linear(in_features=d_model,out_features=vocab_size)


    def forward(self,X:torch.Tensor,y:torch.Tensor) -> torch.Tensor:
        pos_encoding = self.positional_encoding()

        input_embeddings = self.embedding(X) 
        inputs = self.dropout(input_embeddings+pos_encoding) # B*T*C
        
        encoded_sequence = self.encoder(inputs)
        output_embedding = self.embedding(y)
        outputs = self.dropout(output_embedding+pos_encoding) # B*T*C

        decoder_output = self.decoder(outputs,encoded_sequence)

        output_logits = self.linear(decoder_output)
        output_probs = torch.softmax(output_logits,dim=-1)


        return output_probs

In [33]:
config = {
    "vocab_size":37000,
    "batch_size":64,
    "context_size":10,
    "d_model":512,
    "num_heads":1,
    "d_ff":2048,
    "number_of_encoder_blocks": 6,
    "number_of_decoder_blocks": 6,
    "p_drop":0.1
}
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Transformer(**config)

X = torch.randint(1,config["vocab_size"],(config["batch_size"],config["context_size"]))
y = torch.randint(1,config["vocab_size"],(config["batch_size"],config["context_size"]))

o = model(X,y)
print(o.shape)

torch.Size([64, 10, 37000])


# Train Transformer
