In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        divterm = torch.arange(0, d_model, 2).float() / d_model
        divterm = 1 / (10000 ** divterm)
        position = torch.arange(max_len).float().unsqueeze(1)
        pe[:, 0::2] = torch.sin(divterm * position)
        pe[:, 1::2] = torch.cos(divterm * position)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

In [13]:
import torch.nn.functional as F
import math

def attention_mechanism(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

In [19]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads=1):
    super(MultiHeadAttention, self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model

    self.W_q = nn.Linear(d_model, d_model * num_heads)
    self.W_k = nn.Linear(d_model, d_model * num_heads)
    self.W_v = nn.Linear(d_model, d_model * num_heads)
    self.W_o = nn.Linear(d_model * num_heads, d_model)

  def forward(self, x, mask=None):
    Q = self.W_q(x)
    K = self.W_k(x)
    V = self.W_v(x)

    Q = Q.view(Q.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(K.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(V.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)

    output, attention_weights = attention_mechanism(Q, K, V, mask)

    return self.W_o(output.transpose(1, 2).contiguous().view(output.size(0), output.size(2), -1))

In [26]:
class Encoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super(Encoder, self).__init__()
    self.multihead_attention = MultiHeadAttention(d_model, num_heads)
    self.layer_norm1 = nn.LayerNorm(d_model)
    self.layer_norm2 = nn.LayerNorm(d_model)
    self.d_ff = nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Linear(d_ff, d_model)
    )

  def forward(self, x):
    attention_output = self.multihead_attention(x)
    x = self.layer_norm1(x + attention_output)
    ff_output = self.d_ff(x)
    x = self.layer_norm2(x + ff_output)

    return x

In [27]:
class EncoderOnlyTransformer(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
    super(EncoderOnlyTransformer, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model)
    self.encoder_layers = nn.ModuleList([
        Encoder(d_model, num_heads, d_ff) for _ in range(num_layers)
    ])

  def forward(self, x):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    for layer in self.encoder_layers:
      x = layer(x)

    return x

In [28]:
class CrossAttention(nn.Module):
  def __init__(self, d_model, num_heads=1):
    super(CrossAttention, self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model

    self.W_q = nn.Linear(d_model, d_model * num_heads)
    self.W_k = nn.Linear(d_model, d_model * num_heads)
    self.W_v = nn.Linear(d_model, d_model * num_heads)
    self.W_o = nn.Linear(d_model * num_heads, d_model)

  def forward(self, src, tgt, mask=None):
    Q = self.W_q(tgt)
    K = self.W_k(src)
    V = self.W_v(src)

    Q = Q.view(Q.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(K.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(V.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)

    output, attention_weights = attention_mechanism(Q, K, V, mask)

    return self.W_o(output.transpose(1, 2).contiguous().view(output.size(0), output.size(2), -1))

In [29]:
class Decoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super(Decoder, self).__init__()
    self.cross_attention = CrossAttention(d_model, num_heads)
    self.layer_norm1 = nn.LayerNorm(d_model)
    self.layer_norm2 = nn.LayerNorm(d_model)
    self.layer_norm3 = nn.LayerNorm(d_model)
    self.d_ff = nn.Sequential([
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Linear(d_ff, d_model)
    ])
    self.masked_multihead_attention = MultiHeadAttention(d_model, num_heads)

  def forward(self, x, encoder_output, mask):
    cross_attention_output = self.cross_attention(encoder_output, x, mask)
    x = self.layer_norm1(x + cross_attention_output)
    masked_multihead_attention_output = self.masked_multihead_attention(x, mask)
    x = self.layer_norm2(x + masked_multihead_attention_output)
    ff_output = self.d_ff(x)
    x = self.layer_norm3(x + ff_output)

    return x

In [30]:
class DecoderPart(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, num_layers, dff):
    super(DecoderPart, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(vocab_size, d_model)
    self.decoder_layers = nn.ModuleList([Decoder(d_model, num_heads, dff, vocab_size) for _ in range(num_layers)])

  def forward(self, x, encoder_output, mask):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for layer in self.decoder_layers:
      x = layer(x, encoder_output, mask)

    return x

In [31]:
class Transformer(nn.Module):
  def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, dff):
    super(Transformer, self).__init__()
    self.encoder = EncoderOnlyTransformer(src_vocab_size, d_model, num_heads, dff, num_layers)
    self.decoder = DecoderPart(tgt_vocab_size, d_model, num_heads, dff, num_layers)
    self.final_layer = nn.Linear(d_model, tgt_vocab_size)

  def forward(self, src, tgt, mask):
    encoder_output = self.encoder(src)
    decoder_output = self.decoder(tgt, encoder_output, mask)
    output = self.final_layer(decoder_output)

    return output