Fine-tuning best T5 Transformer ü§ñ
-----------------------------------

In this notebook, we will continue the fine-tuning of T5 transformer on the new extracted sentences from the bool **Grammaire de Wolof Moderne**. We obtained, after a hyperparameter tuning with `wandb`, a best bleu score of **2.47** for french to wolof translation model. We provide, bellow, the main evaluation figures, obtained from the hyperparameter search step.

- Parallel coordinates from panel:


`Parameter importance char` (from [panel]():

![parameter_importance]()

In [1]:
# let us extend the paths of the system
import sys

# path = "/content/drive/MyDrive/Memoire/subject2/T5/"

# sys.path.extend([path, f"{path}new_data"])

In [2]:
# define environment
# %env WANDB_LOG_MODEL=true
# %env WANDB_NOTEBOOK_NAME=fw_training_t5_small_best_model_v2.ipynb
# %env WANDB_API_KEY=237a8450cd2568ea1c8e1f8e0400708e79b6b4ee

In [3]:
# !pip install -qq wandb --upgrade

In [4]:
# !pip install evaluate -qq
# !pip install sacrebleu -qq
# !pip install optuna -qq
# !pip install transformers -qq 
# !pip install tokenizers -qq
# !pip install nlpaug -qq
# !pip install ray[tune] -qq
# !python -m spacy download fr_core_news_lg 

In [5]:
# let us import all necessary libraries
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5TokenizerFast, set_seed
from wolof_translate.utils.sent_transformers import TransformerSequences
from wolof_translate.data.dataset_v2 import T5SentenceDataset
from wolof_translate.utils.sent_corrections import *
from sklearn.model_selection import train_test_split
from nlpaug.augmenter import char as nac
from torch.utils.data import DataLoader
# from datasets  import load_metric # make pip install evaluate instead
# and pip install sacrebleu for instance
from functools import partial
from tqdm import tqdm
import pandas as pd
import numpy as np
import evaluate
import wandb
import torch

# wandb.login(key="237a8450cd2568ea1c8e1f8e0400708e79b6b4ee")


  from .autonotebook import tqdm as notebook_tqdm


--------------

## French to wolof

### Configure dataset üî†

In [6]:
def split_data(random_state: int = 50):
  """Split data between train, validation and test sets

  Args:
    random_state (int): the seed of the splitting generator. Defaults to 50
  """
  # load the corpora and split into train and test sets
  corpora = pd.read_csv(f"data/additional_documents/diagne_sentences/extractions.csv")

  train_set, test_set = train_test_split(corpora, test_size=0.1, random_state=random_state)

  # let us save the final training set when performing

  train_set, valid_set = train_test_split(train_set, test_size=0.1, random_state=random_state)

  train_set.to_csv(f"data/additional_documents/diagne_sentences/final_train_set.csv", index=False)

  # let us save the sets
  train_set.to_csv(f"data/additional_documents/diagne_sentences/train_set.csv", index=False)

  valid_set.to_csv(f"data/additional_documents/diagne_sentences/valid_set.csv", index=False)

  test_set.to_csv(f"data/additional_documents/diagne_sentences/test_set.csv", index=False)

In [7]:
# recuperate the tokenizer from a json file
tokenizer = T5TokenizerFast(tokenizer_file=f"wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v3.json")


In [8]:
def recuperate_datasets(fr_char_p: float, fr_word_p: float):

  # Create augmentation to add on French sentences
  fr_augmentation = TransformerSequences(nac.KeyboardAug(aug_char_p=fr_char_p, aug_word_p=fr_word_p),
                                        remove_mark_space, delete_guillemet_space)

  # Recuperate the train dataset
  train_dataset_aug = T5SentenceDataset(f"data/additional_documents/diagne_sentences/final_train_set.csv",
                                        tokenizer,
                                        truncation = True,
                                        cp1_transformer = fr_augmentation)

  # Recuperate the test dataset
  test_dataset = T5SentenceDataset(f"data/additional_documents/diagne_sentences/test_set.csv",
                                        tokenizer,
                                        truncation = True)
  
  # Return the datasets
  return train_dataset_aug, test_dataset

### Configure the model and the evaluation function ‚öôÔ∏è

Let us recuperate the model and resize the token embeddings.

In [9]:
def t5_model_init(tokenizer):

    # Initialize the model name
    model_name = 't5-small'
    # model_name = 'data/checkpoints/vf_t5_small_v2_checkpoints_2/' # from checkpoint

    # import the model with its pre-trained weights
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, add_cross_attention = True)

    # resize the token embeddings
    model.resize_token_embeddings(len(tokenizer))

    return model

