# Example for the Answer Generator Split model

## Imports + model initialization

In [None]:
import time
import numpy as np
import pandas as pd
import tensorflow as tf

from tqdm import tqdm

from models.qa import AnswerGeneratorSplit
from utils.text import TextEncoder
from utils import plot, plot_multiple, plot_embedding, plot_confusion_matrix, set_display_options
from datasets import get_dataset, prepare_dataset, train_test_split, test_dataset_time

set_display_options()

model_name = 'test_nq_qa_generator_split_5_mean_2'
bert_base  = 'facebook/bart-large'

print("Tensorflow version : {}".format(tf.__version__))

In [None]:
config = {
    'lang'             : 'en',
    'input_format'     : ['{question}', '{context}'],
    'output_format'    : '{answer}',
    'text_encoder'     : TextEncoder.from_transformers_pretrained(bert_base),
    'max_input_length' : 512,
    
    'pretrained' : bert_base,
    'encoder_subsampling_step'   : 5,
    'encoder_subsampling_offset' : 1,
    'subsample_after' : True,
    'encoder_subsampling_mode'   : 'mean'
}

tf.config.set_visible_devices([], 'GPU')
model = AnswerGeneratorSplit(nom = model_name, ** config)

print(model)

In [None]:
model.summary()
model.model.encoder.summary()

## Model instanciation + dataset loading

In [None]:
model = AnswerGeneratorSplit(nom = model_name)

lr = {'name' : 'WarmupScheduler', 'maxval' : 5e-5,'minval' : 1e-5, 'factor' : 512, 'warmup_steps' : 8192}
lr = 5e-5

model.compile(optimizer = 'adam', optimizer_config = {'lr' : lr}, metrics = ['TextAccuracy'])
print(model)

In [None]:
datasets = 'squad' #if 'nq' not in model_name else 'nq'

dataset = get_dataset(datasets, clean_text = True, skip_impossible = True, keep_only_first = True)
train, valid = dataset['train'], dataset['valid']


print("Dataset length :\n  Training set : {}\n  Validation set : {}".format(
    len(train), len(valid)
))

In [None]:
train.dropna(axis = 'index', inplace = True)
valid.dropna(axis = 'index', inplace = True)


print("Dataset length :\n  Training set : {}\n  Validation set : {}".format(
    len(train), len(valid)
))

In [None]:
freqs = np.array([model.encode_data(row)[1][1] for row in tqdm(train.to_dict('records'))])
print(freqs)
plot(freqs, plot_type = 'hist')


In [None]:
print(np.sum(freqs > 128))

## Training

In [None]:
fine_tuning = True

if fine_tuning:
    model.get_optimizer().learning_rate.assign(1e-5)

epochs = 5 if fine_tuning else 1
if datasets == 'squad':
    batch_size = 8 if fine_tuning else 32
else:
    batch_size = 6 if fine_tuning else 32
shuffle_size = 0 if model.epochs + epochs < 3 else batch_size * 32

augment_prct = 0.
nb_mask = 1
min_mask_length = 1
max_mask_length = 1

in_batch_negatives = True

max_input_length = 512
max_output_length = 128

