In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)  # Size of the key vector
        scores = torch.matmul(query, key.transpose(-2, -1)) / d_k**0.5

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)
        output = torch.matmul(attention, value)
        return output, attention


In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads

        self.query_fc = nn.Linear(d_model, d_model)
        self.key_fc = nn.Linear(d_model, d_model)
        self.value_fc = nn.Linear(d_model, d_model)
        self.out_fc = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention()

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query = self.query_fc(query).view(batch_size, -1, self.num_heads, self.d_k)
        key = self.key_fc(key).view(batch_size, -1, self.num_heads, self.d_k)
        value = self.value_fc(value).view(batch_size, -1, self.num_heads, self.d_k)

        query = query.transpose(1, 2)  # (batch_size, num_heads, seq_len, d_k)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        output, attention = self.attention(query, key, value, mask)

        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.out_fc(output)

        return output, attention


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


In [5]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super(PositionwiseFeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedforward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Multi-Head Attention
        attn_output, attn_weights = self.attention(x, x, x, mask)
        x = self.layer_norm1(x + self.dropout1(attn_output))

        # Feed Forward Network
        ffn_output = self.ffn(x)
        x = self.layer_norm2(x + self.dropout2(ffn_output))

        return x


In [7]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.attention1 = MultiHeadAttention(d_model, num_heads)
        self.attention2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedforward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, memory, src_mask, tgt_mask):
        # Self-attention
        attn_output, _ = self.attention1(x, x, x, tgt_mask)
        x = self.layer_norm1(x + self.dropout1(attn_output))

        # Encoder-Decoder Attention
        attn_output, _ = self.attention2(x, memory, memory, src_mask)
        x = self.layer_norm2(x + self.dropout2(attn_output))

        # Feed Forward Network
        ffn_output = self.ffn(x)
        x = self.layer_norm3(x + self.dropout3(ffn_output))

        return x


In [8]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model)

        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.positional_encoding(self.encoder_embedding(src))
        tgt = self.positional_encoding(self.decoder_embedding(tgt))

        memory = src
        for encoder in self.encoder_layers:
            memory = encoder(memory, src_mask)

        output = tgt
        for decoder in self.decoder_layers:
            output = decoder(output, memory, src_mask, tgt_mask)

        output = self.fc_out(output)
        return output


In [10]:
import math


In [11]:
# Initialize the model
model = Transformer(src_vocab_size=10000, tgt_vocab_size=10000)

# Sample input
src = torch.randint(0, 10000, (32, 10))  # batch size 32, sequence length 10
tgt = torch.randint(0, 10000, (32, 12))  # batch size 32, sequence length 12

# Forward pass
output = model(src, tgt)
print(output.shape)  # Should output (32, 12, 10000)


torch.Size([32, 12, 10000])