Embedding(2961, 512)

Let us evaluate the predictions with the `bleu` metric.

In [10]:
# %%writefile wolof-translate/wolof_translate/utils/evaluation.py
from tokenizers import Tokenizer
from typing import *
import numpy as np
import evaluate

class TranslationEvaluation:
    
    def __init__(self, 
                 tokenizer: Tokenizer,
                 decoder: Union[Callable, None] = None,
                 metric = evaluate.load('sacrebleu'),
                 ):
        
        self.tokenizer = tokenizer
        
        self.decoder = decoder
        
        self.metric = metric
    
    def postprocess_text(self, preds, labels):
        
        preds = [pred.strip() for pred in preds]
        
        labels = [[label.strip()] for label in labels]
        
        return preds, labels

    def compute_metrics(self, eval_preds):

        preds, labels = eval_preds

        if isinstance(preds, tuple):
        
            preds = preds[0]
        
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels)

        result = self.metric.compute(predictions=decoded_preds, references=decoded_labels)
        
        result = {"bleu": result["score"]}

        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        
        result["gen_len"] = np.mean(prediction_lens)
        
        result = {k: round(v, 4) for k, v in result.items()}
        
        return result

In [11]:
# %run wolof-translate/wolof_translate/utils/evaluation.py

Let us initialize the evaluation object.

In [12]:
evaluation = TranslationEvaluation(tokenizer)


### Searching for the best parameters üïñ

Let us define the data collator.

In [13]:
def data_collator(batch):
    """Generate a batch of data to provide to trainer

    Args:
        batch (_type_): The batch

    Returns:
        dict: A dictionary containing the ids, the attention mask and the labels
    """
    input_ids = torch.stack([b[0].squeeze(0) for b in batch])
    
    attention_mask = torch.stack([b[1].squeeze(0) for b in batch])
    
    labels = torch.stack([b[2].squeeze(0) for b in batch])
    
    return {'input_ids': input_ids, 'attention_mask': attention_mask,
            'labels': labels}

Let us initialize the training arguments and make random search.

In [None]:
# %%wandb

"""Best parameters
learning_rate = 0.00339254221933438
weight_decay = 0.4268760048515646
train_batch_size = 32
random_state = 25
fr_char_p = 0.7040605028532431
fr_word_p = 0.2651043507758306
eval/bleu = 2.2637
"""
"""Best parameters
learning_rate = 0.0023118906467202416
weight_decay = 0.4
train_batch_size = 16
random_state = 1
fr_char_p = 0.5914542632321074
fr_word_p = 0.4495405182231499
eval/loss = 0.5561882257461548
"""

# Initialize wandb
# wandb.init(project = "fw_small_t5_fine_tuning_v2")

# seed
set_seed(0)

# split the data
split_data(random_state=1)

# let us recuperate the datasets
train_dataset, test_dataset = recuperate_datasets(0.5914542632321074, 0.4495405182231499)

# set training arguments
training_args = Seq2SeqTrainingArguments(f"data/checkpoints/t5_results_fw_v2",
                                    logging_dir="data/logs/results_fw_v2",
                                    num_train_epochs=300,
                                    load_best_model_at_end=True,
                                    save_strategy="epoch",
                                    evaluation_strategy="epoch",
                                    logging_strategy="epoch",
                                    per_device_train_batch_size=16, 
                                    per_device_eval_batch_size=16,
                                    learning_rate=0.0023118906467202416,
                                    # learning_rate=0.00003113,
                                    weight_decay=0.4,
                                    predict_with_generate=True, # we will use predict with generate in order to obtain more valuable test results
                                    fp16 = True,
                                    metric_for_best_model = 'bleu', # a bleu score will be used to find the best model
                                    greater_is_better = True,
                                    save_total_limit = 1, # we will save only the best model
                                    )   

