In [None]:
# For automatically reload import package
%load_ext autoreload
%autoreload 2

# Set Huggging Face Cache dir
import os
cache_dir = '/lustre/umt3/user/manyuan/CourseWork/huggingface'
os.environ['HF_HOME'] = cache_dir

# System library
import random
import gc

# External library
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange
from matplotlib import pyplot as plt

# Local library
from dataset import DataSet
import transformer as tfr
import seq2seq as s2s

# Prepare Translation DataSet

## (1) Tatoeba

In [None]:
# Read file and create dataset
datafile = './data/fra.txt'
data = DataSet(max_length=128, source='en', target='fr')
data.read_file(datafile)

In [None]:
sample, sample_dec = data.tokenize()

In [None]:
# Check specific sample at specific index
sample_id = 0
print(sample['input_ids'][sample_id])
print(sample['labels'][sample_id])
print(sample_dec[sample_id])
print(data.tokenizer.decode(sample['labels'][sample_id], skip_special_tokens=True))

## (2) Ted talks

In [None]:
data_ted = DataSet(max_length=128, source='en', target='fr')

In [None]:
data_ted.read_xml('./data/IWSLT17.TED.tst2017.fr-en.fr.xml', './data/IWSLT17.TED.tst2017.en-fr.en.xml')

In [None]:
sample_ted, sample_ted_dec = data_ted.tokenize()

In [None]:
print(sample_ted['input_ids'][0])
print(sample_ted['labels'][0])
print(sample_ted_dec[0])
print(data_ted.tokenizer.decode(sample_ted['labels'][0], skip_special_tokens=True))

So now we can treat both dataset in the same way.

## Prepare Dataloader for training

## Train dataloader

In [None]:
tensors = list()
for i in tqdm(range(len(sample_dec))):
    input_ids  = sample['input_ids'][i]
    valid_lens = sample['attention_mask'][i].sum()
    labels     = sample['labels'][i]
    dec_inputs = sample_dec[i]
    tensors.append((input_ids, dec_inputs, valid_lens, labels))

In [None]:
num_train_samples = len(tensors)
split = int(num_train_samples*0.8)
random.shuffle(tensors)

train_dataloader = DataLoader(tensors[:split], batch_size=128, shuffle=True)
dev_dataloader = DataLoader(tensors[split:], batch_size=8)

for batch in train_dataloader:
    enc_inputs, dec_inputs, valid_lens, labels = batch
    print(enc_inputs.shape)
    print(dec_inputs.shape)
    print(valid_lens.shape)
    print(labels.shape)
    break

## Test dataloader

In [None]:
tensors = list()
for i in tqdm(range(len(sample_ted_dec))):
    input_ids  = sample_ted['input_ids'][i]
    valid_lens = sample_ted['attention_mask'][i].sum()
    labels     = sample_ted['labels'][i]
    dec_inputs = sample_ted_dec[i]
    tensors.append((input_ids, dec_inputs, valid_lens, labels))

In [None]:
test_dataloader = DataLoader(tensors, batch_size=8)

for batch in test_dataloader:
    enc_inputs, dec_inputs, valid_lens, labels = batch
    print(enc_inputs.shape)
    print(dec_inputs.shape)
    print(valid_lens.shape)
    print(labels.shape)
    break

# Create NMT model

In [None]:
# Create transformer Seq2Seq model
# input parameters of encoder and decoder
# (vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, use_bias=False)
# vocab_size  = data_ted.vocab_size
vocab_size  = data.vocab_size + 1 # add a bos token
num_hiddens = 256
ffn_hiddens = 64
num_heads   = 4
num_blks    = 2
dropout     = 0.5

# Use transformer encoder/decoder. Can also use GRU encoder/decoder
encoder = tfr.TransformerEncoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)
decoder = tfr.TransformerDecoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)

# Seq2Seq model
padding_index = data.tokenizer.pad_token_id
# padding_index = data_ted.tokenizer.pad_token_id
lr = 5e-4

model = s2s.Seq2Seq(encoder, decoder, padding_index, lr)

In [None]:
## Create GRU Seq2Seq model
# vocab_size  = data_ted.vocab_size
vocab_size  = data.vocab_size + 1 # add a <bos> special token
embed_size = 256
num_hiddens = 256
num_layers = 2
dropout = 0.5

encoder = s2s.Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers, dropout)
decoder = s2s.Seq2SeqDecoder(vocab_size, embed_size, num_hiddens, num_layers, dropout)

# padding_index = data_ted.tokenizer.pad_token_id
padding_index = data.tokenizer.pad_token_id
lr = 5e-4

model = s2s.Seq2Seq(encoder, decoder, padding_index, lr)