print("Training samples   : {} - {} batches".format(len(train), len(train) // batch_size))
print("Validation samples : {} - {} batches".format(len(valid), len(valid) // (batch_size * 2)))

model.model.freeze(trainable = fine_tuning)

hist = model.train(
    train, validation_data = valid, 
    epochs = epochs, batch_size = batch_size, valid_batch_size = 2.,
    shuffle_size = shuffle_size, max_input_length = max_input_length, max_output_length = max_output_length,
    in_batch_negatives = in_batch_negatives,
    
    augment_prct = augment_prct, nb_mask = nb_mask, min_mask_length = min_mask_length, max_mask_length = max_mask_length
)

In [None]:
model.plot_history()
print(model.history)

## Evaluate

In [None]:
model.test(valid)

## Prediction

In [None]:
valid['question'] = valid['question'].apply(lambda q: q + ' ?')

In [None]:
config = model.get_dataset_config(batch_size = 5, is_validation = True, shuffle_size = 0)
ds = prepare_dataset(valid.sample(10, random_state = 0), ** config, debug = True)

for batch in ds:
    model.predict_with_target(batch, n_pred = 5)


In [None]:
def predict(self, question, context = None, ** kwargs):
    if not isinstance(question, list): question = [question]
    if context is not None:
        if not isinstance(context, list) or len(context) != len(question): context = [context]
    if len(context) == 1 and len(question) > 1: context = context * len(question)
        
    data = question if context is None else []
    if context is not None:
        for i, q in enumerate(question):
            if not isinstance(q, dict): q = {'question' : q}
            if len(context) == len(question):
                c = context[i] if isinstance(context[i], dict) else {'context' : context[i]}
            else:
                c = {'context' : context} if not isinstance(context, dict) else context
            data.append({** q, ** c})
        
    answers = []
    for row in data:
        inputs = [tf.expand_dims(inp, axis = 0) for inp in self.get_input(row)]
        
        pred = self.infer(inputs, training = False)

        answers.append(self.decode_text(pred[0], remove_tokens = True))
    
    return answers

question = [
    'How is the night vision of cat ?',
    'How is the night vision of cat ?',
    'What is the anoatomy of a cat ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'What is the origin of life ?'
]
context  = [
    'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.',
    [p.strip() + '.' for p in 'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.'.split('...') if len(p) > 0],
    ['The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.', 'The answer to everything is 42'],
    'A cat is an animal which has 4 paws and whiskers.',
    'A cat is an animal which has 4 paws and whiskers. However, everyone knows that the answer to everything is 42 !',
    'The answer to everything is 42.'
]

#question, context = question[0], context[0]

if not isinstance(question, list): question = [question]
if not isinstance(context, list): context = [context]

answers = predict(model, question, context)

for q, c, a in zip(question, context, answers):
    print("Question : {}\nContext : {}\nAnswer : {}\n".format(q, c, a))

In [None]:
print(model.text_encoder._id_to_symbol[2])
print(model.text_encoder.eos_token_idx)
print(model.text_encoder.decode([0, 3714, 2]))
print(chr(model.text_encoder.byte_encoder_inv['/']))

model.text_encoder.special_tokens

## Tests

In [None]:
config = model.get_dataset_config(batch_size = 16, is_validation = False, shuffle_size = 0)
ds = prepare_dataset(valid, ** config, debug = True)

test_dataset_time(ds, steps = 1000)

In [None]:
from custom_train_objects.optimizers import WarmupScheduler

lr = WarmupScheduler(maxval = 1e-3, minval = 1e-4, factor = 256, warmup_steps = 4096)
lr.plot(25000)

In [None]:
lr = model.get_optimizer().learning_rate
lr.assign(5e-4)

In [None]:
lr = model.get_optimizer().learning_rate
print(lr)

In [None]:
valid.head()

In [None]:
print(model.text_encoder)
print(model.get_input("Hello", "World !"))

In [None]:
tf.cast(tf.sequence_mask([4, 3], maxlen = 4), tf.float32)

In [None]:
x = tf.reshape(tf.range(9), [3, 1, 3])

In [None]:
import os
import pandas as pd
import tensorflow as tf

from datasets import get_dataset, prepare_dataset, test_dataset_time
from utils.text import TextEncoder

valid = get_dataset('nq', modes = 'valid', include_document = True)['valid']

text_encoder = TextEncoder.from_transformers_pretrained('facebook/bart-large')
question_format = '{question}'
context_format  = '{context}'

In [None]:
def format_question(question, ** kwargs):
    return text_encoder.format(question_format, question = question, ** kwargs)
    
def format_context(context, title = None, ** kwargs):
    return text_encoder.format(context_format, context = context, title = title, ** kwargs)
    
def tf_format_question(data):
    q_text = data if not isinstance(data, (dict, pd.Series)) else data.get('question', '')
    
    encoded_text, token_types = tf.py_function(
        format_question, [q_text], Tout = [tf.int32, tf.int32]
    )
    encoded_text.set_shape([None])
    
    return encoded_text

def tf_format_context(data):
    if not isinstance(data, (dict, pd.Series)): data = {'context' : data}
    
    encoded_text, token_types = tf.py_function(
        format_context, [data.get('context', ''), data.get('title', '')], Tout = [tf.int32, tf.int32]
    )
    encoded_text.set_shape([None])
    
    return encoded_text

        
def get_input(data):
    q_tokens = tf_format_question(data)
    
    if isinstance(data['context'], list):
        contexts = [tf_format_context(c, t) for t, c in zip(data['title'], data['context'])]
        
        outputs = (q_tokens, len(q_tokens))
        for c in contexts: outputs += (c, len(c))
        
        return outputs
        
    c_tokens = tf_format_context(data)
    
    return (q_tokens, len(q_tokens), c_tokens, len(c_tokens))

def get_dataset_config(** kwargs):
    kwargs.update({
        'batch_before_map'  : True,
        'padded_batch'      : True,
        'pad_kwargs'        : {
            'padded_shapes'     : (
                ((None,), (), (None, None), (None, ))
            ),
            'padding_values'    : (
                (text_encoder.blank_token_idx, 0, text_encoder.blank_token_idx, 0)
            )
        }
    })
        
    return kwargs

config = get_dataset_config(batch_size = 2)

valid['context'] = valid['paragraphs']
valid['title']   = valid['titles']

dataset = prepare_dataset(valid, is_rectangular = False, encode_fn = get_input, ** config, debug = True)
test_dataset_time(dataset, steps = 250)

In [None]:
valid

In [None]:
help(tf.data.Dataset.from_generator)