Fine-tuning best T5 Transformer 🤖
-----------------------------------

In this notebook, we will continue the fine-tuning of T5 transformer on the new extracted sentences from the bool **Grammaire de Wolof Moderne**. We obtained, after a hyperparameter tuning with `wandb`, a best bleu score of **2.47** for french to wolof translation model. We provide, bellow, the main evaluation figures, obtained from the hyperparameter search step.

- Parallel coordinates from panel:


`Parameter importance char` (from [panel]():

![parameter_importance]()

In [1]:
# let us extend the paths of the system
import sys

# path = "/content/drive/MyDrive/Memoire/subject2/T5/"

# sys.path.extend([path, f"{path}new_data"])

In [2]:
# define environment
# %env WANDB_LOG_MODEL=true
# %env WANDB_NOTEBOOK_NAME=fw_training_t5_small_best_model_v2.ipynb
# %env WANDB_API_KEY=237a8450cd2568ea1c8e1f8e0400708e79b6b4ee

In [3]:
# !pip install -qq wandb --upgrade

In [4]:
# !pip install evaluate -qq
# !pip install sacrebleu -qq
# !pip install optuna -qq
# !pip install transformers -qq 
# !pip install tokenizers -qq
# !pip install nlpaug -qq
# !pip install ray[tune] -qq
# !python -m spacy download fr_core_news_lg 

In [5]:
# let us import all necessary libraries
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5TokenizerFast, set_seed
from wolof_translate.utils.sent_transformers import TransformerSequences
from wolof_translate.data.dataset_v2 import T5SentenceDataset
from wolof_translate.utils.sent_corrections import *
from sklearn.model_selection import train_test_split
from nlpaug.augmenter import char as nac
from torch.utils.data import DataLoader
# from datasets  import load_metric # make pip install evaluate instead
# and pip install sacrebleu for instance
from functools import partial
from tqdm import tqdm
import pandas as pd
import numpy as np
import evaluate
import wandb
import torch

# wandb.login(key="237a8450cd2568ea1c8e1f8e0400708e79b6b4ee")


  from .autonotebook import tqdm as notebook_tqdm


--------------

## French to wolof

### Configure dataset 🔠

In [6]:
def split_data(random_state: int = 50):
  """Split data between train, validation and test sets

  Args:
    random_state (int): the seed of the splitting generator. Defaults to 50
  """
  # load the corpora and split into train and test sets
  corpora = pd.read_csv(f"data/additional_documents/diagne_sentences/extractions.csv")

  train_set, test_set = train_test_split(corpora, test_size=0.1, random_state=random_state)

  # let us save the final training set when performing

  train_set, valid_set = train_test_split(train_set, test_size=0.1, random_state=random_state)

  train_set.to_csv(f"data/additional_documents/diagne_sentences/final_train_set.csv", index=False)

  # let us save the sets
  train_set.to_csv(f"data/additional_documents/diagne_sentences/train_set.csv", index=False)

  valid_set.to_csv(f"data/additional_documents/diagne_sentences/valid_set.csv", index=False)

  test_set.to_csv(f"data/additional_documents/diagne_sentences/test_set.csv", index=False)

In [7]:
# recuperate the tokenizer from a json file
tokenizer = T5TokenizerFast(tokenizer_file=f"wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v3.json")


In [8]:
def recuperate_datasets(fr_char_p: float, fr_word_p: float):

  # Create augmentation to add on French sentences
  fr_augmentation = TransformerSequences(nac.KeyboardAug(aug_char_p=fr_char_p, aug_word_p=fr_word_p),
                                        remove_mark_space, delete_guillemet_space)

  # Recuperate the train dataset
  train_dataset_aug = T5SentenceDataset(f"data/additional_documents/diagne_sentences/final_train_set.csv",
                                        tokenizer,
                                        truncation = True,
                                        cp1_transformer = fr_augmentation)

  # Recuperate the test dataset
  test_dataset = T5SentenceDataset(f"data/additional_documents/diagne_sentences/test_set.csv",
                                        tokenizer,
                                        truncation = True)
  
  # Return the datasets
  return train_dataset_aug, test_dataset

### Configure the model and the evaluation function ⚙️

Let us recuperate the model and resize the token embeddings.

In [9]:
def t5_model_init(tokenizer):

    # Initialize the model name
    model_name = 't5-small'
    # model_name = 'data/checkpoints/vf_t5_small_v2_checkpoints_2/' # from checkpoint

    # import the model with its pre-trained weights
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, add_cross_attention = True)

    # resize the token embeddings
    model.resize_token_embeddings(len(tokenizer))

    return model

