In [None]:
!pip3 install --upgrade transformers accelerate datasets ctranslate2 sentencepiece -q

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

cache_dir = "/workspace/models/cache"

model_name = "facebook/nllb-200-3.3B"


model = AutoModelForSeq2SeqLM.from_pretrained(
                                            model_name,
                                            device_map='auto',
                                            use_cache=False,
                                            cache_dir=cache_dir
                                            )

In [None]:
# Langauge codes supported by NLLB-200

SRC_LANG = "eng_Latn"
TGT_LANG = "fra_Latn"

# SRC_LANG = "eng_Latn"
# TGT_LANG = "por_Latn"

# SRC_LANG = "eng_Latn"
# TGT_LANG = "swh_Latn"

# SRC_LANG = "swh_Latn"
# TGT_LANG = "eng_Latn"


tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          cache_dir=cache_dir,
                                          src_lang=SRC_LANG,
                                          tgt_lang=TGT_LANG
                                          )

# Loading the data

In [None]:
# Load the training dataset

import os

trainsmall = True
trainmedium = False if trainsmall else True

src = "en"
tgt = "fr"

# src = "en"
# tgt = "pt"

# src = "en"
# tgt = "sw"

# src = "sw"
# tgt = "en"


# Change the path to your datasets
directory = f"/workspace/data/{src}-{tgt}/train"

if trainsmall and (tgt == "fr" or tgt == "pt"):
    source_train_file = os.path.join(directory, f"all-filtered.en.real.trainsmall")
    target_train_file = os.path.join(directory, f"all-filtered.{tgt}.real.trainsmall")
elif trainmedium and (tgt == "fr" or tgt == "pt"):
    source_train_file = os.path.join(directory, f"all-filtered.en.real.trainmedium")
    target_train_file = os.path.join(directory, f"all-filtered.{tgt}.real.trainmedium")
else:
    source_train_file = os.path.join(directory, f"mixed.filtered.{src}.real.train")
    target_train_file = os.path.join(directory, f"mixed.filtered.{tgt}.real.train")
    

with open(source_train_file, encoding="utf-8") as source, open(target_train_file, encoding="utf-8") as target:
    source_train_sentences = [sent.strip() for sent in source]
    target_train_sentences = [sent.strip() for sent in target]

print(source_train_file, target_train_file, sep="\n")
print(len(source_train_sentences))
print(source_train_sentences[10])
print(target_train_sentences[10])

In [None]:
# Load the test dataset

directory = f"/workspace/data/{src}-{tgt}/test"

if tgt == "fr" or tgt == "pt":
    source_test_file = os.path.join(directory, f"all-filtered.en.real.test")
    target_test_file = os.path.join(directory, f"all-filtered.{tgt}.real.test")
else:
    source_test_file = os.path.join(directory, f"medical.filtered.{src}.real.test")
    target_test_file = os.path.join(directory, f"medical.filtered.{tgt}.real.test")

with open(source_test_file, encoding="utf-8") as source, open(target_test_file, encoding="utf-8") as target:
    source_test_sentences = [sent.strip() for sent in source][:1000]
    target_test_sentences = [sent.strip() for sent in target][:1000]

print(source_test_file, target_test_file, sep="\n")
print(len(source_test_sentences))
print(source_test_sentences[0])
print(target_test_sentences[0])

In [None]:
# # Test inference

# from transformers import pipeline

# translator = pipeline("translation",
#                       model=model,
#                       tokenizer=tokenizer,
#                       src_lang=SRC_LANG,
#                       tgt_lang=TGT_LANG)

# translator(source_sentences[0])[0]["translation_text"]

# Fine-tuning

In [None]:
from datasets import Dataset, DatasetDict, load_dataset

src_key = "sentence_" + SRC_LANG
tgt_key = "sentence_" + TGT_LANG

data_train = []

for src_sent, tgt_sent in zip(source_train_sentences, target_train_sentences):
    data_train.append({src_key: src_sent, tgt_key: tgt_sent})

data_test = []

for src_sent, tgt_sent in zip(source_test_sentences, target_test_sentences):
    data_test.append({src_key: src_sent, tgt_key: tgt_sent})

data_finetune = Dataset.from_list(data_train)
data_validate = Dataset.from_list(data_test)

print(data_finetune)
print(data_validate)

In [None]:
def tokenize_fn(examples):
    return tokenizer(examples[src_key], text_target=examples[tgt_key], padding="max_length", truncation=True, max_length=512)

tokenized_finetune = data_finetune.map(tokenize_fn, batched=True)
tokenized_validate = data_validate.map(tokenize_fn, batched=True)

In [None]:
print(tokenized_finetune)
print(tokenized_validate)

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, logging

epochs = 1

learning_rate = 5e-5

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=f"/workspace/models/{src}-{tgt}",
    num_train_epochs=epochs,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    eval_accumulation_steps=4,
    #gradient_checkpointing=True,
    fp16=True,
    fp16_full_eval=True,

    learning_rate=learning_rate,
    lr_scheduler_type='constant',  # "constant", "linear", "cosine"
    
    eval_strategy="steps",  # or "epoch"
    eval_steps=100,
    save_strategy="epoch",
    logging_steps=50,
    report_to='none',       
)

# Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_finetune,
    eval_dataset=tokenized_validate,
)

In [None]:
# training_args

In [None]:
# Start training
trainer.train()

# Convert to CTranslate2

In [None]:
epoch = 1

output_dir = f"/workspace/models/{src}-{tgt}/saved_model_{learning_rate}_constant_epoch-{epoch}"

trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
quantization = "float16"  # or "int8"

ct2_output_dir = f"/workspace/models/{src}-{tgt}/ct2_model_{learning_rate}_constant_epoch-{epoch}_{quantization}" 

!ct2-transformers-converter --model {output_dir} \
--output_dir {ct2_output_dir} \
--quantization {quantization} --force \
 && echo "CTranslate2 model saved at: {ct2_output_dir}"