<a href="https://colab.research.google.com/github/00SamYun/simple_chabot_model/blob/main/output_model_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Setup

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
from IPython.display import clear_output

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from transformers import T5Tokenizer

tf.get_logger().setLevel('WARNING')

#### Prepare Data

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [None]:
dataset, info = tfds.load('web_nlg', shuffle_files=True, with_info=True)

clear_output()

In [None]:
train_dataset = dataset['train']
valid_dataset = dataset['validation']
test_dataset = dataset['test_all']

In [None]:
train_dataset = train_dataset.take(-1)
valid_dataset = valid_dataset.take(-1)
test_dataset = test_dataset.take(-1)

In [None]:
train_dataset = train_dataset.map(lambda x: (x['input_text']['table']['content'], x['target_text']))
valid_dataset = valid_dataset.map(lambda x: (x['input_text']['table']['content'], x['target_text']))
test_dataset = test_dataset.map(lambda x: (x['input_text']['table']['content'], x['target_text']))

In [None]:
def encode(example, encoder_max_len=100, decoder_max_len=100):

    triples = example[0].numpy().reshape((-1,3)).tolist()
    triples = b' | '.join([b' ; '.join(kw) for kw in triples]).decode()
    sentence = example[1].numpy().decode() 

    triples_plus = f'webNLG: {str(triples)} </s>'
    sentence_plus = f'{sentence} </s>'

    encoder_inputs = tokenizer(triples_plus, pad_to_max_length=True, max_length=encoder_max_len, return_tensors='tf')
    decoder_inputs = tokenizer(sentence_plus, pad_to_max_length=True, max_length=decoder_max_len, return_tensors='tf')

    input_ids = encoder_inputs['input_ids'][0].numpy()
    input_attention = encoder_inputs['attention_mask'][0].numpy()
    target_ids = decoder_inputs['input_ids'][0].numpy()
    target_attention = decoder_inputs['attention_mask'][0].numpy()

    outputs = {'input_ids':input_ids, 'attention_mask': input_attention, 
               'labels':target_ids, 'decoder_attention_mask':target_attention}
    
    return outputs

In [None]:
train_data = []

for elem in train_dataset:
    train_data.append(encode(elem))

valid_data = []

for elem in valid_dataset:
    valid_data.append(encode(elem))

test_data = []

for elem in test_dataset:
    test_data.append(encode(elem))

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [None]:
def serialize_example(element):
    input_ids, att_mask, labels, dec_att_mask = [tf.io.serialize_tensor(x) for x in element.values()]

    feature = {
        'input_ids': _bytes_feature(input_ids),
        'attention_mask': _bytes_feature(att_mask),
        'labels': _bytes_feature(labels),
        'decoder_attention_mask': _bytes_feature(dec_att_mask)
        }

    example = tf.train.Example(features=tf.train.Features(feature=feature))

    return example.SerializeToString()

In [None]:
with tf.io.TFRecordWriter('gs://PATH_TO_BUCKET/output_model/train.tfrecord') as writer:
    for elem in train_data:
        example = serialize_example(elem)
        writer.write(example)

In [None]:
with tf.io.TFRecordWriter('gs://PATH_TO_BUCKET/output_model/validation.tfrecord') as writer:
    for elem in valid_data:
        example = serialize_example(elem)
        writer.write(example)

In [None]:
with tf.io.TFRecordWriter('gs://PATH_TO_BUCKET/output_model/test.tfrecord') as writer:
    for elem in test_data:
        example = serialize_example(elem)
        writer.write(example)