Embedding(2961, 512)

Let us evaluate the predictions with the `bleu` metric.

In [10]:
# %%writefile wolof-translate/wolof_translate/utils/evaluation.py
from tokenizers import Tokenizer
from typing import *
import numpy as np
import evaluate

class TranslationEvaluation:
    
    def __init__(self, 
                 tokenizer: Tokenizer,
                 decoder: Union[Callable, None] = None,
                 metric = evaluate.load('sacrebleu'),
                 ):
        
        self.tokenizer = tokenizer
        
        self.decoder = decoder
        
        self.metric = metric
    
    def postprocess_text(self, preds, labels):
        
        preds = [pred.strip() for pred in preds]
        
        labels = [[label.strip()] for label in labels]
        
        return preds, labels

    def compute_metrics(self, eval_preds):

        preds, labels = eval_preds

        if isinstance(preds, tuple):
        
            preds = preds[0]
        
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels)

        result = self.metric.compute(predictions=decoded_preds, references=decoded_labels)
        
        result = {"bleu": result["score"]}

        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        
        result["gen_len"] = np.mean(prediction_lens)
        
        result = {k: round(v, 4) for k, v in result.items()}
        
        return result

In [11]:
# %run wolof-translate/wolof_translate/utils/evaluation.py

Let us initialize the evaluation object.

In [12]:
evaluation = TranslationEvaluation(tokenizer)


### Searching for the best parameters 🕖

Let us define the data collator.

In [13]:
def data_collator(batch):
    """Generate a batch of data to provide to trainer

    Args:
        batch (_type_): The batch

    Returns:
        dict: A dictionary containing the ids, the attention mask and the labels
    """
    input_ids = torch.stack([b[0].squeeze(0) for b in batch])
    
    attention_mask = torch.stack([b[1].squeeze(0) for b in batch])
    
    labels = torch.stack([b[2].squeeze(0) for b in batch])
    
    return {'input_ids': input_ids, 'attention_mask': attention_mask,
            'labels': labels}

Let us initialize the training arguments and make random search.

In [None]:
# %%wandb

"""Best parameters
learning_rate = 0.00339254221933438
weight_decay = 0.4268760048515646
train_batch_size = 32
random_state = 25
fr_char_p = 0.7040605028532431
fr_word_p = 0.2651043507758306
eval/bleu = 2.2637
"""
"""Best parameters
learning_rate = 0.0023118906467202416
weight_decay = 0.4
train_batch_size = 16
random_state = 1
fr_char_p = 0.5914542632321074
fr_word_p = 0.4495405182231499
eval/loss = 0.5561882257461548
"""

# Initialize wandb
# wandb.init(project = "fw_small_t5_fine_tuning_v2")

# seed
set_seed(0)

# split the data
split_data(random_state=1)

# let us recuperate the datasets
train_dataset, test_dataset = recuperate_datasets(0.5914542632321074, 0.4495405182231499)

# set training arguments
training_args = Seq2SeqTrainingArguments(f"data/checkpoints/t5_results_fw_v2",
                                    logging_dir="data/logs/results_fw_v2",
                                    num_train_epochs=300,
                                    load_best_model_at_end=True,
                                    save_strategy="epoch",
                                    evaluation_strategy="epoch",
                                    logging_strategy="epoch",
                                    per_device_train_batch_size=16, 
                                    per_device_eval_batch_size=16,
                                    learning_rate=0.0023118906467202416,
                                    # learning_rate=0.00003113,
                                    weight_decay=0.4,
                                    predict_with_generate=True, # we will use predict with generate in order to obtain more valuable test results
                                    fp16 = True,
                                    metric_for_best_model = 'bleu', # a bleu score will be used to find the best model
                                    greater_is_better = True,
                                    save_total_limit = 1, # we will save only the best model
                                    )   

