<a href="https://colab.research.google.com/github/00SamYun/simple_chabot_model/blob/main/input_model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# set runtime to TPU 

#### Setup

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
import os

import json
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

from official.modeling import tf_utils 
from official import nlp 
from official.nlp import bert

import official.nlp.bert.configs
import official.nlp.bert.bert_models
import official.nlp.bert.tokenization
import official.nlp.optimization

tf.get_logger().setLevel('ERROR')

#### Prepare Data

In [None]:
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12"

In [None]:
def extract_info(element):
    context = element['target_text'].numpy().decode()
    words = element['input_text']['table']['content'].numpy()

    mapping = list(set([(w.decode(), i%3) for i, w in enumerate(words)]))
    
    mapping = [(t[0], [int(n==t[1]) for n in range(3)]) for t in mapping]
    
    return [t[0] for t in mapping], [context]*len(mapping), [t[1] for t in mapping]


def create_data(ds):
    dataset = {'word':[], 'context':[], 'label':[]}

    for elem in ds:
        w,c,l = extract_info(elem)
        dataset['word'] += w
        dataset['context'] += c
        dataset['label'] += l
        
    return dataset

In [None]:
with tf.device('/job:localhost'):
    dart, info = tfds.load('dart', with_info=True, shuffle_files=True)

In [None]:
with tf.device('/job:localhost'):
    
    train_dataset = create_data(dart['train'].take(-1))
    valid_dataset = create_data(dart['validation'].take(-1))

In [None]:
tokenizer = bert.tokenization.FullTokenizer(vocab_file=os.path.join(gs_folder_bert, "vocab.txt"), do_lower_case=True)

def encode(inp, tokenizer):
    tokens = list(tokenizer.tokenize(inp))
    tokens.append('[SEP]')
    return tokenizer.convert_tokens_to_ids(tokens)


def prepare_inputs(ds, tokenizer):

    words = tf.ragged.constant([encode(w, tokenizer) for w in np.array(ds['word'])])
    contexts = tf.ragged.constant([encode(c, tokenizer) for c in np.array(ds['context'])])

    cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*words.shape[0]
    input_word_ids = tf.concat([cls, words, contexts], axis=-1)

    input_mask = tf.ones_like(input_word_ids).to_tensor()
    type_cls = tf.zeros_like(cls)
    type_words = tf.zeros_like(words)
    type_contexts = tf.ones_like(contexts)
    input_type_ids = tf.concat([type_cls, type_words, type_contexts], axis=-1).to_tensor()

    inputs = {
        'input_word_ids': input_word_ids.to_tensor(),
        'input_mask': input_mask,
        'input_type_ids': input_type_ids}

    return inputs

In [None]:
train_ds = prepare_inputs(train_dataset, tokenizer)
valid_ds = prepare_inputs(valid_dataset, tokenizer)

train_label_ds = tf.convert_to_tensor(train_dataset['label'])
valid_label_ds = tf.convert_to_tensor(valid_dataset['label'])

#### Create Strategy

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

#### Setup Input Pipeline

In [None]:
BATCH_SIZE_PER_REPLICA = 32
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 8

#### Create Model

In [None]:
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12"

hub_url_bert = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3"

bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read()) 

bert_config = bert.configs.BertConfig.from_dict(config_dict) 

In [None]:
def create_model():
    bert_classifier, bert_encoder = bert.bert_models.classifier_model(bert_config, num_labels=3)

    return bert_classifier, bert_encoder

In [None]:
with strategy.scope():
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metrics = tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)
    optimizer = tf.keras.optimizers.SGD()

    bert_classifier, bert_encoder = create_model()

    checkpoint = tf.train.Checkpoint(encoder=bert_encoder)
    checkpoint.read(os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed() 

    bert_classifier.compile(optimizer=optimizer, loss=loss, metrics=metrics)

#### Training Model

In [None]:
bert_classifier.fit(train_ds, train_label_ds, 
                    validation_data=(valid_ds, valid_label_ds), 
                    batch_size=GLOBAL_BATCH_SIZE,
                    epochs=EPOCHS)

# training for 8 epochs took approximately 1h 30m to run 

In [None]:
weights_dir = 'gs://PATH_TO_BUCKET/input_model/training_weights'
bert_classifier.save_weights(weights_dir)

In [None]:
# Note: model should be reloaded and tested on GPU