In [7]:
#pip install torch

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does matrix multiplications for query*keys for all batches and heads at once
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)
        out = self.fc_out(out)
        return out


In [10]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


In [11]:
class Transformer(nn.Module):
    def __init__(self, embed_size, num_layers, heads, forward_expansion, dropout, device):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(
                embed_size,
                heads,
                dropout=dropout,
                forward_expansion=forward_expansion,
            )
            for _ in range(num_layers)
        ])

        self.device = device

    def forward(self, x, mask):
        out = x
        for layer in self.layers:
            out = layer(out, out, out, mask)
        return out


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer = Transformer(embed_size=512, num_layers=6, heads=8, forward_expansion=4, dropout=0.1, device=device)

In [13]:

# Example tensor (batch_size, sequence_length, embed_size)
x = torch.randn((32, 10, 512))
mask = None  # Define mask if needed

transformer.to(device)
x = x.to(device)

# Forward pass through the transformer
out = transformer(x, mask)

# Printing the output tensor
print("Output Tensor:")
print(out)
print("Output Shape:", out.shape)


Output Tensor:
tensor([[[-2.4236e+00,  2.3581e-01, -3.6038e-01,  ..., -9.7450e-02,
          -1.1645e-01,  3.3492e-01],
         [-2.1384e-02,  1.0151e+00, -3.0742e+00,  ...,  1.0237e-01,
          -1.3232e+00,  4.4146e-01],
         [ 0.0000e+00, -7.6741e-02, -4.2155e-02,  ..., -1.3690e+00,
          -2.1432e+00,  8.8490e-01],
         ...,
         [-0.0000e+00, -1.3314e+00,  3.2071e-01,  ..., -8.1643e-01,
           1.9602e+00, -7.5400e-03],
         [-1.6196e-01, -3.0135e-01, -2.4847e-01,  ..., -2.5513e+00,
          -9.2246e-01,  0.0000e+00],
         [-9.4961e-02,  8.6847e-02,  3.7584e-01,  ..., -0.0000e+00,
          -3.1179e-01,  0.0000e+00]],

        [[ 2.1049e+00, -4.5234e-01, -0.0000e+00,  ..., -8.3429e-01,
          -9.1045e-01,  1.7835e-01],
         [-2.8866e-01, -4.8122e-02,  4.4013e-01,  ..., -5.4069e-01,
          -1.2793e+00, -0.0000e+00],
         [-5.6460e-01, -4.9123e-03,  5.9178e-01,  ..., -0.0000e+00,
           2.7341e-01,  4.8724e-01],
         ...,
         [