# ByT5 pre-training (generic)

Use TPU and High-RAM instance, create a GCS bucket

## Dependencies

In [None]:
%%capture
! pip install t5 sentencepiece apache_beam --upgrade

In [None]:
! git clone https://github.com/google-research/byt5 byt5-repo
! git clone https://github.com/google-research/multilingual-t5
! git clone https://github.com/google-research/text-to-text-transfer-transformer

Cloning into 'byt5-repo'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 58 (delta 26), reused 54 (delta 22), pack-reused 0[K
Unpacking objects: 100% (58/58), done.
Cloning into 'multilingual-t5'...
remote: Enumerating objects: 246, done.[K
remote: Counting objects: 100% (69/69), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 246 (delta 32), reused 32 (delta 23), pack-reused 177[K
Receiving objects: 100% (246/246), 63.44 KiB | 2.64 MiB/s, done.
Resolving deltas: 100% (144/144), done.
Cloning into 'text-to-text-transfer-transformer'...
remote: Enumerating objects: 3314, done.[K
remote: Counting objects: 100% (259/259), done.[K
remote: Compressing objects: 100% (113/113), done.[K
remote: Total 3314 (delta 154), reused 200 (delta 145), pack-reused 3055[K
Receiving objects: 100% (3314/3314), 5.18 MiB | 15.45 MiB/s, done.
Resolving deltas: 100% (

### Verify

In [None]:
from t5.models import mesh_transformer

In [None]:
from seqio.dataset_providers import MixtureRegistry, TaskRegistry

In [None]:
! python -c "import t5; print(t5.data.MixtureRegistry.names())"

2021-07-06 13:35:52.793782: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
dict_keys([])


## Reorder

In [None]:
! mv multilingual-t5/multilingual_t5 ./
! mv byt5-repo/byt5 ./

In [None]:
! mkdir models
! cp byt5/gin/models/* models/
! cp text-to-text-transfer-transformer/t5/models/gin/models/* models/

## TPU Training

In [None]:
# copy this URL into the TPU sections below
import os
os.environ['COLAB_TPU_ADDR']

In [None]:
# in byt5/tasks.py, line 53, check WIKI_LANGS array
# you might need to add your lang (ex: "dv")
# tfds has the other languages; it just wasn't included in the repo here

# saves model at step 0, 5100, 10100 (?)

! python -m t5.models.mesh_transformer_main \
  --gin_file="./models/byt5.small.gin" \ # could be byt5.large or other .gin templates
  --gin_param="MIXTURE_NAME = 'byt5_wiki.LANG'" \ # update me
  --gin_param="mesh_train_dataset_fn.mixture_or_task_name = 'byt5_wiki.LANG'" \ # update me
  --gin_param="utils.run.sequence_length = {'inputs': 1024, 'targets': 189}" \
  --gin_param="utils.run.batch_size = ('tokens_per_batch', 262144)" \ # 1/4 of readme
  --gin_param="run.train_steps = 100000" \ # 1/10 of readme
  --gin_param="utils.tpu_mesh_shape.tpu_topology = 'v3-8'" \ # on CoLab you only get 8 cores
  --gin_param="run.train_dataset_fn = @t5.models.mesh_transformer.mesh_train_dataset_fn" \
  --module_import="byt5.tasks" \
  --tpu="grpc://TPU" \ # TPU address
  --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \
  --model_dir="gs://BUCKET/byt5_model" \ # your GCS bucket name
  --gcp_project="GCP" # your GCP project

  #--t5_tfds_data_dir="${BUCKET}/t5-tfds" \
    # --eval_mode="perplexity_eval" \
#  --eval_gin_param="mesh_eval_dataset_fn.num_eval_examples = 10000" \
  #--tpu_zone="${ZONE}" \
#   --gin_param="utils.run.learning_rate_schedule=@learning_rate_schedules.rsqrt_no_ramp_down" \


## Convert checkpoint to HF / PyTorch model

In [None]:
%%capture
! pip install transformers

In [None]:
! cp ./drive/MyDrive/mlin/dvcorpus/dv-t5/checkpoint ./drive/MyDrive/mlin/dvcorpus/dv-t5/model.ckpt

In [None]:
! transformers-cli convert --model_type t5 \
  --tf_checkpoint ./drive/MyDrive/mlin/dvcorpus/dv-t5/ \
  --config ./drive/MyDrive/mlin/dvcorpus/dv-t5/config.json \
  --pytorch_dump_output ./drive/MyDrive/mlin/dvcorpus/dv-t5/

2021-07-06 20:45:17.975878: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Building PyTorch model from configuration: T5Config {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 3584,
  "d_kv": 64,
  "d_model": 1472,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "gradient_checkpointing": false,
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "num_decoder_layers": 4,
  "num_heads": 6,
  "num_layers": 12,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "ByT5Tokenizer",
  "transformers_version": "4.8.2",
  "use_cache": true,
  "vocab_size": 384
}

Converting TensorFlow checkpoint from /content/drive/MyDrive/mlin/dvcorpus/dv-t5
Loading TF weight decoder/block_000/layer_000/SelfAttention/k with shape [147

In [None]:
from transformers import TFT5ForConditionalGeneration
t_model = TFT5ForConditionalGeneration.from_pretrained('./drive/MyDrive/mlin/dvcorpus/dv-t5', from_pt=True)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFT5ForConditionalGeneration: ['decoder.embed_tokens.weight', 'encoder.embed_tokens.weight']
- This IS expected if you are initializing TFT5ForConditionalGeneration from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFT5ForConditionalGeneration from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [None]:
t_model.save_pretrained('./drive/MyDrive/mlin/dvcorpus/dv-t5/')