In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
from torchtext.vocab import build_vocab_from_iterator
import nltk
import random



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# English -> French pairs
pairs = [
    ("i am a student", "je suis un étudiant"),
    ("he is a teacher", "il est un professeur"),
    ("she likes pizza", "elle aime la pizza"),
    ("we are friends", "nous sommes amis"),
]


In [4]:
def tokenize(text):
    return [token.lower() for token in nltk.tokenize.word_tokenize(text) if token.isalnum()]

In [5]:
src_vocab = {"<pad>": 0, "<sos>":1, "<eos>":2}
trg_vocab = {"<pad>": 0, "<sos>":1, "<eos>":2}

In [None]:
for en, fr in pairs:
    for token in tokenize(en):
        if token not in src_vocab:
            src_vocab[token] = len(src_vocab)
    for token in tokenize(fr):
        if token not in trg_vocab:
            trg_vocab[token] = len(trg_vocab)

In [7]:
inv_trg_vocab = {v: k for k, v in trg_vocab.items()}

In [8]:
def encode(sentence, vocab):
    return [vocab["<sos>"]] + [vocab[token] for token in tokenize(sentence)] + [vocab["<eos>"]]

In [None]:
data = [(encode(en, src_vocab), encode(fr, trg_vocab)) for en, fr in pairs]

In [10]:
data

[([1, 3, 4, 5, 6, 2], [1, 3, 4, 5, 6, 2]),
 ([1, 7, 8, 5, 9, 2], [1, 7, 8, 5, 9, 2]),
 ([1, 10, 11, 12, 2], [1, 10, 11, 12, 13, 2]),
 ([1, 13, 14, 15, 2], [1, 14, 15, 16, 2])]

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        
    def forward(self, src):
        embedded = self.embedding(src)
        output, hidden = self.rnn(embedded)
        #  output : [batch_size, seq_len, hidden_dim * 2]
        #  hidden : [n_layers * 2, batch_size, hidden_dim]
        return output, hidden

In [12]:
class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super().__init__()
        self.attn = nn.Linear(enc_hidden_dim * 2 + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Linear(dec_hidden_dim, 1, bias=False)
        
    def forward(self, hidden, encoder_output):
        # hidden : [n_layers, batch_size, dec_hidden_dim]
        # encoder_output : [batch_size, src_len, enc_hidden_dim * 2]
        
        # batch_size = encoder_output.shape[-1]
        src_len = encoder_output.shape[1]
        
        # repeat decoder hidden state src_len times
        hidden = hidden[-1].unsqueeze(1).repeat(1, src_len, 1)
        # hidden : [batch_size, src_len, dec_hidden_dim]
        # encoder_output : [batch_size, src_len, enc_hidden_dim * 2]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_output), dim=2)))
        # energy : [batch_size, src_len, dec_hidden_dim]
        
        attention = self.v(energy).squeeze(2)
        # attention : [batch_size, src_len]
        
        return nn.functional.softmax(attention, dim=1)

In [13]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, enc_hidden_dim, dec_hidden_dim, attention):
        super().__init__()
        self.output_dim = output_dim # [size of vocab]
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.GRU(enc_hidden_dim*2 + embedding_dim, dec_hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(enc_hidden_dim*2 + dec_hidden_dim + embedding_dim, output_dim)
        
    def forward(self, input, hidden, encoder_output):
        # input : [batch_size]
        # hidden : [n_layers, batch_size, dec_hidden_dim]
        # encoder_output : [batch_size, src_len, enc_hidden_dim * 2]
        
        embedded = self.embedding(input).unsqueeze(1)
        # embedded : [batch_size, 1, embedding_dim]
        
        attention = self.attention(hidden, encoder_output)
        # attention : [batch_size, src_len]
        
        attention = attention.unsqueeze(1)
        # attention : [batch_size, 1, src_len]
        
        # encoder_output : [batch_size, src_len, enc_hidden_dim * 2]
        
        weighted = torch.bmm(attention, encoder_output)
        # weighted : [batch_size, 1, enc_hidden_dim * 2]
        
        rnn_input = torch.cat((embedded, weighted), dim=2)
        # rnn_input : [batch_size, 1, enc_hidden_dim * 2 + embedding_dim]
        
        output, hidden = self.rnn(rnn_input, hidden[-1].unsqueeze(0))
        # output : [batch_size, 1, dec_hidden_dim]
        # hidden : [n_layers, batch_size, dec_hidden_dim]
        
        output = output.squeeze(1)
        # output : [batch_size, dec_hidden_dim]
        embedded = embedded.squeeze(1)
        # embedded : [batch_size, embedding_dim]
        weighted = weighted.squeeze(1)
        # weighted : [batch_size, enc_hidden_dim * 2]
        output = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
        # output : [batch_size, output_dim]
        
        return output, hidden

In [14]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src : [batch_size, src_len]
        # trg : [batch_size, trg_len]

        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        enc_outputs, hidden = self.encoder(src)
        hidden = hidden[::2] + hidden[1::2] 
        
        input = trg[:, 0]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, enc_outputs)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1
        
        return outputs

In [15]:
len(src_vocab), len(trg_vocab)

(16, 17)

In [16]:
INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(trg_vocab)
ENC_EMB_DIM = DEC_EMB_DIM = 32
HID_DIM = 64

attn = Attention(HID_DIM, HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, HID_DIM, attn)

model = Seq2Seq(enc, dec, device).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters())
EPOCHS = 100

In [17]:
for epoch in range(EPOCHS):
    epoch_loss = 0
    
    for src, trg in data:
        src = torch.tensor(src).unsqueeze(0).to(device)  # [1, src_len]
        trg = torch.tensor(trg).unsqueeze(0).to(device)  # [1, trg_len]

        optimizer.zero_grad()
        output = model(src, trg)

        output_dim = output.shape[-1]

        output = output[1:].reshape(-1, output_dim)  # [N, C]
        trg = trg[:, 1:].reshape(-1)                    # [N]

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        epoch_loss += loss.item()

    if (epoch + 1) % 10 == 0:    
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss/len(data):.6f}")

Epoch 10/100, Loss: 1.469619
Epoch 20/100, Loss: 0.284535
Epoch 30/100, Loss: 0.047490
Epoch 40/100, Loss: 0.020250
Epoch 50/100, Loss: 0.011936
Epoch 60/100, Loss: 0.007950
Epoch 70/100, Loss: 0.005689
Epoch 80/100, Loss: 0.004341
Epoch 90/100, Loss: 0.003446
Epoch 100/100, Loss: 0.002814


In [18]:
def translate(sentence):
    src = torch.tensor(encode(sentence, src_vocab)).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden = enc(src)
        hidden = hidden[::2] + hidden[1::2]
        input = torch.tensor([trg_vocab["<sos>"]]).to(device)
        result = []
        for _ in range(10):
            output, hidden = dec(input, hidden, encoder_outputs)
            top1 = output.argmax(1).item()
            if top1 == trg_vocab["<eos>"]:
                break
            result.append(inv_trg_vocab[top1])
            input = torch.tensor([top1]).to(device)
    return " ".join(result)

In [None]:
print("\nTranslations:")
for en, fr in pairs:
    translation = translate(en)
    print(f"{en} -> {translation} (expected: {fr})")


Translations:
i am a student -> je suis un étudiant (expected: je suis un étudiant)
he is a teacher -> il est un professeur (expected: il est un professeur)
she likes pizza -> elle aime la pizza (expected: elle aime la pizza)
we are friends -> nous sommes amis (expected: nous sommes amis)
