In [3]:
import os
import sys

from transformers import FlaxT5ForConditionalGeneration, AutoTokenizer, T5Config
from datasets import load_dataset

import numpy as np
import jax.numpy as jnp

In [4]:
tokenizer = AutoTokenizer.from_pretrained("t5-large")
trainer_config = T5Config.from_pretrained("t5-large")

student_config_dict = trainer_config.to_dict()  # makes it mutable
student_config_dict["d_ff"] = 1024
student_config_dict["d_model"] = 256
student_config_dict["num_heads"] = 4
student_config_dict["num_layers"] = 1
student_config_dict["num_decoder_layers"] = 1

student_config = T5Config.from_dict(student_config_dict)

In [5]:
MODEL_DIR = "t5_super_tiny"
os.makedirs(MODEL_DIR, exist_ok=True)

tokenizer.save_pretrained(MODEL_DIR)
student_config.save_pretrained(MODEL_DIR)

In [10]:
c4_dataset = load_dataset("c4", "realnewslike")

c4_train = c4_dataset["train"]

indices = np.random.choice(len(c4_train), int(len(c4_train) * 0.01), replace=False)
c4_train = c4_train.select(indices)
next(iter(c4_train))

Reusing dataset c4 (/home/vlialin/.cache/huggingface/datasets/c4/realnewslike/0.0.0/df532b158939272d032cc63ef19cd5b83e9b4d00c922b833e4cb18b2e9869b01)


  0%|          | 0/2 [00:00<?, ?it/s]

{'text': "State Bank of India (SBI) main branch, Sector 17, in association with the Chandigarh Beopar Mandal organised a 'Coin Mela' for traders, in which coins and currency notes amounting to Rs 36 lakh were distributed. The fair will continue till Tuesday as well.\nMohan Ganeshari, division general manager (DGM), and Kamlesh Sekhon, area general manager (AGM), SBI, inaugurated the event by distributing bags of coins and bundle of notes to traders.\nSBI officials said that Rs 1 currency notes had already been printed by the Reserve Bank of India, and the supply would reach the RBI, Chandigarh, soon.\nKK Rana and Venu Gopal, manager and assistant manager, currency issue department, RBI, also made people aware about Notes Return Rules, 2009, and star series notes.",
 'timestamp': '2019-04-22T22:06:46Z',
 'url': 'https://www.hindustantimes.com/chandigarh/chandigarh-state-bank-of-india-holds-coin-mela-for-traders/story-QlwEu2loquMBnrSoGYSr0K.html'}

### Train command

```bash
export WANDB_START_METHOD="thread"
export TOKENIZERS_PARALLELISM=false
export MODEL_DIR=t5_super_tiny

python run_t5_mlm_flax.py \
	--output_dir=$MODEL_DIR \
	--model_type="t5" \
	--config_name=$MODEL_DIR \
	--tokenizer_name=$MODEL_DIR \
	--dataset_name="c4" \
	--dataset_config_name="realnewslike" \
	--preprocessing_num_workers="8" \
	--max_seq_length="64" \
	--per_device_train_batch_size="512" \
	--per_device_eval_batch_size="512" \
	--adafactor \
	--learning_rate="0.005" \
	--weight_decay="0.001" \
	--warmup_steps="2000" \
	--overwrite_output_dir \
	--logging_steps="10" \
	--save_steps="1000" \
	--eval_steps="500" \
	# --dataset_fraction="0.1" # DEBUG option, make sure that validation set is still more that 1 element
```


# LFOM Distillation

```bash
export TOKENIZERS_PARALLELISM=false
export MODEL_DIR=t5_2l_8h_512d_2048ff_lfom_distil_debug
export TEACHER_MODEL=t5-small
export WEAK_MODEL=t5_2l_8h_512d_2048ff

# REMEMBER TO REPLACE --config-name $WEAK_MODEL with a normal config
# REMEMBER TO ADD --weak_model_name_or_path

python run_lfom_distillation_flax.py \
	--output_dir=$MODEL_DIR \
	--model_type="t5" \
	--config_name=$WEAK_MODEL \
	--tokenizer_name=$TEACHER_MODEL \
	--teacher_model_name_or_path=$TEACHER_MODEL \
	--dataset_name="c4" \
	--dataset_config_name="realnewslike" \
	--preprocessing_num_workers="8" \
	--max_seq_length="256" \
	--temperature 2.0 \
	--per_device_train_batch_size="128" \
	--per_device_eval_batch_size="128" \
	--adafactor \
	--learning_rate="0.005" \
	--weight_decay="0.001" \
	--warmup_steps="2000" \
	--overwrite_output_dir \
	--logging_steps="10" \
	--save_steps="1000" \
	--eval_steps="500" \
	--dataset_fraction="0.1" # DEBUG option, make sure that validation set is still more that 1 element
```