In [1]:
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
import numpy as np
metric = evaluate.load("bleu")
source_lang = "dyu_Latn"
target_lang = "fra_Latn"
checkpoint = "facebook/nllb-200-distilled-600M"
zindi_ds = load_dataset("uvci/Koumankan_mt_dyu_fr")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#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
# !huggingface-cli login

In [4]:
import re
import sys
import unicodedata
from sacremoses import MosesPunctNormalizer

mpn = MosesPunctNormalizer(lang="en")
mpn.substitutions = [
    (re.compile(r), sub) for r, sub in mpn.substitutions
]

def get_non_printing_char_replacer(replace_by: str = " "):
    non_printable_map = {
        ord(c): replace_by
        for c in (chr(i) for i in range(sys.maxunicode + 1))
        # same as \p{C} in perl
        # see https://www.unicode.org/reports/tr44/#General_Category_Values
        if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
    }

    def replace_non_printing_char(line) -> str:
        return line.translate(non_printable_map)

    return replace_non_printing_char

replace_nonprint = get_non_printing_char_replacer(" ")

def preproc(text):
    clean = mpn.normalize(text)
    clean = replace_nonprint(clean)
    # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
    clean = unicodedata.normalize("NFKC", clean)
    return clean

def preprocess_function(examples):
    inputs = [preproc(example["dyu"]) for example in examples["translation"]]
    targets = [preproc(example["fr"]) for example in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True, padding="max_length")
    # Check for None values in input_ids and labels
    if None in model_inputs["input_ids"] or None in model_inputs["labels"]:
        print("Warning: None values found in tokenized output")
        # Remove examples with None values
        valid_indices = [i for i, (inp, lab) in enumerate(zip(model_inputs["input_ids"], model_inputs["labels"]))
                         if inp is not None and lab is not None]
        for key in model_inputs.keys():
            model_inputs[key] = [model_inputs[key][i] for i in valid_indices]
    return model_inputs

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    print(result)
    result = {"bleu": result["bleu"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [5]:
# source_lang = "dyu_Latn"
# target_lang = "fra_Latn"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, src_lang=source_lang, tgt_lang=target_lang)
# Apply preprocessing to the dataset
tokenized_zds = zindi_ds.map(
    preprocess_function,
    batched=True,
    remove_columns=zindi_ds["train"].column_names  # Remove original columns
)

Map: 100%|██████████| 8065/8065 [00:02<00:00, 3334.19 examples/s]
Map: 100%|██████████| 1471/1471 [00:00<00:00, 2621.41 examples/s]
Map: 100%|██████████| 1393/1393 [00:00<00:00, 3717.07 examples/s]


In [6]:
concat_ds = concatenate_datasets([tokenized_zds['train'], tokenized_zds['test']])

In [8]:
tokenizer.convert_ids_to_tokens(concat_ds[100]['labels'][:10])

['fra_Latn', '▁J', "'", 'hab', 'ite', '▁à', '▁Londres', '.', '</s>', '<pad>']

In [9]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
checkpoint = ("/root/zindi/models/nllb/nllb_output/checkpoint-400_bkp")
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [8]:
import gc, torch
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

In [10]:

training_args = Seq2SeqTrainingArguments(
    output_dir="models/nllb/nllb_output",
    eval_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
    weight_decay=0.01,
    num_train_epochs=200000,
    predict_with_generate=True,
    fp16=True,
    # push_to_hub=False,
    do_train=True,
    do_eval=True,
    gradient_accumulation_steps=10,
    logging_dir= "models/nllb/nllb_output/logs",
    logging_steps = 10,
    save_strategy = 'steps',
    save_steps = 100,
    save_total_limit = 3,
    seed = 42,
    dataloader_drop_last = False,
    eval_steps = 1,
    # label_smoothing_factor: float = 0.0,
    # optim: Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch',
    # resume_from_checkpoint: Optional[str] = None,
    # fp16_backend: str = 'auto',
    # batch_eval_metrics: bool = False,
    # eval_on_start=True,
    generation_max_length= 128,
    generation_num_beams=2,
    # generation_config: Union[str, pathlib.Path, transformers.generation.configuration_utils.GenerationConfig, NoneType] = None,
    
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=concat_ds,
    eval_dataset=tokenized_zds["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
0,6.6478,5.448466,0.0589,9.5887
1,3.1878,2.380498,0.0702,9.9001
2,0.4941,0.359103,0.0787,10.0639
3,0.1805,0.202435,0.0848,9.6587
4,0.1447,0.196484,0.1041,10.4181
5,0.1269,0.196208,0.1125,10.291
6,0.1118,0.200018,0.1209,10.8232
7,0.0984,0.204165,0.1263,10.6329
8,0.0868,0.20982,0.1193,11.0891


{'bleu': 0.05885899356016952, 'precisions': [0.3165684941851684, 0.12213592233009708, 0.06925795053003533, 0.04007177033492823], 'brevity_penalty': 0.5783064201166958, 'length_ratio': 0.6461403337562214, 'translation_length': 6621, 'reference_length': 10247}
{'bleu': 0.07024193217256376, 'precisions': [0.3432579444603493, 0.13635539437896646, 0.07719688542825362, 0.043590478921709204], 'brevity_penalty': 0.6270112949069837, 'length_ratio': 0.6817605152727627, 'translation_length': 6986, 'reference_length': 10247}


Non-default generation parameters: {'max_length': 200}


{'bleu': 0.07866510167033189, 'precisions': [0.35406498396318503, 0.14578947368421052, 0.08354810996563573, 0.049379310344827586], 'brevity_penalty': 0.6511925288555784, 'length_ratio': 0.6998145798770372, 'translation_length': 7171, 'reference_length': 10247}
{'bleu': 0.08481167008800912, 'precisions': [0.37483166242705374, 0.17613200306983884, 0.1068235294117647, 0.061930783242258654], 'brevity_penalty': 0.5866696284750242, 'length_ratio': 0.6521908851371133, 'translation_length': 6683, 'reference_length': 10247}


Non-default generation parameters: {'max_length': 200}


{'bleu': 0.10407974210647918, 'precisions': [0.39852448021462106, 0.18031417112299467, 0.10985626283367557, 0.06648936170212766], 'brevity_penalty': 0.6876235550434502, 'length_ratio': 0.7275300087830584, 'translation_length': 7455, 'reference_length': 10247}
{'bleu': 0.11253189706170484, 'precisions': [0.40888104291146116, 0.19497709146444936, 0.1232504700229789, 0.0781334780249593], 'brevity_penalty': 0.6760426013168009, 'length_ratio': 0.7186493607885235, 'translation_length': 7364, 'reference_length': 10247}


Non-default generation parameters: {'max_length': 200}


{'bleu': 0.1208528904508842, 'precisions': [0.41355760718350826, 0.18924798011187072, 0.11885167464114832, 0.07491289198606271], 'brevity_penalty': 0.7438318289490451, 'length_ratio': 0.7716404801405289, 'translation_length': 7907, 'reference_length': 10247}
{'bleu': 0.12634031382728828, 'precisions': [0.4166234439834025, 0.20076910751482135, 0.13146466640347415, 0.08628659476117104], 'brevity_penalty': 0.7198528187838598, 'length_ratio': 0.7526105201522397, 'translation_length': 7712, 'reference_length': 10247}


Non-default generation parameters: {'max_length': 200}


{'bleu': 0.1193132229888814, 'precisions': [0.4010221889803042, 0.1839413829949626, 0.11566535654126894, 0.07203287406333092], 'brevity_penalty': 0.7577799387196862, 'length_ratio': 0.7828632770566996, 'translation_length': 8022, 'reference_length': 10247}


KeyboardInterrupt: 