In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Scaled Dot-Product Attention

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k
    def forward(self, Q, K, V):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        return output

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(self.d_k)

    def forward(self, Q, K, V):
        batch_size = Q.size(0)

        # Linear projections
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention
        attn_values = self.attention(Q, K, V)

        # Concatenate and pass through final linear layer
        concat = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(concat)
        return output

# Position-wise Feed-Forward Network
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = PositionwiseFeedforward(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x + self.attention(x, x, x))
        x = self.norm2(x + self.ff(x))
        return x

# Transformer Model
class Transformer(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Example Usage
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048

model = Transformer(num_layers, d_model, num_heads, d_ff)

# Dummy input (batch size, sequence length, d_model)
x = torch.rand(64, 10, d_model)
output = model(x)
print(output)

tensor([[[ 1.1781e+00,  1.4841e-01, -7.0080e-01,  ...,  1.1856e+00,
          -3.4810e-01,  2.2296e-01],
         [ 6.8155e-01,  8.4364e-01, -1.5005e+00,  ..., -8.6762e-02,
          -1.8138e+00,  6.2113e-01],
         [ 4.5144e-01,  1.1330e+00, -1.0829e+00,  ...,  8.0591e-01,
          -5.7017e-01,  6.7763e-02],
         ...,
         [ 7.5145e-01, -7.6074e-01, -1.0524e+00,  ..., -4.8734e-01,
          -1.5585e-01, -1.3068e+00],
         [ 7.9490e-01, -1.0910e+00, -1.2889e-01,  ..., -1.1822e-01,
          -2.6666e+00, -1.2966e+00],
         [ 1.2134e+00, -8.0573e-01, -1.6540e+00,  ...,  5.2009e-01,
          -1.4772e+00, -5.4820e-01]],

        [[ 4.7008e-02, -2.9697e-01, -4.8584e-01,  ...,  5.1817e-01,
          -1.8642e+00,  1.8510e+00],
         [-6.3696e-01, -1.3792e+00, -7.9070e-01,  ...,  1.2614e+00,
          -1.8337e+00,  3.6326e-01],
         [ 1.3495e+00, -1.3591e-01, -1.1504e+00,  ..., -2.0536e-02,
          -9.1363e-01,  2.8693e-01],
         ...,
         [ 9.1070e-01, -8