In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import DataLoader, TensorDataset

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    dim_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(dim_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attention = torch.nn.functional.softmax(scores, dim=-1)
    return torch.matmul(attention, value)


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        out = scaled_dot_product_attention(Q, K, V, mask)
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        return self.fc_out(out)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, embed_size, hidden_dim):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.ffn = FeedForward(embed_size, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn))
        ffn_out = self.ffn(x)
        return self.norm2(x + self.dropout(ffn_out))

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_dim, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.attention1 = MultiHeadAttention(embed_size, num_heads)
        self.attention2 = MultiHeadAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)
        self.ffn = FeedForward(embed_size, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        attn1 = self.attention1(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn1))
        attn2 = self.attention2(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + self.dropout(attn2))
        ffn_out = self.ffn(x)
        return self.norm3(x + self.dropout(ffn_out))

In [None]:
def train_transformer(model, dataset, epochs=10, batch_size=32, lr=0.001):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for src, tgt in dataloader:
            optimizer.zero_grad()
            output = model(src, tgt)
            loss = criterion(output.view(-1, output.shape[-1]), tgt.view(-1))
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

In [None]:
input_dim = 10000  # Vocabulary size
embed_size = 512
num_heads = 8
num_layers = 6
hidden_dim = 2048
output_dim = 10000

In [None]:
class Transformer(nn.Module):
    def __init__(self, input_dim, embed_size, num_heads, num_layers, hidden_dim, output_dim):
        super(Transformer, self).__init__()
        self.encoder = nn.ModuleList([TransformerEncoderLayer(embed_size, num_heads, hidden_dim) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([TransformerDecoderLayer(embed_size, num_heads, hidden_dim) for _ in range(num_layers)])
        self.embedding = nn.Embedding(input_dim, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        self.fc_out = nn.Linear(embed_size, output_dim)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.embedding(src)
        src = self.pos_encoding(src)
        tgt = self.embedding(tgt)
        tgt = self.pos_encoding(tgt)

        for layer in self.encoder:
            src = layer(src, src_mask)

        for layer in self.decoder:
            tgt = layer(tgt, src, src_mask, tgt_mask)

        return self.fc_out(tgt)

In [None]:
model = Transformer(input_dim, embed_size, num_heads, num_layers, hidden_dim, output_dim)
dataset = TensorDataset(torch.randint(0, input_dim, (1000, 20)), torch.randint(0, input_dim, (1000, 20)))
train_transformer(model, dataset)