# define training loop
trainer = Seq2SeqTrainer(model_init=partial(t5_model_init, tokenizer = train_dataset.tokenizer),
                  args=training_args,
                  train_dataset=train_dataset, 
                  eval_dataset=test_dataset,
                  data_collator=data_collator,
                  compute_metrics=evaluation.compute_metrics
                  )

# load last checkpoint
# trainer._load_from_checkpoint("data/training2/results/checkpoint-147")

# start training loop
trainer.train()
# trainer.train('data/checkpoints/vf_t5_small_v2_checkpoints/') # from the searching best model
# trainer.train('data/checkpoints/results_fw_v2/last_checkpoint/') # from last checkpoint

# finish wandb
# wandb.finish()


In [35]:
# let us get the best model
model = AutoModelForSeq2SeqLM.from_pretrained('data/checkpoints/results_fw_v2/checkpoint-29898/')

### Predictions

Let us generate texts and store into a DataFrame.

In [36]:

# set the model to eval mode
_ = model.eval()

# run model inference on all test data
original_translations, predicted_translations, original_texts, scores = [], [], [], {}

for data, attention_mask, labels in tqdm(DataLoader(test_dataset)):
    
    # Traduce the sentences
    original_text = tokenizer.decode(data[0], skip_special_tokens=True)
    
    original_translation = tokenizer.decode(labels[0], skip_special_tokens=True)
    
    # get tokens
    generated = torch.tensor(data)
    
    attention_mask = torch.tensor(attention_mask)
    
    # recuperate the pad token id
    pad_token_id = tokenizer.pad_token_id
    
    # perform prediction
    predictions = model.generate(generated, do_sample = False, top_k = 50, max_length = test_dataset.max_len, top_p = 0.90,
                                    temperature = 0, num_return_sequences = 0, attention_mask = attention_mask, pad_token_id = pad_token_id)
    
    # calculate the score and add it to the score
    result = evaluation.compute_metrics((predictions, torch.tensor(labels)))
    
    if not scores: scores.update({k: v for k, v in result.items()})
    
    else: scores.update({k: round(scores[k] + v, 4) for k, v in result.items()})
    
    # decode the predicted tokens into texts
    predicted_translation = list(test_dataset.decode(predictions))
    
    print(predicted_translation[0])
    
    # append results
    original_translations.append(original_translation)
    
    predicted_translations.extend(predicted_translation)
    
    original_texts.append(original_text)

# transform result into data frame
df_ft_to_wf = pd.DataFrame({'original_text': original_texts,
                            'original_label': original_translations,
                            'predicted_label': predicted_translations})

# print the result
df_ft_to_wf.head()

  generated = torch.tensor(data)
  attention_mask = torch.tensor(attention_mask)
  result = evaluation.compute_metrics((predictions, torch.tensor(labels)))


daaw




Moom daal, √±√©pp √±ibbisi na√±u.




kuu




√ëun √±ii lay set.




mbay




Dem √µga te √±√´w na.




Dem naa ci keneen ku jig√©en.




da√µal




r√´bb




x√©ewlu




pas




Demuma




sawwu




da√µal




Wut√´l leneen.




Moo di Lawbe bi




Menn la.




Kooy waxal?




Faatim la, mu ni.




L√´f lan a r√©er?




Y√©en b√´gg √µgeen woon.




degu √†ll




Nit kii ci sama wet.




jaambur jaamburlu




fenn




√±oo√±u




ca




Dafa di dem.




Gis √µga nit kookee?




Fee la.




Demuma fa woon




Da√µga gis kan?




te




waaye nag




te itam




noonu itam




arafu geen




√ïgoor dem.




Wool g√≥or gi dul dem




gise




wax√©




B√´gg naa g√≥or gi √±√´w, xale yi √±√´w, jig√©en √±i toog!




B√´gg naa




waaye nag




Nit ku baax √µga.




√Äna √µgoor?




gis√©




doj




keneen kule




Foofu, g√≥or gi dem fu rafet la.




Koo gis




Ma √µgii dem.




Nit √±enn √±i yegseegu√±u.




Seeteegul




Menn la.




Kile la.




nobaate




ca




