### Transformers 102: Building entire Transformer Architetcure ###

This notebook is built on top of the "Transformers 101" notebook that I previously created. In the previous notebook, I created a small version of the Transformer Decoder network and it was mostly for understanding withou undertaking much of the computational efficacies and clean code-ups. In this notebook, I will write the entire the code up of the Transformer model with both the encoder and decoder bits. I will also release a python file along with the notebook for dynamic usage later on.

So, let's just begin.

In [1]:
## Importing necessary packages ##

import torch
import torch.nn as nn
import torch.nn.functional as F

We are going to build everything piece by piece based on the architectural diagram given in this [paper](https://arxiv.org/pdf/1706.03762.pdf).

The first thing we will build up is the **Scaled Dot Product Attention**. This is pretty simple, as we previously saw in our previous notebook. The mechanism deals with the interaction between three matrices, namely, Query, Key and Value. 

The calculation is very simple : $softmax(\frac{Q \cdot K^{T}}{\sqrt{d_{k}}}) \cdot V$ .

In the previous notebook we implemented the Decoder of the Transformer where we were needed to do a masking of after the dot product of the Query and Key to not allow interaction of sequences from future. But if we build up our encoder too, this masking is no longer necessary. Hence, we would build a generic Scaled Dot Product attention module which allows for optional Masking.

So, lets implement this.

In [2]:
## Scaled Dot Product Attention ##


class ScaledDotProductAttention(nn.Module):
    """Implements the scaled dot product attention."""

    def __init__(
        self,
        d_embed: int,
        d_k: int,
        d_v: int = None,
        mask: bool = False,
    ):
        """Constructor"""

        super().__init__()

        # If d_v is not specified set to same as d_k #
        if d_v == None:
            d_v = d_k

        # Query, key and values linear layers #
        self.query_ffn = nn.Linear(d_embed, d_k, bias=False)
        self.key_ffn = nn.Linear(d_embed, d_k, bias=False)
        self.value_ffn = nn.Linear(d_embed, d_v, bias=False)

        self.mask = mask

    def forward(
        self,
        x: torch.tensor,
        y: torch.tensor = None,
        z: torch.tensor = None,
        return_wt: bool = False,
    ):
        """Applies a forward pass through the Scaled Dot Product Attention."""

        # Getting the key, query and the value
        key = self.key_ffn(x)

        if y == None:
            y = x.clone()

        if z == None:
            z = x.clone()

        query = self.query_ffn(y)
        value = self.value_ffn(z)

        # Getting d_k from query
        d_k = key.shape[-1]

        # Calculating weights
        weight = query @ key.transpose(-2, -1) / d_k**0.5

        # Setting mask
        if self.mask:
            weight = torch.tril(weight)
            weight = weight.masked_fill(weight == 0, float("-inf"))

        # Pass through softmax
        weight = F.softmax(weight, dim=-1)

        # Finally product with the values and return
        if return_wt:
            return weight @ value, weight

        return weight @ value


One thing that is worth noting is that in the forward pass, I optionally gave an opportunity to send in 3 input tensors which can be individually used for generating the keys, queries and values. This makes it highly generalizable since now we can use this exact class for cross-attention as well as self-attention. Moreover, I also allowed for an optional Mask term which makes the module generic for Masked attention as well as unmasked attention.

Lets test it out...

In [3]:
## Test ScaledDotProductAttention ##

## For reproducibility ##
torch.manual_seed(97)

## Without mask ##
print(f'Without Mask')
print(f'--------------')
test_attn_1 = ScaledDotProductAttention(d_embed=2, d_k=2)

x = torch.randn(2,2)

print(f'Input :\n\t{x}')
print(f'--------------')

result_1, wt = test_attn_1(x, return_wt=True)

print(f'--------------')
print(f'Output : {result_1}')

print(f'--------------')
print(f'Weights after softmax for sanity check : {wt}')

print(f'\n\nWith Mask')
print(f'--------------')
test_attn_2 = ScaledDotProductAttention(d_embed=2, d_k=2,mask=True)

result_2, wt = test_attn_2(x, return_wt=True)

print(f'Output : {result_2}')

print(f'--------------')
print(f'Weights after softmax for sanity check : {wt}')

Without Mask
--------------
Input :
	tensor([[-2.1115, -0.4614],
        [ 0.9178,  1.7334]])
--------------
--------------
Output : tensor([[-0.5837, -0.1434],
        [-0.6758, -0.1538]], grad_fn=<MmBackward0>)
--------------
Weights after softmax for sanity check : tensor([[0.3377, 0.6623],
        [0.2804, 0.7196]], grad_fn=<SoftmaxBackward0>)


With Mask
--------------
Output : tensor([[-0.8411, -1.1339],
        [ 0.0678, -0.2628]], grad_fn=<MmBackward0>)
--------------
Weights after softmax for sanity check : tensor([[1.0000, 0.0000],
        [0.4882, 0.5118]], grad_fn=<SoftmaxBackward0>)


Now we are done with the simple building block of our Transformer model, we can setup the other core bit of it, which is the **Multi-head Attention** module.

In [4]:
class MultiHeadAttention(nn.Module):
    """Implements the Multihead Attention."""

    def __init__(
        self,
        d_embed: int,
        num_heads: int,
        d_k: int = None,
        mask: bool = False,
    ):
        """Constructor"""

        super().__init__()

        d_v = d_embed // num_heads
        if d_k == None:
            d_k = d_v

        self.multi_heads = nn.ModuleList(
            [
                ScaledDotProductAttention(
                    d_embed=d_embed,
                    d_k=d_k,
                    d_v=d_v,
                    mask=mask,
                )
                for _ in range(num_heads)
            ]
        )

        self.mask = mask

    def forward(self, x: torch.tensor, y: torch.tensor = None, z: torch.tensor = None):
        """Forward Pass"""

        return torch.cat([head(x, y, z) for head in self.multi_heads], dim=-1)


In [5]:
## Testing ##

## For reproducibility ##
torch.manual_seed(97)

x = torch.randn(4, 8, 64)

mult_attn = MultiHeadAttention(d_embed=64, num_heads=8)

print(f'Output shape is : {mult_attn(x).shape}')

#print(f'Multihead attention module :\n{mult_attn}')

print(f'Type of mult_attn is {type(mult_attn)}')

Output shape is : torch.Size([4, 8, 64])
Type of mult_attn is <class '__main__.MultiHeadAttention'>


Looks perfect!! Now moving on...

In the encoder and the decoder block, after the Attention layers, there is a simple Feedforward network which maps the output of the attention mechanism to a richer representation. This sub-module is called **Position Wise Feed Forward Networks**. It basically does the following thing.

1. Pass the output of the attention through a linear layer with increased features. (In the original paper it is 4 x d_model).
2. Pass it through a ReLU non-linearity.
3. Finally pass it through another linear layer with reduced features (map it back to d_model).

In [6]:
## Position wise Feed Forward Networks ##

class PositionWiseFFN(nn.Module):
    """Implements the Position Wise Feed Forward Networks"""
    
    def __init__(self, d_embed : int):
        """Constructor"""
        
        super().__init__()
        
        self.pwffn = nn.Sequential(
            nn.Linear(in_features=d_embed, out_features=4*d_embed, bias=False),
            nn.ReLU(),
            nn.Linear(in_features=4*d_embed, out_features=d_embed, bias=False),
        )
        
    def forward(self, x):
        """Forward Pass"""
        
        return self.pwffn(x)

In [7]:
## testing ##

## For reproducibility ##
torch.manual_seed(97)

pwffn = PositionWiseFFN(512)

inp = torch.randn(4, 8 , 512)

print(f'Output shape : {pwffn(inp).shape}')

Output shape : torch.Size([4, 8, 512])


Perfect!!

We are slowly building our work.

Next up is the formation of the Position embedding. We are going to use the sinusoidal embedding used by the original paper. Since, it doesn't have any learnable parameters we will implement the simple using a python function.

In [8]:
## Postion Encoding ##

def PositionalEncoding(seq_length : int, d_embed : int, device = 'cpu'):
    """Positional Embedding"""
    
    pos = torch.arange(seq_length , device=device).unsqueeze(-1)
    i = torch.arange(d_embed, device=device)
    
    position_embedding = torch.empty((seq_length, d_embed))
    
    position_embedding[: , ::2] = torch.sin(pos/1000**(i[::2]/d_embed))
    position_embedding[: , 1::2] = torch.cos(pos/1000**(i[1::2]/d_embed))
    
    return position_embedding

In [9]:
## Testing ##

position_embedding = PositionalEncoding(8,32)

position_embedding

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  6.9250e-01,  6.0469e-01,  8.6618e-01,  4.0931e-01,
          9.4281e-01,  2.7043e-01,  9.7575e-01,  1.7689e-01,  9.8975e-01,
          1.1522e-01,  9.9567e-01,  7.4919e-02,  9.9817e-01,  4.8678e-02,
          9.9923e-01,  3.1618e-02,  9.9968e-01,  2.0534e-02,  9.9986e-01,
          1.3335e-02,  9.9994e-01,  8.6595e-03,  9.9998e-01,  5.6234e-03,
          9.9999e-01,  3.6517e-03,  1.0000e+00,  2.3714e-03,  1.0000e+00,
          1.5399e-03,  1.0000e+00],
        [ 9.0930e-01, -4.0877e-02,  9.63

Again... Looks perfect!!

We have almost built up all the necessary modules. The few things that are remaining are the Add and Norm modules present in each of the Encoder and Decoder block layers. These are nothing but a simple residual connection of the input after being passed through a sublayer (like attention or point wise ffn), with an added Dropout and then passed through simple LayerNorm layer.

So, our next goal is to prepare this module called the **ResidualDropoutNorm**. 

In [10]:
class ResidualDropoutNorm(nn.Module):
    """Implements the Residual Dropout Norm Layer."""

    def __init__(
        self, sublayer: nn.Module, d_embed: int = 512, dropout_rate: float = 0.1
    ):
        """Constructor."""

        super().__init__()

        self.sublayer = nn.ModuleList([sublayer, nn.Dropout(dropout_rate)])
        self.layer_norm = nn.LayerNorm(d_embed)

    def forward(self, x, y: torch.tensor = None, z: torch.tensor = None):
        """Forward Pass."""

        out = x.clone()

        for layer in self.sublayer:
            if type(layer) == MultiHeadAttention and layer.mask:
                out = layer(out, y, z)
            else:
                out = layer(out)
        x = x + out

        return self.layer_norm(x)


In [11]:
## testing ##

## For reproducibility ##
torch.manual_seed(97)

x = torch.randn(3 , 2 , 64)

rdn = ResidualDropoutNorm(MultiHeadAttention(d_embed=64, num_heads=8), 64)

print(f'Output shape : {rdn(x).shape}')

rdn = ResidualDropoutNorm(PositionWiseFFN(d_embed=64), 64)

print(f'Output shape : {rdn(x).shape}')

Output shape : torch.Size([3, 2, 64])
Output shape : torch.Size([3, 2, 64])


And it matches...!! 

I am really happy to be honest, and if you are reading this and trying to build it out yourself and you are at this point now, you should be very proud of yourself, because we are almost done and just need to plug in all these modular pieces together to make our final architecture.

Now its time to make the layers of the Encoder and the Decoder, which are repeated N times to make the encoder and decoder of the Transformer.

In [12]:
## Encoder Layer ##

class EncoderLayer(nn.Module):
    """Encoder Layer."""

    def __init__(
        self,
        d_embed: int,
        num_heads: int,
        dropout_rate: float,
        d_k: int = None,
        mask: bool = False,
    ):
        """Constructor."""

        super().__init__()

        self.encoder_layer = nn.ModuleList(
            [
                ResidualDropoutNorm(
                    d_embed=d_embed,
                    sublayer=MultiHeadAttention(
                        d_embed=d_embed,
                        num_heads=num_heads,
                        d_k=d_k,
                        mask=mask,
                    ),
                    dropout_rate=dropout_rate,
                ),
                ResidualDropoutNorm(
                    d_embed=d_embed,
                    sublayer=PositionWiseFFN(d_embed),
                    dropout_rate=dropout_rate,
                ),
            ]
        )

    def forward(self, x, y: torch.tensor = None, z: torch.tensor = None):
        """Forward pass through a single encoder layer"""

        for layer in self.encoder_layer:
            out = layer(x, y, z)

        return out

In [13]:
## Testing ##

## For reproducibility ##
torch.manual_seed(97)

enc_layer = EncoderLayer(d_embed = 512, num_heads = 8, dropout_rate = 0.1)

x = torch.randn(4, 8, 512)

print(f'Output shape : {enc_layer(x).shape}')

Output shape : torch.Size([4, 8, 512])


Perfect... Looks awesome...

Next up is the decoder layer!! Here we must be a bit careful since we need to make two attention heads. One masked attention and another without masking. Also we cannot put everything in `nn.Sequential` since the second attention head receives different inputs. So, lets slowly and carefully build it up.

In [14]:
## Decoder Layer ##

class DecoderLayer(nn.Module):
    """Implements a single decoder layer."""

    def __init__(
        self,
        d_embed: int,
        num_heads: int,
        dropout_rate: float,
        d_k_sa: int = None,
        use_ca: bool = False,
        d_k_ca: int = None,
    ):

        """Constructor."""

        super().__init__()

        self.masked_attn = ResidualDropoutNorm(
            sublayer=MultiHeadAttention(
                d_embed=d_embed,
                num_heads=num_heads,
                d_k=d_k_sa,
                mask=True,
            ),
            d_embed=d_embed,
            dropout_rate=dropout_rate,
        )

        self.cross_attn = (
            ResidualDropoutNorm(
                sublayer=MultiHeadAttention(
                    d_embed=d_embed,
                    num_heads=num_heads,
                    d_k=d_k_ca,
                    mask=False,
                ),
                d_embed=d_embed,
                dropout_rate=dropout_rate,
            )
            if use_ca
            else None
        )

        self.pwfn = ResidualDropoutNorm(
            d_embed=d_embed,
            sublayer=PositionWiseFFN(d_embed),
            dropout_rate=dropout_rate,
        )

    def forward(self, z: torch.tensor, x: torch.tensor = None, y: torch.tensor = None):
        """Forward Pass through a single decoder layer."""

        z = self.masked_attn(x=z)

        if self.cross_attn != None:
            z = self.cross_attn(x, y, z)

        return self.pwfn(z)

Lets test it out!!

In [15]:
## Testing ##

## For reproducing ##
torch.manual_seed(97)

decoder_layer = DecoderLayer(d_embed = 64, num_heads = 8, dropout_rate = 0.1, use_ca = True, d_k_ca = 128)

from_encoder = torch.randn(3 , 8 , 64)
from_decoder = torch.randn(3 , 8 , 64)

print(f'Output shape : {decoder_layer(z = from_decoder, x= from_encoder).shape}')

Output shape : torch.Size([3, 8, 64])


Wow!! That's a lot of good work... Now let's move on and finally make our encoder and decoder blocks and finally the full Transformer block.

In [16]:
## Encoder block ##

class Encoder(nn.Module):
    """Encoder Block"""

    def __init__(
        self,
        d_embed: int,
        num_heads: int,
        dropout_rate: float,
        d_k: int = None,
        mask: bool = False,
        num_layers: int = 6,
    ):
        """Constructor"""

        super().__init__()
        self.encoder = nn.ModuleList(
            [
                EncoderLayer(d_embed, num_heads, dropout_rate, d_k, mask)
                for _ in range(num_layers)
            ]
        )

    def forward(self, inp, return_intermediate=False):
        """Forward Pass"""

        intermediate_val = torch.tensor([])

        for layer in self.encoder:
            inp = layer(inp)
            intermediate_val = torch.cat([intermediate_val, inp.unsqueeze(0)], dim=0)

        if return_intermediate:
            return intermediate_val[-1], intermediate_val

        return intermediate_val[-1]

In [17]:
## Testing ##

torch.manual_seed(97)

enc = Encoder(d_embed=64, num_heads=4, dropout_rate=0.1)

inp = torch.randn(2,2,64)

print(f'Output is {enc(inp).shape}')

print(f'Intermediate shape is {enc(inp,True)[-1].shape}')

intermediate_values = enc(inp,True)[-1]
print(f'Intermediate values are : {intermediate_values}')

Output is torch.Size([2, 2, 64])
Intermediate shape is torch.Size([6, 2, 2, 64])
Intermediate values are : tensor([[[[ 1.8488, -0.6840, -0.2318,  ..., -1.0235,  0.4784, -0.7230],
          [ 0.2811, -0.7021, -0.2585,  ..., -0.8179, -1.2603, -1.2209]],

         [[-0.6921,  1.5826,  0.3856,  ...,  0.2469,  0.6324,  0.5615],
          [-0.0306,  0.0901,  1.0393,  ..., -0.6681,  1.4845, -1.3618]]],


        [[[ 2.3459, -1.0848,  0.2010,  ..., -1.1772,  0.5774, -0.8000],
          [ 0.4959, -0.6215, -0.2967,  ..., -0.5862, -1.1931, -1.1484]],

         [[-0.6813,  1.3622,  0.4601,  ...,  0.3345,  0.7836,  0.3445],
          [ 0.1517,  0.0620,  0.7610,  ..., -0.4687,  1.2950, -1.3701]]],


        [[[ 2.3232, -1.3206,  0.1268,  ..., -1.2306,  0.9281, -1.0693],
          [ 0.5146, -0.6943, -0.1544,  ..., -0.9186, -0.7569, -1.0942]],

         [[-0.0967,  1.2268,  0.3612,  ..., -0.0602,  1.2503,  0.2629],
          [ 0.0069, -0.0968,  0.9283,  ..., -0.6343,  1.7400, -1.0824]]],


        [[[

Perfect... Looks nice!

Now we will build the Decoder Block.

In [18]:
## Decoder Block ##

class Decoder(nn.Module):
    """Decoder Block"""

    def __init__(
        self,
        d_embed: int,
        num_heads: int,
        dropout_rate: float,
        d_k_sa: int = None,
        use_ca: bool = False,
        d_k_ca: int = None,
        num_layers: int = 6,
    ):

        super().__init__()

        self.decoder = nn.ModuleList(
            [
                DecoderLayer(
                    d_embed,
                    num_heads,
                    dropout_rate,
                    d_k_sa,
                    use_ca,
                    d_k_ca,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, z, intermediate_value=None):
        """Forward Pass"""

        for i, layer in enumerate(self.decoder):
            if intermediate_value != None:
                z = layer(z=z, x=intermediate_value[i])

            else:
                z = layer(z=z)

        return z

In [19]:
## testing ##

torch.manual_seed(97)

inp = torch.randn(2,2,64)

encoder = Encoder(d_embed=64, num_heads=4, dropout_rate=0.1)

_ , intermediate_val = encoder(inp, return_intermediate = True)

decoder = Decoder(d_embed=64, num_heads=4, dropout_rate=0.1, use_ca = True)

print(inp.shape)

print(f'Output shape : {decoder(inp, intermediate_val).shape}')

torch.Size([2, 2, 64])
Output shape : torch.Size([2, 2, 64])


Amazing we are done with almost everything. Now what we just need to do is plug in everything in our **Transformers** module.

In [20]:
class Transformer(nn.Module):
    """Transformer module."""

    def __init__(
        self,
        vocab_size: int,
        sequence_length: int,
        d_embed: int,
        use_encoder: bool,
        use_decoder: bool,
        num_heads: int,
        d_k_encoder: int = None,
        encoder_mask: bool = False,
        encoder_num_layers: int = 6,
        decoder_num_layers: int = 6,
        dropout_rate: float = 0.1,
        classification: bool = False,
        num_classes: int = None,
        use_ca: bool = False,
        d_k_sa: int = None,
        d_k_ca: int = None,
        mask: bool = False,
        device: torch.device = torch.device("cpu"),
    ):
        """Constructor"""

        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_embed)

        # self.pos_embed = PositionalEncoding(sequence_length, d_embed, device)

        self.final_output = (
            nn.Linear(d_embed, num_classes)
            if classification
            else nn.Linear(d_embed, vocab_size)
        )

        self.encoder = None
        self.decoder = None

        if use_encoder:
            self.encoder = Encoder(
                d_embed,
                num_heads,
                dropout_rate,
                d_k_encoder,
                mask,
                num_layers=encoder_num_layers,
            )

        if use_decoder:
            self.decoder = Decoder(
                d_embed,
                num_heads,
                dropout_rate,
                d_k_sa,
                use_ca,
                d_k_ca,
                num_layers=decoder_num_layers,
            )

        self.device = device
        self.register_buffer("use_ca", torch.tensor([use_ca]))
        self.register_buffer("decoder_num_layers", torch.tensor([decoder_num_layers]))
        self.register_buffer("sequence_length", torch.tensor([sequence_length]))
        self.register_buffer("classification", torch.tensor([classification]))

    def forward(self, inputs=None, outputs=None):
        """Forward Pass."""

        intermediate = None

        if self.encoder != None:
            inputs = self.embed(inputs)
            B, L, C = inputs.shape
            pos_embed = PositionalEncoding(L, C, self.device)
            inputs = inputs + pos_embed
            inputs, intermediate = self.encoder(inputs, return_intermediate=True)

            if self.decoder == None or self.classification.item() == True:
                return self.final_output(inputs)

        if self.decoder != None:
            outputs = self.embed(outputs)
            B, L, C = outputs.shape
            pos_embed = PositionalEncoding(L, C, self.device)
            outputs = outputs + pos_embed

            if (intermediate != None and self.use_ca.item() == True) or (
                intermediate != None
            ):
                outputs = self.decoder(outputs, intermediate)

            elif self.use_ca:
                intermediate = torch.cat(
                    [outputs.unsqueeze(0) for _ in range(self.decoder_num_layers)],
                    dim=0,
                )
                outputs = self.decoder(outputs, intermediate)

            else:
                outputs = self.decoder(outputs)

        return self.final_output(outputs)

    ## Utility function to show data ##

    def _show_data(self, data, idx_2_char_map, verbose=True):
        """Given a data tensor, maps them to string and prints them."""
        str_data = [idx_2_char_map[each_word.item()] for each_word in data.data]
        if verbose:
            print(str_data)
        else:
            return str_data

    def generate(
        self,
        max_length: int,
        idx_2_char_map: dict,
        device: torch.device = torch.device("cpu"),
    ):
        """Generative bit."""

        idx = torch.zeros((1, 1), device=device).int()

        assert self.decoder != None, "Decoder must be present!!"

        for _ in range(max_length):
            logits = self(outputs=idx[:, -self.sequence_length :])
            logits = logits[:, -1, :]
            prob = torch.nn.functional.softmax(logits, dim=-1)
            next_idx = torch.multinomial(prob, 1).int()
            idx = torch.cat([idx, next_idx], dim=1)

        print("".join(self._show_data(idx[0, 1:].data, idx_2_char_map, verbose=False)))

In [21]:
## Final test for Transformer module ##

torch.manual_seed(97)

print(f'------------------------------------')
print(f'Testing just decoder without encoder')
print(f'------------------------------------')

transformer_1 = Transformer(vocab_size = 6000,
                          sequence_length = 8,
                          d_embed = 32,
                          use_encoder=False,
                          use_decoder=True,
                          num_heads=8,
                          )


outputs_1 = torch.randint(0, 6000, (2, 8))

print(f'Outputs shape : {outputs_1.shape}')

print(f'Outputs shape : {transformer_1(outputs=outputs_1).shape}')

print(f'------------------------------------')
print(f'Testing just decoder without encoder \nbut with cross attention in between')
print(f'------------------------------------')

transformer_2 = Transformer(vocab_size = 6000,
                          sequence_length = 8,
                          d_embed = 32,
                          use_encoder=False,
                          use_decoder=True,
                          num_heads=8,
                          use_ca=True
                          )

outputs_2 = torch.randint(0, 6000, (2, 8))

print(f'Outputs shape : {outputs_2.shape}')

print(f'Outputs shape : {transformer_2(outputs=outputs_2).shape}')

print(f'------------------------------------')
print(f'Testing encoder and decoder')
print(f'------------------------------------')

inputs = torch.randint(0, 6000, (2, 8))
outputs = torch.randint(0, 6000, (2, 8))

transformer_2 = Transformer(vocab_size = 6000,
                          sequence_length = 8,
                          d_embed = 32,
                          use_encoder=True,
                          use_decoder=True,
                          num_heads=8,
                          use_ca=True
                          )

print(f'Inputs shape : {inputs.shape}')
print(f'Outputs shape : {outputs.shape}')

print(f'Outputs shape : {transformer_2(inputs=inputs,outputs=outputs).shape}')

------------------------------------
Testing just decoder without encoder
------------------------------------
Outputs shape : torch.Size([2, 8])
Outputs shape : torch.Size([2, 8, 6000])
------------------------------------
Testing just decoder without encoder 
but with cross attention in between
------------------------------------
Outputs shape : torch.Size([2, 8])
Outputs shape : torch.Size([2, 8, 6000])
------------------------------------
Testing encoder and decoder
------------------------------------
Inputs shape : torch.Size([2, 8])
Outputs shape : torch.Size([2, 8])
Outputs shape : torch.Size([2, 8, 6000])


Perfect... We have coded up everything that we need to build our transformer model. Now what I will do is put all these codes in a python module so that in the next Notebook we can test our Transformer block with a real world data.... See ya!!