In [1]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

In [2]:
seq_len = 10
d_model = 512
h = 8
d_k = d_model//h

X = torch.randn((seq_len, d_model))
Q = torch.randn((d_model, d_model))
K = torch.randn((d_model, d_model))
V = torch.randn((d_model, d_model))
print(d_k, X.shape, Q.shape, K.shape, V.shape)

64 torch.Size([10, 512]) torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])


In [3]:
# Self Attention

q = torch.matmul(X, Q)
k = torch.matmul(X, K)
v = torch.matmul(X, V)
print(q.shape, k.shape, v.shape)

attention = torch.matmul(q, k.T)/np.sqrt(d_model)
attention_softmax = nn.functional.softmax(attention, -1)
attention = torch.matmul(attention_softmax, v)
print(attention.shape)

torch.Size([10, 512]) torch.Size([10, 512]) torch.Size([10, 512])
torch.Size([10, 512])


In [4]:
# Multi-Headed Attendtion

W_o = torch.randn((d_model, d_model))
q = torch.matmul(X, Q)
k = torch.matmul(X, K)
v = torch.matmul(X, V)

q = torch.reshape(q, (seq_len, h, d_k)).transpose(0, 1)
k = torch.reshape(k, (seq_len, h, d_k)).transpose(0, 1)
v = torch.reshape(v, (seq_len, h, d_k)).transpose(0, 1)
print(q.shape, k.shape, v.shape)

k = torch.transpose(k, -2, -1)
multihead_attention = torch.matmul(q, k)/np.sqrt(d_k)
multihead_attention_softmax = nn.functional.softmax(multihead_attention, -1)
multihead_attention = torch.matmul(multihead_attention, v)
multihead_attention = multihead_attention.transpose(0, 1).reshape((seq_len, d_model))
multihead_attention = torch.matmul(multihead_attention, W_o)
print(multihead_attention.shape)

torch.Size([8, 10, 64]) torch.Size([8, 10, 64]) torch.Size([8, 10, 64])
torch.Size([10, 512])


In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, d_model):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        pos = torch.arange(0, seq_len).reshape((seq_len, 1))
        index = torch.arange(0, d_model).reshape((1, d_model))
        denominator = torch.pow(10000, (2*index)/d_model)
        pe = torch.zeros((seq_len, d_model))
        pe[:,0::2] = torch.sin(pos/denominator[:, 0::2])
        pe[:,1::2] = torch.cos(pos/denominator[:, 1::2])
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe.requires_grad_(False)


pe = PositionalEncoding(seq_len, d_model)
x = X.unsqueeze(0)
x = pe(x)
print(x, x.shape)

tensor([[[ 0.2192,  1.2478,  2.1913,  ...,  0.5312, -0.2162,  2.0037],
         [-0.6920, -0.3131, -0.9358,  ..., -0.3090,  0.7039,  1.2872],
         [ 1.3077, -0.9405,  0.9326,  ...,  0.5825,  0.1900,  1.7588],
         ...,
         [ 1.5464,  2.0699,  0.2880,  ...,  0.4341, -2.4510,  1.1655],
         [ 1.1835,  0.1523,  0.9788,  ...,  1.0372, -0.0145,  1.1234],
         [ 1.4313, -2.0561,  0.0652,  ...,  0.6187,  0.7517,  0.4005]]]) torch.Size([1, 10, 512])


In [6]:
import math

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, h, d_k):
        super().__init__()
        self.h = h
        self.d_k = d_k
        self.d_model = self.h * self.d_k
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    @staticmethod
    def attention(q, k, v, mask):
        k = k.transpose(-2, -1)
        attention_scores = torch.matmul(q, k)/math.sqrt(q.shape[-1])
        if mask is not None:
            attention_scores = torch.masked_fill(attention_scores, mask, -1e-9)
        attention_scores = nn.functional.softmax(attention_scores, -1)
        attention = torch.matmul(attention_scores, v)
        return attention

    def forward(self, q, k, v, mask=None):
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)

        q = torch.reshape(q, (q.shape[0], q.shape[1], self.h, self.d_k)).transpose(1, 2)
        k = torch.reshape(k, (k.shape[0], k.shape[1], self.h, self.d_k)).transpose(1, 2)
        v = torch.reshape(v, (v.shape[0], v.shape[1], self.h, self.d_k)).transpose(1, 2)
        x = MultiHeadAttentionBlock.attention(q, k, v, mask)

        x = x.transpose(1, 2)
        x = x.reshape((x.shape[0], x.shape[1], self.d_model))

        return self.w_o(x)
    
src_mask = torch.mean(x, -1) == 0
mha = MultiHeadAttentionBlock(h, d_k)
x = mha(x, x, x, src_mask)
print(x.shape, src_mask.shape)
        

torch.Size([1, 10, 512]) torch.Size([1, 10])


