In [1]:
# !python -m spacy download fr_core_news_sm
# !python -m spacy download en_core_web_sm

import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import load_dataset, Dataset, DatasetDict
import spacy
# import gc

In [2]:
data_files = {'validation': 'fr-en/validation-00000-of-00001.parquet'}
dataset = load_dataset(path='wmt/wmt14', trust_remote_code=True, data_files=data_files)
data = pd.DataFrame(dataset['validation'])

train, temp = train_test_split(data, test_size=0.4, random_state=0)
test, validation = train_test_split(temp, test_size=0.5, random_state=0)

def process_translations(df):
    en_texts = [item['en'] for item in df['translation']]
    fr_texts = [item['fr'] for item in df['translation']]
    
    return pd.DataFrame({
        'en': en_texts,
        'fr': fr_texts
    })

train_processed = process_translations(train)
test_processed = process_translations(test)
validation_processed = process_translations(validation)

train_dataset = Dataset.from_pandas(train_processed.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test_processed.reset_index(drop=True))
validation_dataset = Dataset.from_pandas(validation_processed.reset_index(drop=True))

ds = DatasetDict({
    'train': train_dataset,
    'test': test_dataset,
    'validation': validation_dataset
})

ds

DatasetDict({
    train: Dataset({
        features: ['en', 'fr'],
        num_rows: 1800
    })
    test: Dataset({
        features: ['en', 'fr'],
        num_rows: 600
    })
    validation: Dataset({
        features: ['en', 'fr'],
        num_rows: 600
    })
})

In [3]:
# del dataset, data
# del train, temp, test, validation
# del process_translations, train_processed, test_processed, validation_dataset
# gc.collect()

In [4]:
en_nlp = spacy.load('en_core_web_sm')
fr_nlp = spacy.load('fr_core_news_sm')

def tokenize_example(example, en_nlp, fr_nlp, max_length, sos_token, eos_token):
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(example['en'])][:max_length]
    fr_tokens = [token.text.lower() for token in fr_nlp.tokenizer(example['fr'])][:max_length]

    en_tokens = [sos_token] + en_tokens + [eos_token]
    fr_tokens = [sos_token] + fr_tokens + [eos_token]

    return {'en_tokens': en_tokens, 'fr_tokens': fr_tokens} 

    
max_length = 1000
sos_token = '<sos>'
eos_token = '<eos>'
pad_token = '<pad>'

fn_kwargs = {
    'en_nlp': en_nlp,
    'fr_nlp': fr_nlp,
    'max_length': max_length,
    'sos_token': sos_token,
    'eos_token': eos_token,
}

train_data, test_data, validation_data = (
    ds['train'],
    ds['test'],
    ds['validation'],
)

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)
validation_data = validation_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

In [5]:
from collections import Counter

def lang_str_int(lang, nlp):
    lang_vocab = []
    special_vocab = ['<unk>', '<pad>', '<sos>', '<eos>'] 

    flattened_list = [token.text.lower() for sentence in lang for token in nlp.tokenizer(sentence)]
    lang_count = Counter(flattened_list)
    lang_words = [string for string, freq in lang_count.items() if freq >= 2]

    lang_vocab = special_vocab + lang_words
    # lang_vocab.extend(special_vocab)
    # lang_vocab.extend(lang_words)

    lang_str2int = {ch: i for i, ch in enumerate(lang_vocab)}
    lang_int2str = {i: ch for i, ch in enumerate(lang_vocab)}

    return lang_str2int, lang_int2str

en = process_translations(data)['en'].tolist()
fr = process_translations(data)['fr'].tolist()

fr_str2int, fr_int2str = lang_str_int(fr, fr_nlp)
en_str2int, en_int2str = lang_str_int(en, en_nlp)

In [6]:
import torch
import numpy as np
import torch.nn as nn

def token_to_int(example, str2int):
    return [str2int.get(token, str2int['<unk>']) for token in example]

def tokens_to_ids(example):
    example['en_ids'] = token_to_int(example['en_tokens'], en_str2int)
    example['fr_ids'] = token_to_int(example['fr_tokens'], fr_str2int)
    return example

train_data = train_data.map(tokens_to_ids)
test_data = test_data.map(tokens_to_ids)
validation_data = validation_data.map(tokens_to_ids)

Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

In [7]:
def reverse_source_lang(example):
    example['en_ids'] = example['en_ids'][::-1]
    return example

train_data = train_data.map(reverse_source_lang)
test_data = test_data.map(reverse_source_lang)
validation_data = validation_data.map(reverse_source_lang)

train_data.set_format(
    type='torch',
    columns=['en_ids', 'fr_ids'],
    output_all_columns=False
)
test_data.set_format(
    type='torch',
    columns=['en_ids', 'fr_ids'],
    output_all_columns=False
)
validation_data.set_format(
    type='torch',
    columns=['en_ids', 'fr_ids'],
    output_all_columns=False
)

Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

In [8]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example['en_ids'] for example in batch]
        batch_fr_ids = [example['fr_ids'] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_fr_ids = nn.utils.rnn.pad_sequence(batch_fr_ids, padding_value=pad_index)
        batch = {
            'en_ids': batch_en_ids,
            'fr_ids': batch_fr_ids
        }
        return batch
    return collate_fn

In [9]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle
    )
    return data_loader

In [10]:
batch_size = 128
PAD_INDEX = en_str2int[pad_token]

train_data_loader = get_data_loader(train_data, batch_size, PAD_INDEX, shuffle=True)
test_data_loader = get_data_loader(test_data, batch_size, PAD_INDEX, shuffle=False)
validation_data_loader = get_data_loader(validation_data, batch_size, PAD_INDEX, shuffle=False)

