In [84]:
import torch
import torch.nn as nn
import numpy as np

In [85]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        divterm = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        position = torch.arange(max_len).float().unsqueeze(1)
        pe[:, 0::2] = torch.sin(divterm.unsqueeze(0) * position)
        pe[:, 1::2] = torch.cos(divterm.unsqueeze(0) * position)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

In [86]:
import torch.nn.functional as F
import math

def attention_mechanism(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

In [87]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads=1):
    super(MultiHeadAttention, self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model

    self.W_q = nn.Linear(d_model, d_model * num_heads)
    self.W_k = nn.Linear(d_model, d_model * num_heads)
    self.W_v = nn.Linear(d_model, d_model * num_heads)
    self.W_o = nn.Linear(d_model * num_heads, d_model)

  def forward(self, x, mask=None):
    Q = self.W_q(x)
    K = self.W_k(x)
    V = self.W_v(x)

    Q = Q.view(Q.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(K.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(V.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)

    output, attention_weights = attention_mechanism(Q, K, V, mask)

    return self.W_o(output.transpose(1, 2).contiguous().view(output.size(0), output.size(2), -1))

In [88]:
class Encoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super(Encoder, self).__init__()
    self.multihead_attention = MultiHeadAttention(d_model, num_heads)
    self.layer_norm1 = nn.LayerNorm(d_model)
    self.layer_norm2 = nn.LayerNorm(d_model)
    self.d_ff = nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Linear(d_ff, d_model)
    )

  def forward(self, x):
    attention_output = self.multihead_attention(x)
    x = self.layer_norm1(x + attention_output)
    ff_output = self.d_ff(x)
    x = self.layer_norm2(x + ff_output)

    return x

In [89]:
class EncoderOnlyTransformer(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
    super(EncoderOnlyTransformer, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model)
    self.encoder_layers = nn.ModuleList([
        Encoder(d_model, num_heads, d_ff) for _ in range(num_layers)
    ])

  def forward(self, x):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    for layer in self.encoder_layers:
      x = layer(x)

    return x

In [90]:
class CrossAttention(nn.Module):
  def __init__(self, d_model, num_heads=1):
    super(CrossAttention, self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model

    self.W_q = nn.Linear(d_model, d_model * num_heads)
    self.W_k = nn.Linear(d_model, d_model * num_heads)
    self.W_v = nn.Linear(d_model, d_model * num_heads)
    self.W_o = nn.Linear(d_model * num_heads, d_model)

  def forward(self, src, tgt, mask=None):
    Q = self.W_q(tgt)
    K = self.W_k(src)
    V = self.W_v(src)

    Q = Q.view(Q.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(K.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(V.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)

    output, attention_weights = attention_mechanism(Q, K, V, mask)

    return self.W_o(output.transpose(1, 2).contiguous().view(output.size(0), output.size(2), -1))

In [91]:
class Decoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super(Decoder, self).__init__()
    self.cross_attention = CrossAttention(d_model, num_heads)
    self.layer_norm1 = nn.LayerNorm(d_model)
    self.layer_norm2 = nn.LayerNorm(d_model)
    self.layer_norm3 = nn.LayerNorm(d_model)
    self.d_ff = nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Linear(d_ff, d_model)
    )
    self.masked_multihead_attention = MultiHeadAttention(d_model, num_heads)

  def forward(self, x, encoder_output, mask):
    masked_multihead_attention_output = self.masked_multihead_attention(x, mask)
    x = self.layer_norm1(x + masked_multihead_attention_output)
    cross_attention_output = self.cross_attention(encoder_output, x)
    x = self.layer_norm2(x + cross_attention_output)
    ff_output = self.d_ff(x)
    x = self.layer_norm3(x + ff_output)

    return x

In [92]:
class DecoderPart(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, num_layers, dff):
    super(DecoderPart, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model)
    self.decoder_layers = nn.ModuleList([Decoder(d_model, num_heads, dff) for _ in range(num_layers)])

  def forward(self, x, encoder_output, mask):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for layer in self.decoder_layers:
      x = layer(x, encoder_output, mask)

    return x

In [93]:
class Transformer(nn.Module):
  def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, dff):
    super(Transformer, self).__init__()
    self.encoder = EncoderOnlyTransformer(src_vocab_size, d_model, num_heads, dff, num_layers)
    self.decoder = DecoderPart(tgt_vocab_size, d_model, num_heads, dff, num_layers)
    self.final_layer = nn.Linear(d_model, tgt_vocab_size)

  def forward(self, src, tgt, mask):
    encoder_output = self.encoder(src)
    decoder_output = self.decoder(tgt, encoder_output, mask)
    output = self.final_layer(decoder_output)

    return output

In [94]:
toy_data = [
    ("i love you", "je t'aime"),
    ("how are you", "comment ça va"),
    ("good morning", "bonjour"),
    ("thank you", "merci"),
    ("good night", "bonne nuit"),
    ("see you soon", "à bientôt"),
]



In [95]:
class Vocab:
    def __init__(self, sentences):
        self.special = ['<pad>', '<sos>', '<eos>', '<unk>']
        words = set(w for s in sentences for w in s.split())
        self.itos = self.special + sorted(words)
        self.stoi = {w: i for i, w in enumerate(self.itos)}

    def encode(self, sentence, max_len=10):
        tokens = ['<sos>'] + sentence.split() + ['<eos>']
        tokens += ['<pad>'] * (max_len - len(tokens))
        return [self.stoi.get(t, self.stoi['<unk>']) for t in tokens]

    def decode(self, indices):
        return ' '.join([self.itos[i] for i in indices if self.itos[i] not in ['<sos>', '<eos>', '<pad>']])

src_vocab = Vocab([s for s, _ in toy_data])
tgt_vocab = Vocab([t for _, t in toy_data])

In [96]:
def get_tensor_pairs(data, src_vocab, tgt_vocab, max_len=10):
    src, tgt = [], []
    for s, t in data:
        src.append(torch.tensor(src_vocab.encode(s, max_len)))
        tgt.append(torch.tensor(tgt_vocab.encode(t, max_len)))
    return torch.stack(src), torch.stack(tgt)

src_tensor, tgt_tensor = get_tensor_pairs(toy_data, src_vocab, tgt_vocab)

In [97]:
model = Transformer(
    src_vocab_size=len(src_vocab.itos),
    tgt_vocab_size=len(tgt_vocab.itos),
    d_model=128,
    num_heads=4,
    num_layers=2,
    dff=256
)

In [98]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=src_vocab.stoi['<pad>'])

def create_mask(seq_len):
    return torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)

In [99]:
for epoch in range(10):
    model.train()
    optimizer.zero_grad()

    tgt_input = tgt_tensor[:, :-1]
    tgt_out = tgt_tensor[:, 1:]
    mask = create_mask(tgt_input.shape[1])

    logits = model(src_tensor, tgt_input, mask)
    loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

Epoch 1, Loss: 2.6658
Epoch 2, Loss: 2.8954
Epoch 3, Loss: 2.2919
Epoch 4, Loss: 2.2657
Epoch 5, Loss: 2.1421
Epoch 6, Loss: 2.0681
Epoch 7, Loss: 1.9790
Epoch 8, Loss: 1.8050
Epoch 9, Loss: 1.6190
Epoch 10, Loss: 1.5253


In [103]:
def translate(sentence, max_len=10):
    model.eval()
    # Assuming device is defined elsewhere, e.g., device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    src = torch.tensor(src_vocab.encode(sentence, max_len)).unsqueeze(0).to(device)
    tgt = torch.tensor([[tgt_vocab.stoi['<sos>']]], dtype=torch.long).to(device)

    for _ in range(max_len):
        mask = create_mask(tgt.shape[1]).to(device)
        out = model(src, tgt, mask)
        # next_token shape after argmax is (batch_size,) which is (1,) in this case
        # unsqueeze(0) makes it (1, 1) which matches the last dimension of tgt for concatenation along dim=1
        next_token = out[:, -1].argmax(-1).unsqueeze(0)
        tgt = torch.cat([tgt, next_token], dim=1) # Concatenate (1, seq_len) and (1, 1) along dim 1
        if next_token.item() == tgt_vocab.stoi['<eos>']:
            break

    return tgt_vocab.decode(tgt.squeeze().tolist())

In [104]:
print("English: good night")
print("French : ", translate("good night", 10))

English: good night
French :  bonne bonne bonne bonne bonne bonne bonne bonne bonne bonne