Dinaa dem




Gis na keneen ki.




xasaw




xar




reew




gis√©




Jambaar du bare wax.




yan




wante itam?




waaye nag




amal




soppe




dawal




mb√´gg




warugar




B√´gg na soo demee




Leneen lan?




Foofu, dem fu.




te




lemu




fenn




su




mbubb




b√´gg




Doo nitu jamm?




duggo




√±ennu nit




te√µxu




Gaynd√© genn r√©erul.




rattkat




odd




fasu




G√≥or gii b√´gg√≥on




K√´r g√©pp




Wool xale yi dul dem




waaye nag




addooku




foofu




Ku dem?




Noonu itam, ma ba√± g√©nn




Gis naa sa xarit y√©pp.




Ku def lii




Demal rekk




Fu mu b√´gg.




ba




doj




Du woon.




amal




da√µda√µluji




sawwu




fenn




dab




ma√µkoo




Su b√´ggul




Y Y Y Y Y daw  Laobe Laobe Laobe fa woon woon woon.




Aminta √±√´w?




wante itam




dawal




G√≥or gee ni nit la, soo demee




Demuma fa woon




nile




Wool xale yi dul dem




dawal




Xar mi m√©pp.




Gis naa la yaak moom.




Gis √µga coroom la woon.




tis




Nit ku g√≥or la.




Jile jig√©en jan √µgeen wax?




Dem √µga te




bootaay




na ni




Defe naa du kii.




waaye nag




Ma may √±i.




lenn loolu woon




G√≥or gi dem na, ma defe




dind√©eku




fenn




gasax




Bu mu dem




Lan la wax.




Toogal ci f√©pp, fu leer




Giskoonu√±u ka




Moo doon dem




fenn




gaddaay




Dafa di nitu tay.




te




√ïgor dem.




jaambur




Demal ci k√´r gu yaa.




Su g√≥or gi dee √±√´w




Dem na√±u




Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw Yaw




√±enn nit




doc




Foo wax?




taf




ng√©ew




Fas yan √±oo r√©er?




doj




Du√µgeen woon √±√´w




te itam




keneen kule




jig




araf




waa Ndar




rafet




Yeneen xar yi y√©pp daw.




Mi √µgi fi.




Giskoon la




ndax




wuutal




reeaatle




seggaat




Ndax yaa beg?




seggaat




L?




warugar




doc




looluu




Laobe l√´ woon.




Jig√©en √±i dana√±u na√µgu




Ci ki.




ma√µkoo




ile jig√©en jan √µgeen wax?




Gis naa Samba.




Buleen dem




jaambur jaamburlu




xar




Benn bi r√´cc laay seet.




Gis naa √±oo√±ale woon.




Foo ja√µge x√´r√´m?




Da√±u doonkoon nitu alal.




Yeewal yan?




araf√©ef ra√±√±ee




na ni




Xar yan √±oo √±√´w?




warugar




waaye sax




fi




Lawbe la.




am na




Ki kan la?




Wool xale yi dul dem




ndimo




daaw




walla




Gis naa am xar.




√ëeneen la√±u.




Yaa ka gisul




far




Wax na koralal b√©pp nit.




Gisoon naa jig√©en ji √±√©pp?




Ki kan la?




√±oom




wex




te




doc




su




te




Moontin, dem naa




Gis na ma man mi.




R√©ew mi am na alal ndax?




waaye nag




ci




Fattu




fal




saxaar




Mu di weneen wu moo.




gistal




Soo demee, mu √±√´w.




ni




Mi √µgi noonu rekk.




Wante itam, la rax su.




Gisoon naa nit √±oo√±a √±√©pp.




Dem √µga...




Gisu leen leen woon




gistal




Xar yooyu yan √µga moom?




Mi √µgi fi, soo demee




Nileen ka




nobaate




√±j√´g




Y Y Y Y Y daw  Laobe Laobe Laobe fa woon woon woon.




deeo




daaw




Gis naa doomi jig√©en ja.




Wool xale yi dul dem




do√±j




doc




Bale xale laa wax?




jaambur




fal




am ak




Na dugg ci biir su b√´gg√©e




√±aari goro√µ




Naa...




Ndax kan dem?