# define training loop
trainer = Seq2SeqTrainer(model_init=partial(t5_model_init, tokenizer = train_dataset.tokenizer),
                  args=training_args,
                  train_dataset=train_dataset, 
                  eval_dataset=test_dataset,
                  data_collator=data_collator,
                  compute_metrics=evaluation.compute_metrics
                  )

# load last checkpoint
# trainer._load_from_checkpoint("data/training2/results/checkpoint-147")

# start training loop
trainer.train()
# trainer.train('data/checkpoints/vf_t5_small_v2_checkpoints/') # from the searching best model
# trainer.train('data/checkpoints/results_fw_v2/last_checkpoint/') # from last checkpoint

# finish wandb
# wandb.finish()


In [35]:
# let us get the best model
model = AutoModelForSeq2SeqLM.from_pretrained('data/checkpoints/results_fw_v2/checkpoint-29898/')

### Predictions

Let us generate texts and store into a DataFrame.

In [36]:

# set the model to eval mode
_ = model.eval()

# run model inference on all test data
original_translations, predicted_translations, original_texts, scores = [], [], [], {}

for data, attention_mask, labels in tqdm(DataLoader(test_dataset)):
    
    # Traduce the sentences
    original_text = tokenizer.decode(data[0], skip_special_tokens=True)
    
    original_translation = tokenizer.decode(labels[0], skip_special_tokens=True)
    
    # get tokens
    generated = torch.tensor(data)
    
    attention_mask = torch.tensor(attention_mask)
    
    # recuperate the pad token id
    pad_token_id = tokenizer.pad_token_id
    
    # perform prediction
    predictions = model.generate(generated, do_sample = False, top_k = 50, max_length = test_dataset.max_len, top_p = 0.90,
                                    temperature = 0, num_return_sequences = 0, attention_mask = attention_mask, pad_token_id = pad_token_id)
    
    # calculate the score and add it to the score
    result = evaluation.compute_metrics((predictions, torch.tensor(labels)))
    
    if not scores: scores.update({k: v for k, v in result.items()})
    
    else: scores.update({k: round(scores[k] + v, 4) for k, v in result.items()})
    
    # decode the predicted tokens into texts
    predicted_translation = list(test_dataset.decode(predictions))
    
    print(predicted_translation[0])
    
    # append results
    original_translations.append(original_translation)
    
    predicted_translations.extend(predicted_translation)
    
    original_texts.append(original_text)

# transform result into data frame
df_ft_to_wf = pd.DataFrame({'original_text': original_texts,
                            'original_label': original_translations,
                            'predicted_label': predicted_translations})

# print the result
df_ft_to_wf.head()

  generated = torch.tensor(data)
  attention_mask = torch.tensor(attention_mask)
  result = evaluation.compute_metrics((predictions, torch.tensor(labels)))


daaw




Moom daal, ñépp ñibbisi nañu.




kuu




Ñun ñii lay set.




mbay




Dem õga te ñëw na.




Dem naa ci keneen ku jigéen.




daõal




rëbb




xéewlu




pas




Demuma




sawwu




daõal




Wutël leneen.




Moo di Lawbe bi




Menn la.




Kooy waxal?




Faatim la, mu ni.




Lëf lan a réer?




Yéen bëgg õgeen woon.




degu àll




Nit kii ci sama wet.




jaambur jaamburlu




fenn




ñooñu




ca




Dafa di dem.




Gis õga nit kookee?




Fee la.




Demuma fa woon




Daõga gis kan?




te




waaye nag




te itam




noonu itam




arafu geen




Õgoor dem.




Wool góor gi dul dem




gise




waxé




Bëgg naa góor gi ñëw, xale yi ñëw, jigéen ñi toog!




