https://pytorch.org/tutorials/beginner/transformer_tutorial.html

https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int = 100, dropout: float = 0.1):
        """Initialize the PositionalEncoding module."""
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        positional_encoding = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        division_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        positional_encoding[:, 0::2] = torch.sin(position * division_term)
        positional_encoding[:, 1::2] = torch.cos(position * division_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        self.register_buffer("positional_encoding", positional_encoding)

    def forward(self, x):
        """Perform the forward pass of the PositionalEncoding module."""
        x = x + self.positional_encoding[:, : x.size(1)].requires_grad_(False)
        x = self.dropout(x)
        return x

In [3]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        src_vocab_size:int,
        tgt_vocab_size:int,
        src_seq_len:int,
        tgt_seq_len:int,
        d_model:int=512, 
        nhead:int=8,
        num_encoder_layers:int=6, 
        num_decoder_layers:int=6, 
        dim_feedforward:int=2048, 
        dropout:float=0.1,
        ) -> None:
        super().__init__()

        self.model_type="Transformer"
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        self.src_positional_encoding = PositionalEncoding(d_model, src_seq_len, dropout)
        self.tgt_positional_encoding = PositionalEncoding(d_model, tgt_seq_len, dropout)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True, # (batch, seq, d_model)
            )
        self.linear = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src:torch.Tensor, tgt:torch.Tensor) -> torch.Tensor:
        src = self.src_embedding(src) * math.sqrt(self.d_model)
        src = self.src_positional_encoding(src)
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.tgt_positional_encoding(tgt)
        x = self.transformer(src, tgt)
        x = self.linear(x)
        x = F.log_softmax(x, dim=-1)
        return x

In [4]:
batch_size=0
d_model=512

src_vocab_size=10000
tgt_vocab_size=10000

src_seq_length=100
tgt_seq_length=100

model = TransformerModel(
    src_vocab_size= src_vocab_size,
    tgt_vocab_size = tgt_vocab_size,
    src_seq_len = src_seq_length,
    tgt_seq_len = tgt_seq_length,
    )

In [5]:
src = torch.rand(batch_size, src_seq_length).long()
tgt = torch.rand(batch_size, tgt_seq_length).long()

output = model(src, tgt)
print(output.shape) # torch.Size([batch, tgt_seq_len, tgt_vocab_size]) Seq2Seq

torch.Size([0, 100, 10000])
