In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import nltk
from nltk.corpus import gutenberg
from collections import Counter

In [2]:
pairs = [
    ("a man is playing guitar", "ein mann spielt gitarre"),
    ("a woman is cooking", "eine frau kocht"),
    ("children are playing football", "kinder spielen fußball"),
    ("he is reading a book", "er liest ein buch"),
    ("she is riding a bike", "sie fährt fahrrad")
]

In [3]:
SPECIALS = ["<unk>", "<pad>", "<bos>", "<eos>"]

In [4]:
def tokenize_text(text):
    text = text.lower()
    tokens = nltk.word_tokenize(text)
    return tokens

In [None]:
def build_vocab(sentences, tokenizer=tokenize_text):
    counter = Counter()
    for s in sentences:
        counter.update(tokenizer(s))
    
    vocab = {word: i+len(SPECIALS) for i, word in enumerate(counter.keys())}
    
    for i, sp in enumerate(SPECIALS):
        vocab[sp] = i
    
    return vocab

In [6]:
vocab_en = build_vocab([src for src, _ in pairs], tokenizer=tokenize_text)
vocab_de = build_vocab([tgt for _, tgt in pairs], tokenizer=tokenize_text)

In [7]:
print("English vocab:", vocab_en)
print("German vocab:", vocab_de)

English vocab: {'a': 4, 'man': 5, 'is': 6, 'playing': 7, 'guitar': 8, 'woman': 9, 'cooking': 10, 'children': 11, 'are': 12, 'football': 13, 'he': 14, 'reading': 15, 'book': 16, 'she': 17, 'riding': 18, 'bike': 19, '<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3}
German vocab: {'ein': 4, 'mann': 5, 'spielt': 6, 'gitarre': 7, 'eine': 8, 'frau': 9, 'kocht': 10, 'kinder': 11, 'spielen': 12, 'fußball': 13, 'er': 14, 'liest': 15, 'buch': 16, 'sie': 17, 'fährt': 18, 'fahrrad': 19, '<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3}


In [8]:
def encode_sentence(sentence, vocab, tokenizer, add_bos=True, add_eos=True):
    tokens = tokenizer(sentence)
    ids = []
    if add_bos:
        ids.append(vocab["<bos>"])
    ids += [vocab.get(tok, vocab["<unk>"]) for tok in tokens]
    if add_eos:
        ids.append(vocab["<eos>"])
    return torch.tensor(ids, dtype=torch.long)

In [9]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, 
                 nhead, src_vocab_size, tgt_vocab_size, dim_feedforward=512):
        super(Seq2SeqTransformer, self).__init__()
        
        self.transformer = nn.Transformer(
            d_model=emb_size, 
            nhead=nhead, 
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward
        )
        
        self.src_embed = nn.Embedding(src_vocab_size, emb_size)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, emb_size)
        self.fc_out = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt):
        src_emb = self.src_embed(src).transpose(0, 1)
        tgt_emb = self.tgt_embed(tgt).transpose(0, 1)

        output = self.transformer(src_emb, tgt_emb)
        return self.fc_out(output)


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SRC_VOCAB_SIZE = len(vocab_en)
TGT_VOCAB_SIZE = len(vocab_de)
EMB_SIZE = 128
NHEAD = 4
FFN_HID_DIM = 256
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD,
                           SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=vocab_de["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(20):
    total_loss = 0
    for src_text, tgt_text in pairs:
        src = encode_sentence(src_text, vocab_en, tokenize_text).unsqueeze(0).to(device)
        tgt = encode_sentence(tgt_text, vocab_de, tokenize_text).unsqueeze(0).to(device)

        optimizer.zero_grad()
        output = model(src, tgt[:, :-1])
        loss = loss_fn(output.reshape(-1, output.shape[-1]), tgt[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}, Loss: {total_loss:.4f}")


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 0, Loss: 15.6510
Epoch 1, Loss: 10.1689
Epoch 2, Loss: 6.9444
Epoch 3, Loss: 4.8464
Epoch 4, Loss: 2.9222
Epoch 5, Loss: 2.0507
Epoch 6, Loss: 1.4303
Epoch 7, Loss: 0.9970
Epoch 8, Loss: 0.9022
Epoch 9, Loss: 0.5630
Epoch 10, Loss: 0.4751
Epoch 11, Loss: 0.4398
Epoch 12, Loss: 0.3492
Epoch 13, Loss: 0.3261
Epoch 14, Loss: 0.2589
Epoch 15, Loss: 0.2519
Epoch 16, Loss: 0.2344
Epoch 17, Loss: 0.2186
Epoch 18, Loss: 0.1935
Epoch 19, Loss: 0.1535


In [11]:
def translate(model, sentence, max_len=20):
    model.eval()
    src = encode_sentence(sentence, vocab_en, tokenize_text).unsqueeze(0).to(device)
    tgt = torch.tensor([[vocab_de["<bos>"]]], dtype=torch.long).to(device)

    for _ in range(max_len):
        output = model(src, tgt)
        next_token = output.argmax(2)[:, -1]  # Shape: (1, output_length)
        if next_token.dim() == 0:
            next_token_value = next_token.item()  # Get the actual value
        else:
            next_token_value = next_token.argmax()  # Get the index of the highest value
        tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
        if next_token_value == vocab_de["<eos>"]:
            break

    return " ".join([list(vocab_de.keys())[list(vocab_de.values()).index(tok)] 
                     for tok in tgt.squeeze().tolist()])
print(translate(model, "a man is playing guitar"))


<bos> ein ein mann ein mann mann spielt