Gis naa gaynde.




Da√µga dem?




B√´gg naa √µga √±√´w mu dem su √µgeen nopp√©e.




Ci fii : ba ci




lax




Na?




Xar menn man √µga wax?




Xar yan √±oo √±√´w?




bu




w√†l




naqadil




Gis naa Samba.




Nit √±i daawu√±u coow.




w√´rs√´gl√≥o




reeaatle




l√©pp loolu woon




waaye nag




waxaalelu




lekkadi




waru gar




Fu mu dem?




ci biir




addooku




Wande it, moom am waaye moomu




waa joor




ka√±aan




S√©en naa ay nit.




Ndaw si r√©ccu na amoon.




nawle




Gis naa jeeg bi.




Bumi bi du fi buur.




Bawal m√©pp xar.




Ki kan la?




Fu mu b√´gg?




Jox na doom ji jig√©en jan a r√©er.




G√≥or gii b√´gg√≥on




bareedi




gis√©


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 297/297 [05:57<00:00,  1.20s/it]

Du woon dem





Unnamed: 0,original_text,original_label,predicted_label
0,C'est peut-√™tre Fatim!,"Soo demee, Faatim la!",daaw
1,Celui qui est en haut!,Kenn ki ci kaw!,"Moom daal, √±√©pp √±ibbisi na√±u."
2,jusqu'√†,ba,kuu
3,C'est ces livres-l√† qu'il m'a donn√©s.,Yee t√©ere la ma j√≥ox.,√ëun √±ii lay set.
4,march√© de b√©tail,dogal,mbay


In [37]:
df_ft_to_wf.tail(10)

Unnamed: 0,original_text,original_label,predicted_label
287,J'ai vu la femme.,Gis naa jig√©en ju.,Gis naa jeeg bi.
288,Le fait que tu ne viennes pas est pr√©judiciable.,√ïga ba√± √±√´w rafetul.,Bumi bi du fi buur.
289,"¬´ me voici, regardez-moi ¬ª n'est pas le fait d...","Maa √µgii gis leen ma, du j√´fu nit ku yiw.",Bawal m√©pp xar.
290,C'est bien?,Baax na?,Ki kan la?
291,O√π est celui dont il parle?,Ana waa ji muy wax?,Fu mu b√´gg?
292,Il a confi√© √† la femme une valise.,De√µk na jig√©en ji walis.,Jox na doom ji jig√©en jan a r√©er.
293,"Cet homme qui vint autrefois, le voil√†!","G√≥or googule √±√´w√≥on, mi √µgi!",G√≥or gii b√´gg√≥on
294,De quelle mani√®re?,Fan?,bareedi
295,attacher,takk,gis√©
296,Il ne t'avait pas remis la taxe,Joxu la woon juuti bi,Du woon dem


In [24]:
# let us display 100 samples
pd.options.display.max_rows = 100
df_ft_to_wf.sample(100)

Unnamed: 0,original_text,original_label,predicted_label
167,Leurs lions c'est des moutons.,Seni gaynde yi ay xar la√±u.,Nit ku baax la.
211,Il ne dit rien de cela.,Waxul cu li dara.,Wax na Musa dara.
63,bile,wextan,gis√©
154,confier,d√©ey,w√†yaan
5,Tu as √©t√© et il a √©t√© et moi aussi j'ai √©t√©.,"Dem √µga, te dem na, te dem naa.",Dem √µga te √±√´w na.
77,se pincer,t√´ccu,lemu
183,√™tres humains de sexe masculin,nit √±u g√≥or,ma√µkoo
158,Quel mouton s'est √©gar√©?,Xar man a r√©er?,Fan a r√©er?
9,b√©n√©ficier d'un sort,x√©ewlu,x√©ewlu
139,Il dit que c'est Fatim.,Mu ni Faatim la.,Li √±u wax la.


In [39]:
# let us print the scores
scores['bleu']/df_ft_to_wf.shape[0]

7.642814814814815

## Colab download and remove step

In [None]:
import shutil

# shutil.rmtree('/content/drive/MyDrive/Memoire/subject2/training2/results2')
shutil.rmtree('wandb')
# shutil.make_archive('wandb', 'zip', 'wanbd')