In [7]:
class FeedForwardLayer(nn.Module):
    def __init__(self, d_model:int, d_hidden:int):
        super().__init__()
        self.ff1 = nn.Linear(d_model, d_hidden)
        self.ff2 = nn.Linear(d_hidden, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.ff1(x))
        return self.ff2(x)
    
ffl = FeedForwardLayer(d_model, 2048)
x = ffl(x)
print(x.shape)

torch.Size([1, 10, 512])


In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, normalized_shape):
        super().__init__()
        self.layernorm = nn.LayerNorm(normalized_shape)

    def forward(self, x, sublayer):
        x = self.layernorm(x + sublayer(x))
        return x

res = ResidualBlock(x.shape)
x = res(x, lambda x: mha(x, x, x, src_mask))
print(x.shape)
x = res(x, ffl)
print(x.shape)

torch.Size([1, 10, 512])
torch.Size([1, 10, 512])


In [9]:
class EncoderBlock(nn.Module):
    def __init__(self, normalized_shape, mha: MultiHeadAttentionBlock, ffl: FeedForwardLayer):
        super().__init__()
        self.mha = mha
        self.ffl = ffl
        self.res = nn.ModuleList([ResidualBlock(normalized_shape) for _ in range(2)])
        
    def forward(self, x, src_mask):
        x = self.res[0](x, lambda x: self.mha(x, x, x, src_mask))
        x = self.res[1](x, self.ffl)

        return x

eb = EncoderBlock(x.shape, mha, ffl)
x = eb(x, src_mask)
print(x.shape)

torch.Size([1, 10, 512])


In [10]:
class Encoder(nn.Module):
    def __init__(self, encoder_blocks: nn.ModuleList):
        super().__init__()
        self.encoder_blocks = encoder_blocks

    def forward(self, x, src_mask):
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x, src_mask)
        
        return x

encoder_blocks = []
d_hidden = 2048
for _ in range(5):
    mha = MultiHeadAttentionBlock(h, d_k)
    ffl = FeedForwardLayer(d_model, d_hidden)
    encoder_block = EncoderBlock(x.shape, mha, ffl)
    encoder_blocks.append(encoder_block)
    
encoder_blocks = nn.ModuleList(encoder_blocks)
encoder = Encoder(encoder_blocks)
x = encoder(x, src_mask)
print(x.shape)

torch.Size([1, 10, 512])


In [11]:
class DecoderBlock(nn.Module):
    def __init__(self, normalized_shape, mha: MultiHeadAttentionBlock, ca: MultiHeadAttentionBlock, ffl: FeedForwardLayer):
        super().__init__()
        self.mha = mha
        self.ca = ca
        self.ffl = ffl
        self.res = nn.ModuleList([ResidualBlock(normalized_shape) for _ in range(3)])

    def forward(self, x, y, tgt_mask):
        x = self.res[0](x, lambda x: self.mha(x, x, x, tgt_mask))
        x = self.res[1](x, lambda x: self.ca(y, y, x))
        x = self.res[2](x, self.ffl)
        
        return x
    
tgt_mask = torch.ones((seq_len, seq_len)).triu(diagonal=1) == 1
tgt_mask = src_mask * tgt_mask
eb = EncoderBlock(x.shape, mha, ffl)
y = eb(x, src_mask)
db = DecoderBlock(x.shape, mha, mha, ffl)
z = db(x, y, tgt_mask)
print(z.shape)

torch.Size([1, 10, 512])


In [12]:
class Decoder(nn.Module):
    def __init__(self, decoder_blocks: nn.ModuleList):
        super().__init__()
        self.decoder_blocks = decoder_blocks

    def forward(self, x, y, tgt_mask):
        for decoder_block in self.decoder_blocks:
            x = decoder_block(x, y, tgt_mask)
        
        return x


encoder_blocks = []
d_hidden = 2048
for _ in range(5):
    mha = MultiHeadAttentionBlock(h, d_k)
    ffl = FeedForwardLayer(d_model, d_hidden)
    encoder_block = EncoderBlock(x.shape, mha, ffl)
    encoder_blocks.append(encoder_block)
    
encoder_blocks = nn.ModuleList(encoder_blocks)
encoder = Encoder(encoder_blocks)
x = encoder(x, src_mask)
print(x.shape)

decoder_blocks = []
d_hidden = 2048
for _ in range(5):
    mha = MultiHeadAttentionBlock(h, d_k)
    ca = MultiHeadAttentionBlock(h, d_k)
    ffl = FeedForwardLayer(d_model, d_hidden)
    decoder_block = DecoderBlock(x.shape, mha, ca, ffl)
    decoder_blocks.append(decoder_block)
    
decoder_blocks = nn.ModuleList(decoder_blocks)
decoder = Decoder(decoder_blocks)
z = decoder(x, y, tgt_mask)
print(z.shape)

torch.Size([1, 10, 512])
torch.Size([1, 10, 512])


