<a href="https://colab.research.google.com/github/Zerldas/Python_Exercise/blob/main/Encoder-Decoder/encoder_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cài thư viện

!pip install spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import random

In [None]:
# thiét lâp thiết bị tính toán
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", DEVICE)

In [None]:
# Tải dataset
# Tokenizer của spaCy cho tiếng Anh và tiếng Đức
tokenizer_en = get_tokenizer('spacy', language='en_core_web_sm')
tokenizer_de = get_tokenizer('spacy', language='de_core_news_sm')
# Lấy iterator cho training dataset
train_iter = Multi30k(split='train', language_pair=('en','de'))

In [None]:
def yield_tokens(data_iter, tokenizer, index):
    for src, tgt in data_iter:
      # index=0 -> input (EN), index=1 -> target (DE)
      yield tokenizer(src if index==0 else tgt)

In [None]:
vocab_en = build_vocab_from_iterator(yield_tokens(train_iter, tokenizer_en, 0),
                                    specials=["<unk>", "<pad>", "<sos>", "<eos>"])
vocab_en.set_default_index(vocab_en["<unk>"])
train_iter = Multi30k(split='train', language_pair=('en','de'))  # reload iterator

vocab_de = build_vocab_from_iterator(yield_tokens(train_iter, tokenizer_de, 1),
                                    specials=["<unk>", "<pad>", "<sos>", "<eos>"])
vocab_de.set_default_index(vocab_de["<unk>"])

In [None]:
# Khai báo thông số
INPUT_DIM = len(vocab_en)
OUTPUT_DIM = len(vocab_de)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
BATCH_SIZE = 32
N_EPOCHS = 10
LEARNING_RATE = 0.01

In [None]:
# <unk>: unknown token <pad>: padding cho batch <sos>: start of sentence <eos>: end of sentence
def tensorize(sentence, vocab, tokenizer):
    tokens = [vocab["<sos>"]] + [vocab[t] for t in tokenizer(sentence)] + [vocab["<eos>"]]
    return torch.tensor(tokens, dtype=torch.long)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, (hidden, cell) = self.rnn(embedded)
        return hidden, cell

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        embedded = self.embedding(input)
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        hidden, cell = self.encoder(src)
        input = trg[0,:]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1
        return outputs

In [None]:
# Build model
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM).to(DEVICE)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM).to(DEVICE)
model = Seq2Seq(encoder, decoder, DEVICE).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=vocab_de["<pad>"])

In [None]:
train_data = list(Multi30k(split='train', language_pair=('en','de')))
for epoch in range(N_EPOCHS):
    total_loss = 0
    random.shuffle(train_data)
    for i in range(0, len(train_data), BATCH_SIZE):
        batch = train_data[i:i+BATCH_SIZE]
        optimizer.zero_grad()
        src_batch = [tensorize(pair[0], vocab_en, tokenizer_en) for pair in batch]
        trg_batch = [tensorize(pair[1], vocab_de, tokenizer_de) for pair in batch]
        src_batch = nn.utils.rnn.pad_sequence(src_batch).to(DEVICE)
        trg_batch = nn.utils.rnn.pad_sequence(trg_batch).to(DEVICE)
        # forward + loss
        output = model(src_batch, trg_batch)
        output_dim = output.shape[-1]
        loss = criterion(output[1:].view(-1, output_dim), trg_batch[1:].view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{N_EPOCHS}, Loss: {total_loss:.2f}")

In [None]:
def translate_sentence(sentence):
    model.eval()
    with torch.no_grad():
        src_tensor = tensorize(sentence, vocab_en, tokenizer_en).unsqueeze(1).to(DEVICE)
        trg_indices = [vocab_de["<sos>"]]
        hidden, cell = model.encoder(src_tensor)
        input_tok = torch.tensor([vocab_de["<sos>"]]).to(DEVICE)
        for _ in range(50):
            output, hidden, cell = model.decoder(input_tok, hidden, cell)
            top1 = output.argmax(1).item()
            trg_indices.append(top1)
            input_tok = torch.tensor([top1]).to(DEVICE)
            if top1 == vocab_de["<eos>"]:
                break
        trg_tokens = [vocab_de.lookup_token(idx) for idx in trg_indices[1:-1]]
        return " ".join(trg_tokens)

In [None]:
# Test
print("Translate 'a man is playing a guitar':")
print(translate_sentence("a man is playing a guitar"))