# Training our NMT models

In [None]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(device)

epochs = 10

# Use wandb to monitor the model

In [None]:
# Load model from previous checkpoint
model.load_state_dict(torch.load('models/transformer.pt'))
model.eval()

In [None]:
# Load model from previous checkpoint
model.load_state_dict(torch.load('models/gru.pt'))
model.eval()

In [None]:
# Get evaluation set loss
def get_loss(model, device, dataloader):
    dev_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            a, b, c, d = batch
            enc_inputs = a.to(device)
            dec_inputs = b.to(device)
            valid_lens = c.to(device)
            labels     = d.to(device)
            
            Y_hat = model(enc_inputs, dec_inputs, valid_lens)
            
            loss = model.loss(Y_hat.transpose(1, 2), labels)
            dev_loss += loss.item()
    
    return dev_loss

In [None]:
losses = list()

model.to(device)
model.train()
for epoch in trange(epochs):
    for batch in tqdm(train_dataloader):
        a, b, c, d = batch
        enc_inputs = a.to(device)
        dec_inputs = b.to(device)
        valid_lens = c.to(device)
        labels     = d.to(device)
        
        Y_hat = model(enc_inputs, dec_inputs, valid_lens)
        
        loss = model.loss(Y_hat.transpose(1, 2), labels)
        
        model.optimizer.zero_grad()
        loss.backward()
        model.optimizer.step()
        
        losses.append(loss.item())

    if (epoch+1)%2 == 0:
        torch.save(model.state_dict(), f'models/transformer_{epoch}.pt')

In [None]:
plt.plot(losses)
plt.grid()
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.savefig('transformer_loss.pdf')

# Save checkpoint for later use

In [None]:
torch.save(model.state_dict(), 'models/transformer.pt')

In [None]:
torch.save(model.state_dict(), 'models/gru.pt')

# Create a SMT (Statistical Machine Translation) model as baseline

In [None]:
import nltk
nltk.download('swadesh')
from nltk.corpus import swadesh
en2fr = [ (i.lower(), j.lower()) for i, j in swadesh.entries(['en', 'fr'])]
translation_dict = dict(en2fr)

In [None]:
def translate_sentence(sentence):
    """
    Translate a sentence using the translation dictionary.
    
    Args:
    sentence (str): Input sentence in English.
    
    Returns:
    str: Translated sentence in French.
    """
    # Tokenize the input sentence
    tokens = sentence.lower().split()
    
    # Translate each token using the dictionary, if available
    translated_tokens = [translation_dict.get(token, token) for token in tokens]
    
    # Join the translated tokens to form the translated sentence
    translated_sentence = ' '.join(translated_tokens)
    
    return translated_sentence

# Example usage
english_sentence = "far ."
french_translation = translate_sentence(english_sentence)
print("English:", english_sentence)
print("French:", french_translation)

# Try pretrained Marian MT model (HuggingFace)

In [None]:
from transformers import MarianTokenizer, MarianMTModel
src = 'en'  # source language
tgt = 'fr'  # target language
sample_text = "We are connecting."

mname = f'Helsinki-NLP/opus-mt-{src}-{tgt}'
model = MarianMTModel.from_pretrained(mname)
tok = MarianTokenizer.from_pretrained(mname)
batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors='pt')
gen = model.generate(**batch)  # for forward pass: model(**batch)
words = tok.batch_decode(gen, skip_special_tokens=True)

# Evaluation of different models

## (1) BLEU and BERT Score

In [None]:
def evaluate_smt(dataloader, batch_size, lang="fr", is_RNN=False, batch_total=0):
    total_bleu = 0
    total_bertscore = 0
    total_count = 0
    batch_count = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            total_count += len(batch[0])
            batch_count += 1
    
            srcs = [data.tokenizer.convert_ids_to_tokens(batch[0][i], skip_special_tokens=True) for i in range(batch_size)]
            tgts = [data.tokenizer.convert_ids_to_tokens(batch[3][i], skip_special_tokens=True) for i in range(batch_size)]
            
            for src, tgt in zip(srcs, tgts):
                str_src = data.tokenizer.convert_tokens_to_string(src)
                str_tgt = data.tokenizer.convert_tokens_to_string(tgt)
        
                if not batch_total:
                    # SMT model
                    print(translate_sentence(str_src))
                    print(f'{str_src} => {translate_sentence(str_tgt)}, bleu, '
                          f'{s2s.bleu(translate_sentence(str_src).split(" "), tgt, k=2):.3f}')
        
                    print(f'{str_src} => {translate_sentence(str_tgt)}, bleu, '
                          f'{s2s.bert_score(translate_sentence(str_src), str_tgt, lang=lang)["f1"][0]:.3f}')
                else:
                    total_bleu += s2s.bleu(translate_sentence(str_src).split(" "), tgt, k=2)
                    total_bertscore += s2s.bert_score(translate_sentence(str_src), str_tgt, lang=lang)["f1"][0]
                
                gc.collect()
    
            if batch_count >= batch_total:
                break
    
    return total_bleu, total_bertscore, total_count

