# ByT5 Generic

Use TPU instance, create a GCS bucket with your train and validation sets as line-by-line text files

## Dependencies

In [None]:
%%capture
! pip install t5 sentencepiece apache_beam --upgrade

In [None]:
! git clone https://github.com/google-research/byt5 byt5-repo
! git clone https://github.com/google-research/multilingual-t5
! git clone https://github.com/google-research/text-to-text-transfer-transformer

Cloning into 'byt5-repo'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 58 (delta 26), reused 54 (delta 22), pack-reused 0[K
Unpacking objects: 100% (58/58), done.
Cloning into 'multilingual-t5'...
remote: Enumerating objects: 246, done.[K
remote: Counting objects: 100% (69/69), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 246 (delta 32), reused 32 (delta 23), pack-reused 177[K
Receiving objects: 100% (246/246), 63.44 KiB | 2.64 MiB/s, done.
Resolving deltas: 100% (144/144), done.
Cloning into 'text-to-text-transfer-transformer'...
remote: Enumerating objects: 3314, done.[K
remote: Counting objects: 100% (259/259), done.[K
remote: Compressing objects: 100% (113/113), done.[K
remote: Total 3314 (delta 154), reused 200 (delta 145), pack-reused 3055[K
Receiving objects: 100% (3314/3314), 5.18 MiB | 13.52 MiB/s, done.
Resolving deltas: 100% (

## Reorder

In [None]:
! mv multilingual-t5/multilingual_t5 ./
! mv byt5-repo/byt5 ./

In [None]:
! mkdir models
! cp byt5/gin/models/* models/
! cp text-to-text-transfer-transformer/t5/models/gin/models/* models/

## Custom Text Dataset

In [None]:
# set up train and validation line by line files

In [None]:
## already in byt5/tasks.py
import t5
import functools

DEFAULT_BYTE_OUTPUT_FEATURES = {
    "inputs": t5.data.Feature(vocabulary=t5.data.ByteVocabulary()),
    "targets": t5.data.Feature(vocabulary=t5.data.ByteVocabulary())
}
MEAN_NOISE_SPAN_LENGTH = 20

In [None]:
### MANUAL EDITS ###
## upload files to google storage
## update paths
## add this to byt5/tasks.py
####################
t5.data.TaskRegistry.add(
      "byt5_ex",
      t5.data.TextLineTask,
      split_to_filepattern={
            "train": "gs://BUCKET/train_lines.txt",
            "validation": "gs://BUCKET/validation_lines.txt",
        },
      text_preprocessor=[
        functools.partial(
          t5.data.preprocessors.parse_tsv,
          field_names=['text'],
          field_delim='~', # check ASCII char doesn't appear in files, default is tab (\t)
        ),
        functools.partial(
              t5.data.preprocessors.rekey,
              key_map={
                  "inputs": None,
                  "targets": "text"
              }),
      ],
      token_preprocessor=functools.partial(
          t5.data.preprocessors.span_corruption,
          mean_noise_span_length=MEAN_NOISE_SPAN_LENGTH),
      output_features=DEFAULT_BYTE_OUTPUT_FEATURES,
      metric_fns=[])

## TPU Training

In [None]:
# get tpu ip address and port
import os
os.environ['COLAB_TPU_ADDR']

In [None]:
! python -m t5.models.mesh_transformer_main \
  --gin_file="./models/byt5.base.gin" \
  --gin_param="MIXTURE_NAME = 'byt5_ex'" \
  --gin_param="utils.run.sequence_length = {'inputs': 128, 'targets': 128}" \
  --gin_param="utils.run.batch_size = ('tokens_per_batch', 32768)" \
  --gin_param="run.train_steps = 100000" \
  --gin_param="utils.tpu_mesh_shape.tpu_topology = 'v3-8'" \
  --gin_param="run.train_dataset_fn = @t5.models.mesh_transformer.mesh_train_dataset_fn" \
  --gin_param="mesh_train_dataset_fn.mixture_or_task_name = 'byt5_ex'" \
  --module_import="byt5.tasks" \
  --tpu="grpc://TPU_LOCATION" \
  --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \
  --model_dir="gs://BUCKET/byt5_base_ex" \
  --gcp_project="mapmeld-hrd"

  # will save every 5000 steps

  #--t5_tfds_data_dir="${BUCKET}/t5-tfds" \
    # --eval_mode="perplexity_eval" \
#  --eval_gin_param="mesh_eval_dataset_fn.num_eval_examples = 10000" \

  #--tpu_zone="${ZONE}" \
#   --gin_param="utils.run.learning_rate_schedule=@learning_rate_schedules.rsqrt_no_ramp_down" \


## Convert TF checkpoint

In [None]:
%%capture
! pip install transformers

In [None]:
! cp ./drive/MyDrive/mlin/dvcorpus/dv-t5/checkpoint ./drive/MyDrive/mlin/dvcorpus/dv-t5/model.ckpt

In [None]:
! transformers-cli convert --model_type t5 \
  --tf_checkpoint ./drive/MyDrive/mlin/dvcorpus/dv-t5/ \
  --config ./drive/MyDrive/mlin/dvcorpus/dv-t5/config.json \
  --pytorch_dump_output ./drive/MyDrive/mlin/dvcorpus/dv-t5/

In [None]:
from transformers import TFT5ForConditionalGeneration
t_model = TFT5ForConditionalGeneration.from_pretrained('./drive/MyDrive/mlin/dvcorpus/dv-t5', from_pt=True)

In [None]:
t_model.save_pretrained('./drive/MyDrive/mlin/dvcorpus/dv-t5/')