# simpleT5 fine-tuning ByT5 models

Runs on a GPU instance

## Dependencies

In [1]:
%%capture
! rm -rf simpleT5
! git clone https://github.com/mapmeld/simpleT5
! cd simpleT5 && pip install .

In [8]:
# need transformers 4.7.0; now updated in requirements.txt
%%capture
! pip install transformers --upgrade

## Basque data (language included in mC4 dataset used in pretraining)

In [1]:
import pandas as pd
train_df = pd.read_csv("./train.tsv", sep="\t", names=["target_text", "source_text"])
eval_df = pd.read_csv("./test.tsv", sep="\t", names=["target_text", "source_text"])

In [4]:
train_df.head()

Unnamed: 0,target_text,source_text
0,Euskara,म र च १९ क द न ब स क न गर कहर ल क न पन द शल आफ...
1,Politika,agiri baten bitartez adierazi dute mugimenduko...
2,Euskara,ekainaren 14an heldu den igandean behaskaneko ...
3,Politika,ekineko zuzendaritzako kide izatea leporatuta ...
4,Ingurumena,energi trantsizioa landuko dute larunbatean ir...


In [5]:
eval_df.head()

Unnamed: 0,target_text,source_text
0,Nazioartea,pau llonch sabadell 1982 musikaria eta ekintza...
1,Politika,elkarrekin podemoseko kide den monica monteagu...
2,Politika,polizia maitea hurbilekoa eta errespetatua da ...
3,Gizartea,federico addiechi k diskriminazioaren aurkako ...
4,Ingurumena,60 000 pertsona baino gehiagok hartu dute part...


## Fine-tuning google/byt5-small

In [4]:
from simplet5 import SimpleT5

Global seed set to 42


In [6]:
model = SimpleT5()
model.from_pretrained("byt5","google/byt5-small")

Global seed set to 42


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2503.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2593.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=698.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1198627927.0, style=ProgressStyle(descr…




In [7]:
model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 3,
            use_gpu = True,
            outputdir = "outputs",
            early_stopping_patience_epochs = 0
            )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 299 M 
-----------------------------------------------------
299 M     Trainable params
0         Non-trainable params
299 M     Total params
1,198.551 Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [5]:
## restoring from saved
model = SimpleT5()
model.from_pretrained('byt5', './outputs/SimpleT5-epoch-2-train-loss-0.1916')

In [7]:
# from dev.tsv
# this is incorrect, but left here as API example
model.predict('el salvadorreko gerran izan zen gerrillaren ondoan bertatik bertara zuzenean bizi izan zituen gerrak dakartzan oinaze min heriotza izu eta beldurra hala ere merezi izan duela dio dudarik gabe salvadortarrak duintasuna irabazi duelakoan gerra hotsa aditzen da miren odriozolaren ahotsa aditzean')

['Nazioartea']

## Fine-tuning monsoon-nlp/byt5-basque
This was pre-trained overnight on the Basque Wikipedia

In [8]:
from simplet5 import SimpleT5
model2 = SimpleT5()
model2.from_pretrained("byt5","monsoon-nlp/byt5-basque")

In [None]:
model2.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 3,
            use_gpu = True,
            outputdir = "outputs_2",
            early_stopping_patience_epochs = 0
            )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 299 M 
-----------------------------------------------------
299 M     Trainable params
0         Non-trainable params
299 M     Total params
1,198.551 Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [None]:
model2.predict('nafarroako gobernuak toponimia normalizatzeko lana ez duela egin zioen eh bilduk aurkeztutako mozioak eta euskaltzaindiaren jarraibideei segitzeko eskatu geroa bai abstenitu egin da eta bertan behera geratu da asmoa
')