Machine Translate | Seq2Seq with Attention

In [None]:
!pip install tqdm



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import string
import re
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Data Preprocessing

In [None]:
text_file_path = '/content/rus.txt'
with open(text_file_path) as t:
    text = t.read()

def preprocess_text(text):
    text = re.sub("'", '', text)
    text = ''.join(char for char in text if char not in string.punctuation)
    text = re.sub("[0-9]", '', text)
    return text.lower()

def return_sentences(text, num_lines=20000):
    text_lines = text.split('\n')
    english_texts, russian_texts, english_words, russian_words = [], [], set(), set()

    for text_line in tqdm(range(min(len(text_lines), num_lines))):
        if not text_lines[text_line].strip():
            continue
        preprocessed_text_line = preprocess_text(text_lines[text_line])
        tab_split_text = preprocessed_text_line.split('\t')
        if len(tab_split_text) < 2:
            continue

        english_texts.append(tab_split_text[0])
        russian_texts.append('<sos> ' + tab_split_text[1] + ' <eos>')

        english_words.update(tab_split_text[0].split())
        russian_words.update(tab_split_text[1].split())

    # Add special tokens
    english_words.add('<sos>')
    english_words.add('<eos>')
    russian_words.add('<sos>')
    russian_words.add('<eos>')

    return english_texts, russian_texts, sorted(english_words), sorted(russian_words)

# Process text
english_texts, russian_texts, english_words, russian_words = return_sentences(text)

# Create DataFrame
text_df = pd.DataFrame({'English': english_texts, 'Russian': russian_texts})
text_df['English Length'] = text_df['English'].apply(lambda x: len(x.split()))
text_df['Russian Length'] = text_df['Russian'].apply(lambda x: len(x.split()))
text_df = text_df.sample(frac=1, random_state=42)

100%|██████████| 20000/20000 [00:01<00:00, 14413.00it/s]


Vocabulary & Lookup Tables

In [None]:
num_encoder_tokens = len(english_words)
num_decoder_tokens = len(russian_words) + 1

english_lookup = {word: num for num, word in enumerate(english_words)}
russian_lookup = {word: num + 1 for num, word in enumerate(russian_words)}
russian_lookup['<sos>'] = 0  # Add <sos> with index 0
russian_lookup['<eos>'] = num_decoder_tokens - 1
russian_token_lookup = {num: word for word, num in russian_lookup.items()}

Dataset & Dataloader

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, english_texts, russian_texts, english_lookup, russian_lookup):
        self.english_texts = english_texts
        self.russian_texts = russian_texts
        self.english_lookup = english_lookup
        self.russian_lookup = russian_lookup

    def __len__(self):
        return len(self.english_texts)

    def __getitem__(self, idx):
        encoder_input = torch.tensor([self.english_lookup[word] for word in self.english_texts[idx].split()], dtype=torch.long)
        russian_words = self.russian_texts[idx].split()
        decoder_input = torch.tensor([self.russian_lookup[word] for word in russian_words[:-1]], dtype=torch.long)
        decoder_target = torch.tensor([self.russian_lookup[word] for word in russian_words[1:]], dtype=torch.long)
        return encoder_input, decoder_input, decoder_target

def collate_fn(batch):
    encoder_inputs, decoder_inputs, decoder_targets = zip(*batch)
    return pad_sequence(encoder_inputs, batch_first=True), pad_sequence(decoder_inputs, batch_first=True), pad_sequence(decoder_targets, batch_first=True)


X_train, X_valid, y_train, y_valid = train_test_split(text_df['English'], text_df['Russian'], test_size=0.2, random_state=42)

# Create dataloaders
batch_size = 32
train_dataset = TranslationDataset(X_train.tolist(), y_train.tolist(), english_lookup, russian_lookup)
valid_dataset = TranslationDataset(X_valid.tolist(), y_valid.tolist(), english_lookup, russian_lookup)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2)

Model Definition

In [None]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, n_layers, dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)

        # Combine forward and backward hidden states
        hidden = hidden.view(self.n_layers, 2, -1, self.hidden_dim)  # [n_layers, 2, batch_size, hidden_dim]
        hidden = hidden[:, 0, :, :] + hidden[:, 1, :, :]  # Sum forward and backward states
        hidden = hidden.contiguous()  # [n_layers, batch_size, hidden_dim]

        # Combine forward and backward cell states
        cell = cell.view(self.n_layers, 2, -1, self.hidden_dim)  # [n_layers, 2, batch_size, hidden_dim]
        cell = cell[:, 0, :, :] + cell[:, 1, :, :]  # Sum forward and backward states
        cell = cell.contiguous()  # [n_layers, batch_size, hidden_dim]

        return outputs, hidden, cell


