In [5]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F


import os
from dataclasses import dataclass

In [6]:
@dataclass
class configs:
    num_epochs: int = 10
    batch_size: int = 64
    learning_rate: float = 0.001




In [7]:
@dataclass
class model_configs:
    src_n_vocab: int = 10_000
    target_n_vocab: int = 5_000
    # max_len: int = 512
    d_model: int = 512
    num_layers: int = 6
    num_heads: int = 8
    dropout: float = 0.1
    d_k: int = d_model // num_heads # 512 / 8 = 64 ## to be able to project it BY W_o --> h * d_k x d_model
    d_ff: int = 2048


In [8]:
class InputEmbedding(nn.Module):
    def __init__(self, n_vocab, d_model):
        super(InputEmbedding, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(n_vocab, d_model)

        # self.dropout = nn.Dropout(configs.dropout)

    def forward(self, x):
        x = self.embedding(x.int()) * self.d_model ** 0.5
        return x


In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        i_vector = torch.arange(0, d_model, 2, dtype=torch.float)

        denom = torch.pow(10_000.0, i_vector / d_model)

        pe[:, ::2] = torch.sin(pos / denom)
        pe[:, 1::2] = torch.sin(pos / denom)

        pe = pe.unsqueeze(0) # add batch dim

        self.register_buffer('pe', pe) # register buffer to be saved in the model state_dict


    def forward(self, x):

        return x + self.pe[:, :x.size(1)].requires_grad_(False)


In [10]:
class LayerNorm(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super(LayerNorm, self).__init__()

        self.eps = eps
        self.weight = nn.Parameter(torch.tensor(1.), requires_grad=True)
        self.bias = nn.Parameter(torch.tensor(1.), requires_grad=True)

    def forward(self, x):

        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)

        return self.weight * (x - mean) / torch.sqrt(std ** 2 + self.eps) + self.bias


In [11]:
class Add_and_Norm(nn.Module):
    def __init__(self, dropout=0.1):
        super(Add_and_Norm, self).__init__()
        self.layer_norm = LayerNorm()
        self.dropout = nn.Dropout(dropout)


    def forward(self, x, sublayer_output):
        return self.layer_norm(x + self.dropout(sublayer_output))


In [12]:
class selfAttentionHead(nn.Module):
    def __init__(self, d_model, d_k):
        super(selfAttentionHead, self).__init__()
        self.d_k = d_k

        self.wq = nn.Linear(d_model, d_k)
        self.wk = nn.Linear(d_model, d_k)
        self.wv = nn.Linear(d_model, d_k)


    def forward(self, x):

        q = self.wq(x) # shape: (bs, N, d_k)
        k = self.wk(x)
        v = self.wv(x)

        attention_matrix = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5) # (bs, N, N)
        attention_scores = torch.softmax(attention_matrix, dim=-1)

        weighted_value = torch.matmul(attention_scores, v)

        return weighted_value # (bs, N, d_k)



In [36]:
class MultiHeadAttention(nn.Module):
    '''
    MultiHead Attention implementation with parallel heads

    '''
    def __init__(self, num_heads, d_k, causal=False):
        super(MultiHeadAttention, self).__init__()

        self.num_heads = num_heads
        d_model = num_heads * d_k

        self.d_k = d_k

        self.wq = nn.Linear(d_model, d_model) # (d_model, d_k * num_heads)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.wo = nn.Linear(d_model, d_model) # (num_heads * d_k, d_model) = (d_model, d_model)

        self.causal = causal

    def attention(self, q, k, v):

        attention_matrix = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5) # (bs, num_heads, N, N)

        if self.causal:
            mask = torch.tril(torch.ones(attention_matrix.shape[-2], attention_matrix.shape[-1])).unsqueeze(0).unsqueeze(0).to(q.device) # (1, 1, N, N)
            attention_matrix = attention_matrix.masked_fill_(mask == 0, value= float('-inf'))


        attention_scores = torch.softmax(attention_matrix, dim=-1)

        weighted_value = torch.matmul(attention_scores, v) # (bs, num_heads, q_seq_len, d_k)

        return weighted_value, attention_scores

    def forward(self, q, k, v):



        q = self.wq(q) # shape: (bs, seq_len, d_model)
        k = self.wk(k)
        v = self.wv(v)

        ## split heads
        q = q.view(-1, q.shape[1], self.num_heads, self.d_k).transpose(1, 2).contiguous() # (bs, h, seq_len, d_k)
        k = k.view(-1, k.shape[1], self.num_heads, self.d_k).transpose(1, 2).contiguous()
        v = v.view(-1, v.shape[1], self.num_heads, self.d_k).transpose(1, 2).contiguous()


        weighted_value, self.attention_scores = self.attention(q, k, v)


        ##  concatenate all heads together
        weighted_value = weighted_value.transpose(1, 2).contiguous().view(-1, q.shape[2], d_model)  # (batch_size, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        output = self.wo(weighted_value) # (batch_size, seq_len, d_model)

        return output # (bs, seq_len, d_model), (bs, seq_len, seq_len)


In [38]:
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear_lift = nn.Linear(d_model, d_ff)
        self.linear_out = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()


    def forward(self, x):
      x = self.relu(self.linear_lift(x))
      x = self.dropout(x)
      x = self.linear_out(x)

      return x