In [None]:
evaluate_smt(test_dataloader, 8, lang="fr", is_RNN=False, batch_total=20)

In [None]:
def evaluate(model, device, dataloader, batch_size, beam_width=0, lang="fr", is_RNN=False, batch_total=0):
    total_bleu = 0
    total_bertscore = 0
    total_count = 0
    batch_count = 0
    
    model.to(device)
    with torch.no_grad():
        for batch in tqdm(dataloader):
            total_count += len(batch[0])
            batch_count += 1
    
            # Beam search decoding
            if beam_width:
                preds = model.beam_search(batch, device, beam_width, data.max_length, is_RNN)
            # Greedy decoding
            else:
                preds, _ = model.predict_step(batch, device, data.max_length)
            srcs = [data.tokenizer.convert_ids_to_tokens(batch[0][i], skip_special_tokens=True) for i in range(batch_size)]
            tgts = [data.tokenizer.convert_ids_to_tokens(batch[3][i], skip_special_tokens=True) for i in range(batch_size)]
            
            for src, tgt, p in zip(srcs, tgts, preds):
                translation = []
                for token in data.tokenizer.convert_ids_to_tokens(p):
                    if token == '</s>':
                        break
                    translation.append(token)
                str_src = data.tokenizer.convert_tokens_to_string(src)
                str_tgt = data.tokenizer.convert_tokens_to_string(tgt)
                pred = data.tokenizer.convert_tokens_to_string(translation)
        
                if not batch_total:
                    print(pred)
                    # BLEU Score
                    print(f'{str_src} => {str_tgt}, bleu, '
                          f'{s2s.bleu(translation, tgt, k=2):.3f}')
                    # BERT Score
                    print(f'{str_src} => {str_tgt}, bert score, '
                          f'{s2s.bert_score(pred, str_tgt, lang=lang)["f1"][0]:.3f}')
                else:
                    total_bleu += s2s.bleu(translation, tgt, k=2)
                    total_bertscore += s2s.bert_score(pred, str_tgt, lang=lang)["f1"][0]
                
                gc.collect()
    
            if batch_count >= batch_total:
                break
        
    return total_bleu, total_bertscore, total_count

In [None]:
model.eval()
evaluate(model, device, dev_dataloader, 8, 2, lang="fr", is_RNN=False, batch_total=20)

In [None]:
model.eval()
evaluate(model, device, test_dataloader, 8, 2, lang="fr", is_RNN=False, batch_total=20)

In [None]:
model.eval()
evaluate(model, device, dev_dataloader, 8, 0, lang="fr", is_RNN=False, batch_total=20)

In [None]:
model.eval()
evaluate(model, device, test_dataloader, 8, 0, lang="fr", is_RNN=False, batch_total=20)

In [None]:
fras = [5682, 21, 2137, 19, 6381, 21, 682, 291, 0,]
engs = [631, 250, 0, 59513]
print(data.tokenizer.convert_tokens_to_string(data.tokenizer.convert_ids_to_tokens(engs)))
print(data.tokenizer.convert_ids_to_tokens(fras))
print(data.tokenizer.decode(fras))
print(data.tokenizer.convert_tokens_to_string(data.tokenizer.convert_ids_to_tokens(fras)))

In [None]:
s2s.bleu("a b c d e", "a b c e f", k=2)

In [None]:
Random DEV (0.0, 98.9659451842308, 160)
Random TEST (0.0, 103.15149623155594, 160)

SMT DEV (0.0, 108.05254489183426, 160)
SMT TEST (0.0, 114.14204689860344, 160)

Transformer Beam   DEV (53.59059118759569, 142.55582463741302, 160)
Transformer Greedy DEV (59.013246619166466, 143.2861720919609, 160)
Transformer Beam   TEST (15.869870679217367, 133.20187187194824, 160)
Transformer Greedy TEST (15.715572405590233, 132.49793833494186, 160)

GRU Beam   DEV (79.33526214634732, 146.7177917957306, 160)
GRU Greedy DEV (84.5888387368823, 147.6853220462799, 160)
GRU Beam   TEST (15.335201627572879, 132.85215973854065, 160)
GRU Greedy TEST (13.138322543843737, 132.49497658014297, 160)