In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# !pip install datasets sentencepiece transformers

In [None]:
# from datasets import load_dataset
# from google.colab import drive
# from IPython.display import display
# from IPython.html import widgets
# import matplotlib.pyplot as plt
# import numpy as np
# import seaborn as sns
# import torch
# from torch import optim
# from torch.nn import functional as F
# from transformers import AdamW, AutoModelForSeq2SeqLM, AutoTokenizer
# from transformers import get_linear_schedule_with_warmup
# from tqdm import tqdm_notebook

# sns.set()

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# print("Using device: %s" % (device))

In [None]:
from datasets import load_dataset
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torch import optim
from torch.nn import functional as F
from transformers import AdamW, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm_notebook

sns.set()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device: %s" % (device))

In [None]:
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer

model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
model.to(device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")

In [None]:
dataset = load_dataset('alt')
train_dataset = dataset['train']
test_dataset = dataset['test']
train_dataset[0]

In [None]:
# LANG_TOKEN_MAPPING = {
#     'vi': '<vi> ',
#     'lo': '<lo> ',
# }

# LANG_TOKEN_MAPPING = {
#     'vi': 'translate Lao to Vietnamese: ',
#     'lo': 'translate Vietnamese to Lao: ',
# }

# LANG_TOKEN_MAPPING = {
#     'vi': tokenizer.get_lang_token("vi"), # __vi__
#     'lo': tokenizer.get_lang_token("lo"), # __lo__
# }

max_seq_len = model.config.max_length

In [None]:
tokenizer.src_lang = "vi"
tokenizer.tgt_lang = "lo"
sampleInputSentence = 'Phiên dịch tiếng Lào: câu này sẽ được dịch thành tiếng Lào.'
sampleOutputSentence = 'ການ​ແປ​ພາ​ສາ​ລາວ​: ປະ​ໂຫຍກ​ນີ້​ຈະ​ຖືກ​ແປ​ເປັນ​ພາ​ສາ​ລາວ​.'

tokenizerOutput = tokenizer(
    text = sampleInputSentence,
    text_target = sampleOutputSentence,
    return_tensors = 'pt',
    padding = 'max_length',
    truncation = True,
    max_length = max_seq_len).to(device)
print(tokenizerOutput['input_ids'])
print(tokenizer.convert_ids_to_tokens(tokenizerOutput['input_ids'][0]))
print(tokenizer.decode(tokenizerOutput['input_ids'][0]))

model.eval()
modelOutput = model(tokenizerOutput['input_ids'],
                    attention_mask = tokenizerOutput['attention_mask'],
                    labels = tokenizerOutput['labels'])

modelGenerate = model.generate(tokenizerOutput['input_ids'], max_new_tokens = max_seq_len, forced_bos_token_id=tokenizer.get_lang_id("lo"))
print(modelGenerate)

output_text = tokenizer.decode(modelGenerate[0])
print(tokenizer.convert_ids_to_tokens(modelGenerate[0]))
print(output_text)

In [None]:
# sorted(tokenizer.vocab.items(), key=lambda x: x[1])

In [None]:
tokenizerOutput = tokenizer(
    text = sampleInputSentence,
    text_target = sampleOutputSentence,
    return_tensors = 'pt',
    padding = 'max_length',
    truncation = True,
    max_length = max_seq_len).to(device)

tokens = tokenizer.convert_ids_to_tokens(tokenizerOutput['input_ids'][0])
print(tokens) # Make sure that the special translation token is not 'fragmented'

In [None]:
def encode_str(text, text_target, tokenizer, seq_len):

    # Tokenize and add special tokens
    tokenizerOutp = tokenizer(
        text = text,
        text_target = text_target,
        return_tensors = 'pt',
        padding = 'max_length',
        truncation = True,
        max_length = seq_len).to(device)

    return tokenizerOutp['input_ids'][0], tokenizerOutp['labels'][0], tokenizerOutp['attention_mask'][0]


def format_translation_data(translations, tokenizer, seq_len=max_seq_len):

    # Choose a random 2 languages for in i/o
    input_lang, target_lang = np.random.choice(['vi', 'lo'], size = 2, replace = False)

    # Get the translations for the batch
    input_text = translations[input_lang]
    target_text = translations[target_lang]

    if input_text is None or target_text is None:
        return None

    if ((input_lang == 'lo') & (target_lang == 'vi')):
        tokenizer.src_lang = "lo"
        tokenizer.tgt_lang = "vi"
    elif ((input_lang == 'vi') & (target_lang == 'lo')):
        tokenizer.src_lang = "vi"
        tokenizer.tgt_lang = "lo"
    else:
        print('WARNING: SOMETHING WRONG WHEN RANDOMIZING LANG')

    input_token_ids, target_token_ids, attention_mask = encode_str(
        input_text, target_text, tokenizer, seq_len)

    return input_token_ids, target_token_ids, attention_mask


def transform_batch(batch, tokenizer):
    inputs = []
    targets = []
    attentionMask = []
    for translation_set in batch['translation']:
        formatted_data = format_translation_data(
            translation_set, tokenizer, max_seq_len)

        if formatted_data is None:
            continue

        input_ids, target_ids, attention_mask = formatted_data

        inputs.append(input_ids.unsqueeze(0))
        targets.append(target_ids.unsqueeze(0))
        attentionMask.append(attention_mask.unsqueeze(0))

    batch_input_ids = torch.cat(inputs).cuda()
    batch_target_ids = torch.cat(targets).cuda()
    attentionMask = torch.cat(attentionMask).cuda()

    return batch_input_ids, batch_target_ids, attentionMask


def get_data_generator(dataset, tokenizer, batch_size = 32):
    dataset = dataset.shuffle()
    for i in range(0, len(dataset), batch_size):
        raw_batch = dataset[i:i+batch_size]
        yield transform_batch(raw_batch, tokenizer)

In [None]:
# Testing `data_transform`
in_ids, out_ids, attention_mask = format_translation_data(
    train_dataset[0]['translation'], tokenizer)

print(' '.join(tokenizer.convert_ids_to_tokens(in_ids)))
print(' '.join(tokenizer.convert_ids_to_tokens(out_ids)))

# Testing data generator
data_gen = get_data_generator(train_dataset, tokenizer, 8)
data_batch = next(data_gen)
print('Input shape:', data_batch[0].shape)
print('Output shape:', data_batch[1].shape)
print('Attention mask shape:', data_batch[2].shape)

In [None]:
# Constants
n_epochs = 14
batch_size = 8
print_freq = 50
checkpoint_freq = 500
lr = 7.5e-4
n_batches = int(np.ceil(len(train_dataset) / batch_size))
total_steps = n_epochs * n_batches
n_warmup_steps = int(total_steps * 0.01)

In [None]:
# Optimizer
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(
    optimizer, n_warmup_steps, total_steps)

In [None]:
losses = []
valLosses = [1e18]

In [None]:
def eval_model(model, gdataset, max_iters = 16):
    
    model.eval()
    test_generator = get_data_generator(gdataset,
                                      tokenizer, batch_size)
    eval_losses = []
    for i, (input_batch, label_batch, attention_mask_batch) in enumerate(test_generator):
        if i >= max_iters:
            break
    
        model_out = model(
            input_ids = input_batch,
            labels = label_batch,
            attention_mask = attention_mask_batch)
        eval_losses.append(model_out.loss.item())
    
    return np.mean(eval_losses)

In [None]:
model_path = 'm2m100_418M_FineTunedEpoch{}.pt'
model_checkpoint = 'm2m100_418M_Checkpoint.pt'

In [None]:
for epoch_idx in range(n_epochs):
    # Randomize data order
    data_generator = get_data_generator(train_dataset,
                                      tokenizer, batch_size)

    for batch_idx, (input_batch, label_batch, attention_mask_batch) \
          in tqdm_notebook(enumerate(data_generator), total=n_batches):
        
        model.train()
        optimizer.zero_grad()
        
        
        # Forward pass
        model_out = model(
            input_ids = input_batch,
            labels = label_batch,
            attention_mask = attention_mask_batch)
        
        # Calculate loss and update weights
        loss = model_out.loss
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Print training update info
        if (batch_idx + 1) % print_freq == 0:
            avg_loss = np.mean(losses[-print_freq:])
            print('Epoch: {} | Step: {}/{} | Avg. loss: {:.3f} | lr: {}'.format(
              epoch_idx+1, batch_idx+1, n_batches, avg_loss, scheduler.get_last_lr()[0]))
        
        if (batch_idx + 1) % checkpoint_freq == 0:
            test_loss = eval_model(model, test_dataset)
            valLosses.append(test_loss)
            print('Test loss {:.3f}'.format(test_loss))
            if (test_loss <= min(valLosses)):
                print('Saving checkpoint...')
                torch.save(model.state_dict(), model_checkpoint)

valLosses.pop(0)
torch.save(model.state_dict(), model_path.format(epoch_idx + 1))

In [None]:
# Graph the loss

window_size = 50
smoothed_losses = []
for i in range(len(losses)-window_size):
  smoothed_losses.append(np.mean(losses[i:i+window_size]))

plt.plot(smoothed_losses[100:])

In [None]:
plt.plot(valLosses)

In [None]:
test_loss = eval_model(model, test_dataset)
test_loss

In [None]:
test_dataset[0]['translation']

In [None]:
testSrc = 'lo'
testTgt = 'vi'
test_sentence = test_dataset[0]['translation'][testSrc]
test_sentence_target = test_dataset[0]['translation'][testTgt]
print('Raw input text:', test_sentence)

tokenizer.src_lang = testSrc
tokenizer.tgt_lang = testTgt
input_ids, _, _ = encode_str(
    text = test_sentence,
    text_target = test_sentence_target,
    tokenizer = tokenizer,
    seq_len = model.config.max_length)
input_ids = input_ids.unsqueeze(0).cuda()

print('Truncated input text:', tokenizer.convert_tokens_to_string(
    tokenizer.convert_ids_to_tokens(input_ids[0])))

In [None]:
output_tokens = model.generate(input_ids, num_beams = 20, num_return_sequences=3, max_new_tokens = max_seq_len, forced_bos_token_id = tokenizer.get_lang_id(testTgt))
# print(output_tokens)
for token_set in output_tokens:
  print(tokenizer.decode(token_set, skip_special_tokens=True))