<a href="https://colab.research.google.com/github/00SamYun/simple_chabot_model/blob/main/output_model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# set runtime to TPU

#### Setup

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
from IPython.display import clear_output

In [None]:
import os
import numpy as np
import tensorflow as tf
from transformers import TFT5ForConditionalGeneration

tf.get_logger().setLevel('ERROR')

#### Load Dataset

In [None]:
train_dataset = tf.data.TFRecordDataset('gs://PATH_TO_BUCKET/output_model/train.tfrecord')
valid_dataset = tf.data.TFRecordDataset('gs://PATH_TO_BUCKET/output_model/validation.tfrecord')

In [None]:
def read_tfrecord(example):
    format = {
        "attention_mask": tf.io.FixedLenFeature([], tf.string), 
        "decoder_attention_mask": tf.io.FixedLenFeature([], tf.string), 
        "input_ids": tf.io.FixedLenFeature([], tf.string), 
        "labels": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, format)

    record = {k:tf.io.parse_tensor(v, tf.int32) for k,v in example.items()}
    record = {k:tf.reshape(v, (100,)) for k, v in record.items()}

    return record

In [None]:
train_dataset = train_dataset.map(read_tfrecord)
valid_dataset = valid_dataset.map(read_tfrecord)

#### Create Strategy

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

#### Setup Input Pipeline

In [None]:
BUFFER_SIZE = 1812 # total number of elements is 18120

BATCH_SIZE_PER_REPLICA = 16
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 11

In [None]:
train_ds = train_dataset.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
valid_ds = valid_dataset.batch(GLOBAL_BATCH_SIZE)

train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
valid_dist_ds = strategy.experimental_distribute_dataset(valid_ds)

#### Create Model

In [None]:
def create_model():
    model = TFT5ForConditionalGeneration.from_pretrained('t5-base')

    return model

#### Define Metrics & Loss

In [None]:
with strategy.scope():

    def compute_loss(per_example_loss):
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss = tf.keras.metrics.Mean(name='test_loss')

    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

#### Training Loop

In [None]:
with strategy.scope():
    model = create_model()

    optimizer = tf.keras.optimizers.Adam()

    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [None]:
def train_step(inputs):
    x = inputs
    y = x['labels']

    with tf.GradientTape() as tape:
        outputs = model(x, training=True)
        per_example_loss, logits = outputs[0], outputs[1]
        loss = compute_loss(per_example_loss)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(y, logits)
    return loss

In [None]:
def test_step(inputs):
    x = inputs
    y = x['labels']

    outputs = model(x, training=False)
    t_loss, logits = outputs[0], outputs[1]

    test_loss.update_state(t_loss)
    test_accuracy.update_state(y, logits)

In [None]:
@tf.function
def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
    return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0

    for x in train_dist_ds:
        total_loss += distributed_train_step(x)
        num_batches += 1
        train_loss = total_loss / num_batches
    
    for x in valid_dist_ds:
        distributed_test_step(x)
        
    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}")
    print(template.format(epoch+1, train_loss, train_accuracy.result()*100, test_loss.result(),
                          test_accuracy.result()*100))
    
    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

# training on 11 epochs took approximately 30 minutes to run

In [None]:
model.save_weights('gs://PATH_TO_BUCKET/output_model/saved_weights')

In [None]:
# Note: model should be reloaded and tested on CPU