In [1]:
!git clone https://github.com/Kira1108/TransformerReview.git
!cp -R TransformerReview/* .
!rm -rf TransformerReview
!pip install -r requirements.txt

Cloning into 'TransformerReview'...
remote: Enumerating objects: 75, done.[K
remote: Counting objects: 100% (75/75), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 75 (delta 31), reused 61 (delta 17), pack-reused 0[K
Unpacking objects: 100% (75/75), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow-text==2.8.*
  Downloading tensorflow_text-2.8.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 5.1 MB/s 
Installing collected packages: tensorflow-text
Successfully installed tensorflow-text-2.8.2


In [4]:
import tensorflow as tf
import tensorflow_text
from model.transformer import Transformer
from preprocess import get_tokenizers, dataset_creator, get_data
from schedule import PaperScheduler
from losses import loss_function, accuracy_function
import time
from functools import partial

Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/ted_hrlr_translate_pt_en_converter.zip


### Set hyper parameters

In [5]:
# hypter parameters, you can optimize
NUM_LAYERS = 4
D_MODEL = 128
DFF = 512
NUM_HEADS = 8
DROPOUT_RATE = 0.1
MAX_TOKENS = 128

# training parameters
EPOCHS = 20
BUFFER_SIZE = 20000
BATCH_SIZE = 64
WARMUP_STEPS = 4000

### Get Data and tokenizers

In [6]:
train_examples, val_examples = get_data()
tokenizers = get_tokenizers()

make_batches = partial(
    dataset_creator,
    batch_size = BATCH_SIZE,
    buffer_size = BUFFER_SIZE,
    tokenizers = tokenizers,
    max_tokens = MAX_TOKENS)


train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

[1mDownloading and preparing dataset 124.94 MiB (download: 124.94 MiB, generated: Unknown size, total: 124.94 MiB) to ./tensorflow_data/ted_hrlr_translate/pt_to_en/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/51785 [00:00<?, ? examples/s]

Shuffling tensorflow_data/ted_hrlr_translate/pt_to_en/1.0.0.incomplete40V6E6/ted_hrlr_translate-train.tfrecord…

Generating validation examples...:   0%|          | 0/1193 [00:00<?, ? examples/s]

Shuffling tensorflow_data/ted_hrlr_translate/pt_to_en/1.0.0.incomplete40V6E6/ted_hrlr_translate-validation.tfr…

Generating test examples...:   0%|          | 0/1803 [00:00<?, ? examples/s]

Shuffling tensorflow_data/ted_hrlr_translate/pt_to_en/1.0.0.incomplete40V6E6/ted_hrlr_translate-test.tfrecord*…

[1mDataset ted_hrlr_translate downloaded and prepared to ./tensorflow_data/ted_hrlr_translate/pt_to_en/1.0.0. Subsequent calls will reuse this data.[0m


### Model Setup

In [7]:
# hypterparameter from tokenizers
INPUT_VOCAB_SIZE = tokenizers.pt.get_vocab_size().numpy()
TARGET_VOCAB_SIZE = tokenizers.en.get_vocab_size().numpy()

# transformer model
transformer = Transformer(
    num_layers = NUM_LAYERS,
    input_vocab_size = INPUT_VOCAB_SIZE, 
    target_vocab_size = TARGET_VOCAB_SIZE,
    d_model = D_MODEL,
    max_tokens = MAX_TOKENS,
    num_heads = NUM_HEADS,
    dff = DFF,
    rate = DROPOUT_RATE
)

learning_rate = PaperScheduler(D_MODEL, warmup_steps=WARMUP_STEPS)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)


train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

checkpoint_path = './checkpoints/train'

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

### Model Training

In [None]:
# tf.function take functions into a graph, that can execute faster on distributed learning case.
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
    """train_step is one step on a single batch, with forward and backward propagation."""
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    with tf.GradientTape() as tape:
        predictions, _ = transformer([inp, tar_inp],
                                        training = True)
        loss = loss_function(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_accuracy(accuracy_function(tar_real, predictions))



for epoch in range(EPOCHS):
    start = time.time()

    train_loss.reset_states()
    train_accuracy.reset_states()

    # inp -> portuguese, tar -> english
    for (batch, (inp, tar)) in enumerate(train_batches):
        train_step(inp, tar)

        if batch % 50 == 0:
            print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print(f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}')

    print(f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')

Epoch 1 Batch 0 Loss 8.7437 Accuracy 0.0008
Epoch 1 Batch 50 Loss 8.6273 Accuracy 0.0273
Epoch 1 Batch 100 Loss 8.5086 Accuracy 0.0410
Epoch 1 Batch 150 Loss 8.3596 Accuracy 0.0471
Epoch 1 Batch 200 Loss 8.1805 Accuracy 0.0514
Epoch 1 Batch 250 Loss 7.9835 Accuracy 0.0580
Epoch 1 Batch 300 Loss 7.7795 Accuracy 0.0668
Epoch 1 Batch 350 Loss 7.5847 Accuracy 0.0744
Epoch 1 Batch 400 Loss 7.4134 Accuracy 0.0812