In [39]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int = 512, d_k: int = 64, num_heads: int = 8, d_ff: int=2048, dropout=0.1):
        super(EncoderLayer, self).__init__()

        self.multiheadAttention = MultiHeadAttention(num_heads=num_heads, d_k=d_k)
        self.add_and_norm = Add_and_Norm(dropout)
        self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)


    def forward(self, x):

      shortcut = x
      x = self.multiheadAttention(x, x, x) # self-attention
      x = self.add_and_norm(shortcut, x)

      shortcut = x
      x = self.ffn(x)
      x = self.add_and_norm(shortcut, x)

      return x

In [40]:
class Encoder(nn.Module):
    def __init__(self, num_layers: int = 6, d_model: int = 512, d_k: int = 64, num_heads: int = 8, d_ff: int=2048):
        super(Encoder, self).__init__()

        self.layers = nn.ModuleList([
                        EncoderLayer(d_model, d_k, num_heads, d_ff) for _ in range(num_layers)
                        ])

    def forward(self, x):

        for layer in self.layers:
            x = layer(x)

        return x

In [41]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int = 512, d_k: int = 64, num_heads: int = 8, d_ff: int=2048, dropout=0.1):
        super(DecoderLayer, self).__init__()

        self.multiheadcrossAttention = MultiHeadAttention(num_heads=num_heads, d_k=d_k)
        self.causal_multiheadAttention = MultiHeadAttention(num_heads=num_heads, d_k=d_k, causal=True)
        self.add_and_norm = Add_and_Norm(dropout)
        self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)


    def forward(self, x, encoder_output):

        shortcut = x
        x = self.causal_multiheadAttention(x, x, x) # self-attention
        x = self.add_and_norm(shortcut, x)

        shortcut = x
        x = self.multiheadcrossAttention(x, encoder_output, encoder_output) # cross-attention
        x = self.add_and_norm(shortcut, x)

        shortcut = x
        x = self.ffn(x)
        x = self.add_and_norm(shortcut, x)

        return x

In [42]:
class Decoder(nn.Module):
    def __init__(self, n_vocab: int = 1000, num_layers: int = 6, d_model: int = 512, d_k: int = 64, num_heads: int = 8, d_ff: int=2048):
        super(Decoder, self).__init__()

        self.layers = nn.ModuleList([
                        DecoderLayer(d_model, d_k, num_heads, d_ff) for _ in range(num_layers)
                        ])

    def forward(self, x, encoder_output):

        for layer in self.layers:
            x = layer(x, encoder_output)

        return x

In [50]:
class Transformer(nn.Module):
    def __init__(self, configs):
        super(Transformer, self).__init__()

        self.src_embed = InputEmbedding(n_vocab=configs.src_n_vocab, d_model=configs.d_model)
        self.target_embed = InputEmbedding(n_vocab=configs.target_n_vocab, d_model=configs.d_model)

        self.pos_embed = PositionalEncoding(d_model=configs.d_model)

        self.encoder = Encoder(num_layers=configs.num_layers, d_model=configs.d_model,
                                d_k=configs.d_k, num_heads=configs.num_heads, d_ff=configs.d_ff)

        self.decoder = Decoder(num_layers=configs.num_layers, d_model=configs.d_model,
                                d_k=configs.d_k, num_heads=configs.num_heads, d_ff=configs.d_ff)

        self.projection_layer = nn.Sequential(nn.Linear(configs.d_model, configs.target_n_vocab),
                                          nn.Softmax(dim=-1)
                                        )

        self.init_weights()

    def embed_src(self, x):
        x = self.src_embed(x)
        x = self.pos_embed(x)

        return x

    def embed_tgt(self, x):
        x = self.target_embed(x)
        x = self.pos_embed(x)

        return x

    def init_weights(self):
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def generate(self, src, max_len, start_token, end_token):
        '''
        inference time autoregressive generation
        '''

        src = self.embed_src(src)

        encoder_output = self.encoder(src)
        generated_sequence = [start_token]



        for _ in range(max_len):

            target = torch.tensor(generated_sequence).unsqueeze(0).to(src.device)
            target = self.embed_tgt(target)

            decoder_output = self.decoder(target, encoder_output)
            logits = self.projection_layer(decoder_output[:, -1, :])
            next_token = torch.argmax(logits, dim=-1).item()
            generated_sequence.append(next_token)

            if next_token == end_token:
                break

        return torch.tensor(generated_sequence)

    def forward(self, src, target):

        src = self.embed_src(src)
        target = self.embed_tgt(target)

        encoder_output = self.encoder(src)
        decoder_output = self.decoder(target, encoder_output)

        output = self.projection_layer(decoder_output)

        return output




In [55]:
# Example usage
batch_size = 1
seq_length = 10
src_n_vocab = 10000
target_n_vocab = 5000
d_model = 512

src = torch.randint(0, src_n_vocab, (batch_size, seq_length))  # Source sequence
target = torch.randint(0, target_n_vocab, (batch_size, seq_length + 7))  # Source sequence

start_token = 1
end_token = 2
max_length = 20

model = Transformer(configs=model_configs)

output = model.forward(src, target)

print("Transformer output: ", output.shape)
model.eval()  # Set model to evaluation mode

generated_sequence = model.generate(src, max_length, start_token, end_token)
print(generated_sequence.shape)

Transformer output:  torch.Size([1, 17, 5000])
torch.Size([21])


In [23]:
|