In [13]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model:int, voc_size:int):
        super().__init__()
        self.ff1 = nn.Linear(d_model, d_model)
        self.ff2 = nn.Linear(d_model, voc_size)

    def forward(self, x):
        return torch.log_softmax(self.ff2(self.ff1(x)), -1)
    
ffl = ProjectionLayer(d_model, 2048)
op = ffl(x)
print(op.shape)

torch.Size([1, 10, 2048])


In [14]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder:Decoder, projection_layer: ProjectionLayer, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding):
        super().__init__()
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.encoder = encoder
        self.decoder = decoder
        self.projection_layer = projection_layer

    def encode(self, x, src_mask):
        x = self.src_pos(x)
        return self.encoder(x, src_mask)
    
    def decode(self, x, y, tgt_mask):
        y = self.tgt_pos(y)
        return self.decoder(x, y, tgt_mask)
    
    def project(self, x):
        return self.projection_layer(x)
    
    def forward(self, x, src_mask, tgt_mask):
        y = self.encode(x, src_mask)
        z = self.decode(x, y, tgt_mask)
        z = nn.functional.softmax(self.project(z), -1)
        return z

def build_transformer():
    n = 6
    voc_size = 1000
    encoder_blocks = []
    for _ in range(n):
        mha = MultiHeadAttentionBlock(h, d_k)
        ffl = FeedForwardLayer(d_model, d_hidden)
        encoder_block = EncoderBlock((seq_len, d_model), mha, ffl)
        encoder_blocks.append(encoder_block)

    encoder_blocks = nn.ModuleList(encoder_blocks)
    encoder = Encoder(encoder_blocks)

    decoder_blocks = []
    for _ in range(n):
        mha = MultiHeadAttentionBlock(h, d_k)
        ca = MultiHeadAttentionBlock(h, d_k)
        ffl = FeedForwardLayer(d_model, d_hidden)
        decoder_block = DecoderBlock((seq_len, d_model), mha, ca, ffl)
        decoder_blocks.append(decoder_block)

    decoder_blocks = nn.ModuleList(decoder_blocks)
    decoder = Decoder(decoder_blocks)
    
    src_pos = PositionalEncoding(seq_len, d_model)
    tgt_pos = PositionalEncoding(seq_len, d_model)
    projection_layer = ProjectionLayer(d_model, voc_size)
    transfomer = Transformer(encoder, decoder, projection_layer, src_pos, tgt_pos)
    return transfomer

transformer = build_transformer()
predicted_token = transformer(x, src_mask, tgt_mask)
print(predicted_token.shape)

torch.Size([1, 10, 1000])


In [15]:
class Tokenizer(nn.Module):
    def __init__(self, seq_len, padd_token):
        super().__init__()
        self.padd_token = padd_token
        self.seq_len = seq_len
        self.embedding = {
            padd_token: 0
        }
        self.counter = 0

    def tokenize(self, sentences):
        tokens = []
        for sentence in sentences:
            sentence = list(set(sentence.lower().split()))
            tokens.append(sentence)

        return tokens
    
    def unique_ids(self, tokens):
        embeddings = []
        for token_arr in tokens:
            if len(token_arr) < self.seq_len:
                token_arr = token_arr + [self.padd_token]*(self.seq_len-len(token_arr))
            else:
                token_arr = token_arr[:self.seq_len]

            for i, token in enumerate(token_arr):
                if token not in self.embedding:
                    self.counter = self.counter+1
                    self.embedding[token] = self.counter
                token_arr[i] = self.embedding[token]
            embeddings.append(token_arr)
        
        return torch.tensor(embeddings, dtype=torch.int)
    
    def token_to_id(self, token: str):
        return self.embedding[token]
    
    def id_to_token(self, id: int):
        id_index = list(self.embedding.values()).index(id)
        token = list(self.embedding.keys())[id_index]
        return token

padd_token = '<PAD>'
vocab_size = 1000

tokenizer = Tokenizer(seq_len, padd_token)
tokens = tokenizer.tokenize(['hi how are you', 'hello world hi I am santhosh'])
ids = tokenizer.unique_ids(tokens)
embedding = nn.Embedding(vocab_size, d_model)
embeddings = embedding(ids)
print(tokens, ids.shape, embeddings.shape)
 

[['how', 'hi', 'are', 'you'], ['world', 'hello', 'am', 'santhosh', 'hi', 'i']] torch.Size([2, 10]) torch.Size([2, 10, 512])


In [16]:
src_mask = (ids != 0).unsqueeze(1).unsqueeze(-1)
tgt_mask = torch.ones((seq_len, seq_len)).tril(diagonal=0)
tgt_mask = (src_mask * tgt_mask) == 0
res = transformer(embeddings, src_mask, tgt_mask)
print(res.shape, torch.argmax(res, dim=-1))

torch.Size([2, 10, 1000]) tensor([[145, 270, 385, 385, 270, 270, 270, 270, 270, 270],
        [109, 645, 197, 385, 385, 210, 385, 385, 385, 385]])