In [11]:
result = next(iter(train_data_loader))
result['en_ids'].shape, result['fr_ids'].shape

(torch.Size([76, 128]), torch.Size([86, 128]))

In [12]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        input_ = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input_))
        output, (hidden, cell) = self.rnn(embedded, (hidden,cell))
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, trg_len):
        ''' 
        src: [src_len, batch_size]
        trg: [trg_len, batch_size]
        trg_len: length o
        '''
        batch_size = src.shape[1]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        hidden, cell = self.encoder(src)

        input_ = trg[0, :]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input_, hidden, cell)
            outputs[t] = output
            
            top1 = output.argmax(1)
            input_ = top1
        
        return outputs

In [13]:
lang_tensors = next(iter(validation_data_loader))
en_sample = lang_tensors['en_ids']
fr_sample = lang_tensors['fr_ids']

input_dim = len(en_str2int)
output_dim = len(fr_str2int)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
hidden_dim = 512
n_layers = 2
encoder_dropout = 0.5
decoder_dropout = 0.5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CLIP = 1

encoder = Encoder(
    input_dim, 
    encoder_embedding_dim, 
    hidden_dim, 
    n_layers,
    encoder_dropout
)

decoder = Decoder(
    output_dim, 
    decoder_embedding_dim, 
    hidden_dim, 
    n_layers,
    decoder_dropout
)

for param in encoder.parameters():
    torch.nn.init.uniform_(param, a=-0.08, b=0.08)

for param in decoder.parameters():
    torch.nn.init.uniform_(param, a=-0.08, b=0.08)
    
model = Seq2Seq(encoder, decoder, device).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.7)
criterion = nn.CrossEntropyLoss(ignore_index=fr_str2int['<pad>'])

In [14]:
def train(model, dataloader, optimizer, criterion, clip):
    model.train()

    epoch_loss = 0

    for batch in dataloader:
        src = batch['en_ids'].to(device)
        trg = batch['fr_ids'].to(device)

        optimizer.zero_grad()

        output = model(src, trg, trg.shape[0])

        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()

        # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion):
    model.eval()

    epoch_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            src = batch['en_ids'].to(device)
            trg = batch['fr_ids'].to(device)

            output = model(src, trg, trg.shape[0])            

            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

In [None]:
import time
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    mins = int(elapsed_time / 60)
    secs = int(elapsed_time - (mins * 60))
    return mins, secs

N_EPOCHS = 2
CLIP = 1

for epoch in range(N_EPOCHS):
    start_time = time.time()
    
    train_loss = train(model, train_data_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, validation_data_loader, criterion)
    
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

# torch.save(model.state_dict(), 'seq2seq_model_no_teacher_forcing.pth')

In [20]:
torch.save(model.state_dict(), 'seq2seq_model_no_teacher_forcing.pth')

In [21]:
def translate_sentence(sentence, encoder, decoder, src_vocab, trg_vocab, device, max_len=50):
    """
    sentence: list of token indices
    src_vocab: mapping from indices to tokens for source language
    trg_vocab: mapping from indices to tokens for target language
    """
    encoder.eval()
    decoder.eval()
    
    # Convert to tensor and add batch dimension
    src_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)  # [src_len, 1]
    
    with torch.no_grad():
        hidden, cell = encoder(src_tensor)
    
    # Start with <sos> token
    trg_indexes = [fr_str2int['<sos>']]
    
    for _ in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        
        with torch.no_grad():
            output, hidden, cell = decoder(trg_tensor, hidden, cell)
        
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
        
        if pred_token == fr_str2int['<eos>']:
            break
    
    # Convert indices to tokens
    trg_tokens = [int2fr[token] for token in trg_indexes]
    
    return trg_tokens


In [None]:
sample_sentence = [
    en_str2int['<sos>'], 
    en_str2int['i'], 
    en_str2int['am'], 
    en_str2int['a'], 
    en_str2int['student'], 
    en_str2int['<eos>']
]

# Translate
translation = translate_sentence(sample_sentence, encoder, decoder, en_str2int, fr_str2int, device)

print('Translated French Sentence:', ' '.join(translation))

Translated French Sentence: <sos> , , , , , de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de de


In [35]:
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu

def calculate_bleu(data, model, device, max_len=50):
    trgs = []
    pred_trgs = []
    
    for datum in data:
        src = datum['en_ids']
        trg = datum['fr_ids']
        
        src_sentence = src.tolist()[0]
        trg_sentence = trg.tolist()[0]
        
        pred_tokens = translate_sentence(src_sentence, model.encoder, model.decoder, en_str2int, fr_str2int, device, max_len)
        pred_tokens = pred_tokens[1:-1]  # Remove <sos> and <eos>
        
        trg_tokens = [int2fr[token] for token in trg_sentence if token not in [fr_str2int['<sos>'], fr_str2int['<eos>'], fr_str2int['<pad>']]]
        
        trgs.append([trg_tokens])
        pred_trgs.append(pred_tokens)
    
    return corpus_bleu(trgs, pred_trgs)


In [None]:
calculate_bleu(validation_data_loader, model, device=device)

0

In [87]:
# Saving
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
}, 'seq2seq_model.pth')

# Loading
checkpoint = torch.load('seq2seq_model.pth')
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
model = Seq2Seq(encoder, decoder, device).to(device)
model.eval()


  checkpoint = torch.load('seq2seq_model.pth')


Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(4340, 256)
    (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(4707, 256)
    (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)
    (fc_out): Linear(in_features=512, out_features=4707, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)