Bëgg naa




waaye nag




Nit ku baax õga.




Àna õgoor?




gisé




doj




keneen kule




Foofu, góor gi dem fu rafet la.




Koo gis




Ma õgii dem.




Nit ñenn ñi yegseeguñu.




Seeteegul




Menn la.




Kile la.




nobaate




ca




Dinaa dem




Gis na keneen ki.




xasaw




xar




reew




gisé




Jambaar du bare wax.




yan




wante itam?




waaye nag




amal




soppe




dawal




mbëgg




warugar




Bëgg na soo demee




Leneen lan?




Foofu, dem fu.




te




lemu




fenn




su




mbubb




bëgg




Doo nitu jamm?




duggo




ñennu nit




teõxu




Gayndé genn réerul.




rattkat




odd




fasu




Góor gii bëggóon




Kër gépp




Wool xale yi dul dem




waaye nag




addooku




foofu




Ku dem?




Noonu itam, ma bañ génn




Gis naa sa xarit yépp.




Ku def lii




Demal rekk




Fu mu bëgg.




ba




doj




Du woon.




amal




daõdaõluji




sawwu




fenn




dab




maõkoo




Su bëggul




Y Y Y Y Y daw  Laobe Laobe Laobe fa woon woon woon.




Aminta ñëw?




wante itam




dawal




Góor gee ni nit la, soo demee




Demuma fa woon




nile




Wool xale yi dul dem




dawal




Xar mi mépp.




Gis naa la yaak moom.




Gis õga coroom la woon.




tis




Nit ku góor la.




Jile jigéen jan õgeen wax?




Dem õga te




bootaay




na ni




Defe naa du kii.




waaye nag




Ma may ñi.




lenn loolu woon




Góor gi dem na, ma defe




dindéeku




fenn




gasax




Bu mu dem




Lan la wax.




Toogal ci fépp, fu leer




Giskoonuñu ka




Moo doon dem




fenn




gaddaay




Dafa di nitu tay.




te




Õgor dem.




jaambur




Demal ci kër gu yaa.




Su góor gi dee ñëw




Dem nañu




Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw




ñenn nit




doc




Foo wax?




taf




ngéew




Fas yan ñoo réer?




doj




Duõgeen woon ñëw




te itam




keneen kule




jig




araf




waa Ndar




rafet




Yeneen xar yi yépp daw.




Mi õgi fi.




Giskoon la




ndax




wuutal




reeaatle




seggaat




Ndax yaa beg?




seggaat




L?




warugar




doc




looluu




Laobe lë woon.




Jigéen ñi danañu naõgu




Ci ki.




maõkoo




ile jigéen jan õgeen wax?




Gis naa Samba.




Buleen dem




jaambur jaamburlu




xar




Benn bi rëcc laay seet.




Gis naa ñooñale woon.




Foo jaõge xërëm?




Dañu doonkoon nitu alal.




Yeewal yan?




araféef raññee




na ni




Xar yan ñoo ñëw?




warugar




waaye sax




fi




Lawbe la.




am na




Ki kan la?




Wool xale yi dul dem




ndimo




daaw




walla




Gis naa am xar.




Ñeneen lañu.




Yaa ka gisul




far




Wax na koralal bépp nit.




Gisoon naa jigéen ji ñépp?




Ki kan la?




ñoom




wex




te




doc




su




te




Moontin, dem naa




Gis na ma man mi.




Réew mi am na alal ndax?




waaye nag




ci




Fattu




fal




saxaar




Mu di weneen wu moo.




gistal




Soo demee, mu ñëw.




ni




Mi õgi noonu rekk.




Wante itam, la rax su.




Gisoon naa nit ñooña ñépp.




Dem õga...




Gisu leen leen woon




gistal




Xar yooyu yan õga moom?




Mi õgi fi, soo demee




Nileen ka




nobaate




ñjëg




Y Y Y Y Y daw  Laobe Laobe Laobe fa woon woon woon.




deeo




daaw




Gis naa doomi jigéen ja.