class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 3, hidden_dim)  # Adjusted for bidirectional encoder
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs, mask=None):
        # hidden: [batch_size, hidden_dim]
        # encoder_outputs: [batch_size, seq_len, hidden_dim * 2]

        # Repeat hidden state to match sequence length
        seq_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)  # [batch_size, seq_len, hidden_dim]

        combined = torch.cat((hidden, encoder_outputs), dim=2)  # [batch_size, seq_len, hidden_dim * 3]

        # Compute attention energy
        energy = torch.tanh(self.attn(combined))  # [batch_size, seq_len, hidden_dim]
        attention = torch.sum(self.v * energy, dim=2)  # [batch_size, seq_len]

        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)  # Set masked positions to a very low value

        # Normalize attention weights
        return torch.softmax(attention, dim=1).unsqueeze(1)  # [batch_size, 1, seq_len]


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim, n_layers, dropout, attention):
        super(Decoder, self).__init__()

        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim + hidden_dim * 2, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs):
        """
        input: [batch_size]
        hidden: [n_layers, batch_size, hidden_dim]
        cell: [n_layers, batch_size, hidden_dim]
        encoder_outputs: [batch_size, src_len, hidden_dim * 2]
        """
        input = input.unsqueeze(1)  # Reshape to [batch_size, 1]
        embedded = self.dropout(self.embedding(input))  # [batch_size, 1, emb_dim]

        attn_weights = self.attention(hidden[-1], encoder_outputs)  # [batch_size, 1, src_len]
        context = torch.bmm(attn_weights, encoder_outputs)  # [batch_size, 1, hidden_dim * 2]

        # Concatenate context and embedded input
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch_size, 1, emb_dim + hidden_dim * 2]
        # Pass through the LSTM
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        # Output shape after LSTM: [batch_size, seq_len, hidden_dim]
        output = output.squeeze(1)  # Remove the seq_len dimension: [batch_size, hidden_dim]
        # Pass through the fully connected layer
        prediction = self.fc_out(output)  # [batch_size, output_dim]

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        encoder_outputs, hidden, cell = self.encoder(src)

        input = trg[:, 0]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            input = trg[:, t] if teacher_force else output.argmax(1)

        return outputs

Training Setup

In [None]:
input_dim = num_encoder_tokens  # Vocabulary size
embedding_dim, hidden_dim = 256, 512
n_layers = 2
dropout = 0.5
attention = Attention(hidden_dim)
output_dim = num_decoder_tokens
encoder = Encoder(input_dim, embedding_dim, hidden_dim, n_layers, dropout).to(device)
decoder = Decoder(output_dim, embedding_dim, hidden_dim, n_layers, dropout, attention).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

def train(model, dataloader, optimizer, criterion, clip=1.0):
    model.train()
    epoch_loss = 0

    for src, trg_input, trg_output in tqdm(dataloader):
        src, trg_input, trg_output = src.to(device), trg_input.to(device), trg_output.to(device)
        optimizer.zero_grad()
        output = model(src, trg_input)

        output_dim = output.shape[-1]
        loss = criterion(output.view(-1, output_dim), trg_output.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_dataloader, optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}")

100%|██████████| 500/500 [00:15<00:00, 32.15it/s]


Epoch 1/10, Loss: 6.0795


100%|██████████| 500/500 [00:15<00:00, 32.24it/s]


Epoch 2/10, Loss: 5.4514


100%|██████████| 500/500 [00:15<00:00, 32.65it/s]


Epoch 3/10, Loss: 4.9839


100%|██████████| 500/500 [00:14<00:00, 33.81it/s]


Epoch 4/10, Loss: 4.5859


100%|██████████| 500/500 [00:14<00:00, 33.50it/s]


Epoch 5/10, Loss: 4.2667


100%|██████████| 500/500 [00:15<00:00, 32.56it/s]


Epoch 6/10, Loss: 4.0027


100%|██████████| 500/500 [00:15<00:00, 33.07it/s]


Epoch 7/10, Loss: 3.7884


100%|██████████| 500/500 [00:14<00:00, 33.74it/s]


Epoch 8/10, Loss: 3.6227


100%|██████████| 500/500 [00:14<00:00, 33.61it/s]


Epoch 9/10, Loss: 3.4879


100%|██████████| 500/500 [00:15<00:00, 33.30it/s]

Epoch 10/10, Loss: 3.3880





In [None]:
def translate_sentence(sentence, model, english_lookup, russian_token_lookup, max_length=50):
    sentence = preprocess_text(sentence)
    input_tokens = sentence.split()

    input_tensor = torch.tensor([english_lookup.get(word, 0) for word in input_tokens], dtype=torch.long).to(device).unsqueeze(0)

    decoder_input = torch.tensor([russian_lookup['<sos>']], dtype=torch.long).to(device)

    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden, cell = model.encoder(input_tensor)
        output_tokens = []

        for _ in range(max_length):
            output, hidden, cell = model.decoder(decoder_input, hidden, cell, encoder_outputs)
            top1 = output.argmax(1).item()
            output_tokens.append(top1)

            if top1 == russian_lookup['<eos>']:
                break

            decoder_input = torch.tensor([top1], dtype=torch.long).to(device)

    translated_words = [russian_token_lookup[token] for token in output_tokens]
    return ' '.join(translated_words)

In [None]:
#english_sentence = "hello how are you"
# Input (English): hello how are you
# Translated (Russian): как дела —
english_sentence = "I am working"

translated_sentence = translate_sentence(english_sentence, model, english_lookup, russian_token_lookup)

print(f"Input (English): {english_sentence}")
print(f"Translated (Russian): {translated_sentence}")

Input (English): I am working
Translated (Russian): работаю —