Wool xale yi dul dem




doñj




doc




Bale xale laa wax?




jaambur




fal




am ak




Na dugg ci biir su bëggée




ñaari goroõ




Naa...




Ndax kan dem?




Gis naa gaynde.




Daõga dem?




Bëgg naa õga ñëw mu dem su õgeen noppée.




Ci fii : ba ci




lax




Na?




Xar menn man õga wax?




Xar yan ñoo ñëw?




bu




wàl




naqadil




Gis naa Samba.




Nit ñi daawuñu coow.




wërsëglóo




reeaatle




lépp loolu woon




waaye nag




waxaalelu




lekkadi




waru gar




Fu mu dem?




ci biir




addooku




Wande it, moom am waaye moomu




waa joor




kañaan




Séen naa ay nit.




Ndaw si réccu na amoon.




nawle




Gis naa jeeg bi.




Bumi bi du fi buur.




Bawal mépp xar.




Ki kan la?




Fu mu bëgg?




Jox na doom ji jigéen jan a réer.




Góor gii bëggóon




bareedi




gisé


100%|██████████| 297/297 [05:57<00:00,  1.20s/it]

Du woon dem





Unnamed: 0,original_text,original_label,predicted_label
0,C'est peut-être Fatim!,"Soo demee, Faatim la!",daaw
1,Celui qui est en haut!,Kenn ki ci kaw!,"Moom daal, ñépp ñibbisi nañu."
2,jusqu'à,ba,kuu
3,C'est ces livres-là qu'il m'a donnés.,Yee téere la ma jóox.,Ñun ñii lay set.
4,marché de bétail,dogal,mbay


In [37]:
df_ft_to_wf.tail(10)

Unnamed: 0,original_text,original_label,predicted_label
287,J'ai vu la femme.,Gis naa jigéen ju.,Gis naa jeeg bi.
288,Le fait que tu ne viennes pas est préjudiciable.,Õga bañ ñëw rafetul.,Bumi bi du fi buur.
289,"« me voici, regardez-moi » n'est pas le fait d...","Maa õgii gis leen ma, du jëfu nit ku yiw.",Bawal mépp xar.
290,C'est bien?,Baax na?,Ki kan la?
291,Où est celui dont il parle?,Ana waa ji muy wax?,Fu mu bëgg?
292,Il a confié à la femme une valise.,Deõk na jigéen ji walis.,Jox na doom ji jigéen jan a réer.
293,"Cet homme qui vint autrefois, le voilà!","Góor googule ñëwóon, mi õgi!",Góor gii bëggóon
294,De quelle manière?,Fan?,bareedi
295,attacher,takk,gisé
296,Il ne t'avait pas remis la taxe,Joxu la woon juuti bi,Du woon dem


In [24]:
# let us display 100 samples
pd.options.display.max_rows = 100
df_ft_to_wf.sample(100)

Unnamed: 0,original_text,original_label,predicted_label
167,Leurs lions c'est des moutons.,Seni gaynde yi ay xar lañu.,Nit ku baax la.
211,Il ne dit rien de cela.,Waxul cu li dara.,Wax na Musa dara.
63,bile,wextan,gisé
154,confier,déey,wàyaan
5,Tu as été et il a été et moi aussi j'ai été.,"Dem õga, te dem na, te dem naa.",Dem õga te ñëw na.
77,se pincer,tëccu,lemu
183,êtres humains de sexe masculin,nit ñu góor,maõkoo
158,Quel mouton s'est égaré?,Xar man a réer?,Fan a réer?
9,bénéficier d'un sort,xéewlu,xéewlu
139,Il dit que c'est Fatim.,Mu ni Faatim la.,Li ñu wax la.


In [39]:
# let us print the scores
scores['bleu']/df_ft_to_wf.shape[0]

7.642814814814815

## Colab download and remove step

In [None]:
import shutil

# shutil.rmtree('/content/drive/MyDrive/Memoire/subject2/training2/results2')
shutil.rmtree('wandb')
# shutil.make_archive('wandb', 'zip', 'wanbd')