# Tests for generation

## Imports + model initialization

In [1]:
import time
import importlib
import numpy as np
import pandas as pd
import tensorflow as tf

from utils.text import f1_score
from loggers import set_level, add_handler, TIME_LEVEL
from models.qa import MAG
from datasets import get_dataset, prepare_dataset, train_test_split, test_dataset_time

model_name = 'm3_nq_mag_off_entq_ct_wt_ft_doc_3_8_dense'

print("Tensorflow version : {}".format(tf.__version__))

Tensorflow version : 2.6.2


## Model instanciation + dataset loading

In [2]:
model = MAG(nom = model_name, max_to_keep = 2)

print(model)

When using token / word-level tokenizer, it can be useful to add 'detach_punctuation' in cleaners
Model restoration...
Initializing submodel : model !
Successfully restored model from pretrained_models/m3_nq_mag_off_entq_ct_wt_ft_doc_3_8_dense/saving/model.json !
Model m3_nq_mag_off_entq_ct_wt_ft_doc_3_8_dense initialized successfully !

Sub model model
- Inputs 	: unknown
- Outputs 	: unknown
- Number of layers 	: 2
- Number of parameters 	: 409.505 Millions
- Model not compiled

Transfer-learning from : m3_nq_mag_off_entq_ct_wt_ib_3_8_dense
Already trained on 1 epochs (41878 steps)

- Input language : en
- Input vocab (size = 50263) : ['<s>', '<pad>', '</s>', '<unk>', '.', 'Ġthe', ',', 'Ġto', 'Ġand', 'Ġof', 'Ġa', 'Ġin', '-', 'Ġfor', 'Ġthat', 'Ġon', 'Ġis', 'âĢ', "'s", 'Ġwith', 'ĠThe', 'Ġwas', 'Ġ"', 'Ġat', 'Ġit', 'Ġas', 'Ġsaid', 'Ļ', 'Ġbe', 's', 'Ġby', 'Ġfrom', 'Ġare', 'Ġhave', 'Ġhas', ':', 'Ġ(', 'Ġhe', 'ĠI', 'Ġhis', 'Ġwill', 'Ġan', 'Ġthis', ')', 'ĠâĢ', 'Ġnot', 'Ŀ', 'Ġyou', 'ľ', 'Ġthei

In [3]:
datasets = 'squad'# if 'nq' not in model_name else 'nq'
#datasets = ['nq', 'squad']

use_doc = True if 'nq' in datasets and 'doc' in model_name else False

dataset = get_dataset(
    datasets, clean_text = True, skip_impossible = True, keep_mode = 'all', shuffle = True, use_long_answer = False,
    include_document = use_doc
)
train, valid = dataset['train'], dataset['valid']

print("Dataset length :\n  Training set : {}\n  Validation set : {}".format(
    len(train), len(valid)
))

Loading dataset squad...
Dataset length :
  Training set : 86821
  Validation set : 5928


## Prediction

In [4]:
import custom_architectures.transformers_arch.text_transformer_arch as text_transformer_arch
import custom_architectures.transformers_arch.generation_utils as generation_utils

import models
import models.qa
import models.qa.mag
import models.qa.answer_generator_split
import models.qa.base_generator
import models.qa.base_qa

importlib.reload(generation_utils)
importlib.reload(text_transformer_arch)

importlib.reload(models.qa.base_qa)
importlib.reload(models.qa.base_generator)
importlib.reload(models.qa.answer_generator_split)
importlib.reload(models.qa.mag)
importlib.reload(models.qa)
importlib.reload(models)




In [5]:
import os
import logging

from loggers import timer
from utils import *

set_level('time')

time_logger = logging.getLogger('timer')

@timer
def predict(self, question, context = None, metrics = None, save = False,
            directory = None, filename = 'map.json', tqdm = lambda x: x, ** kwargs):
    time_logger.start_timer('processing')
    
    pred_config = self.training_hparams.extract(kwargs, pop = True)
    self.init_train_config(** pred_config)

    if metrics is not None: metrics = self.get_compiled_metrics(metrics, add_loss = False)
    
    if isinstance(question, pd.DataFrame): question = question.to_dict('record')
    if not isinstance(question, list): question = [question]

    if context is not None:
        if not isinstance(context, list) or len(context) != len(question): context = [context]
        if len(context) == 1 and len(question) > 1: context = context * len(question)

    data = question if context is None else []
    if context is not None:
        for i, q in enumerate(question):
            if not isinstance(q, dict): q = {'question' : q}
            if len(context) == len(question):
                c = context[i] if isinstance(context[i], dict) else {'context' : context[i]}
            else:
                c = {'context' : context} if not isinstance(context, dict) else context
            data.append({** q, ** c})

    time_logger.stop_timer('processing')
    
    infos_pred = {}
    if save:
        if directory is None: directory = self.pred_dir
        if filename is None or '.json' in directory: filename, directory = directory, None
        else: filename = os.path.join(directory, filename)
        if directory is not None: os.makedirs(directory, exist_ok = True)
        
        infos_pred = load_json(filename)
    
    answers = []
    for idx, row in enumerate(tqdm(data)):
        context = row['context']
        if isinstance(context, list): context = tuple(context)
        if row['question'] not in infos_pred or context not in infos_pred[row['question']]:
            inputs = [tf.expand_dims(inp, axis = 0) for inp in self.get_input(row)]

            if not self.filter_inputs(inputs):
                logging.warning('Too long data at index {} : {}'.format(
                    idx, [tuple(inp.shape) for inp in inputs]
                ))
                continue
            
            pred = self.infer(inputs, training = False, ** kwargs)

            scores      = pred.score[0].numpy()
            pred_text   = self.decode_text(pred.tokens, remove_tokens = True)[0]
            if not isinstance(pred_text, (list, tuple)):
                pred_text, scores = [pred_text], [scores]

            infos_pred.setdefault(row['question'], {})
            infos_pred[row['question']][context] = {'candidates' : []}
            if 'answers' in row and metrics is not None:
                infos_pred[row['question']][context]['target'] = row['answers']
                target = [
                    tf.expand_dims(out, axis = 0) for out in self.get_output(row)
                ]
            
            for i, (txt, s) in enumerate(zip(pred_text, scores)):
                metrics_i = {}
                if 'answers' in row and metrics is not None:
                    time_logger.start_timer('metric')
                    metrics.reset_states()
                    metrics.update_state(target, pred.tokens[:, i])
                    metrics_i = {
                        name : val for name, val in zip(metrics.metric_names, metrics.result().numpy())
                    }
                    time_logger.stop_timer('metric')

                infos_i = {
                    'text'  : txt,
                    'score' : s,
                    ** metrics_i
                }
                infos_pred[row['question']][context]['candidates'].append(infos_i)
        
        answers.append(infos_pred[row['question']][context])

    if save:
        dump_json(filename, infos_pred, indent = 4)
    
    return answers

pred = predict(
    model, valid.iloc[:10], method = 'beam', metrics = ['f1'], save = False,
    max_input_length = 512
)
pred

timers :
  - predict : 25.883 sec
    - processing : 0.009 sec
    - inference executed 10 times : 25.353 sec (2.535 sec / exec)
      - infer executed 10 times : 25.353 sec (2.535 sec / exec)
        - call executed 10 times : 2.410 sec (0.241 sec / exec)
          - embed executed 10 times : 1.901 sec (0.190 sec / exec)
            - embed executed 20 times : 1.901 sec (0.095 sec / exec)
              - build_mask executed 20 times : 0.039 sec (0.002 sec / exec)
              - token embedding executed 20 times : 0.059 sec (0.003 sec / exec)
                - embed_tokens executed 20 times : 0.014 sec (0.001 sec / exec)
                - embed_token_types executed 20 times : 0.000 sec (0.000 sec / exec)
                - embed_positions executed 20 times : 0.013 sec (0.001 sec / exec)
              - layer call executed 160 times : 1.752 sec (0.011 sec / exec)
              - format_output executed 20 times : 0.000 sec (0.000 sec / exec)
              - subsample executed 10 times : 

[{'candidates': [{'text': 'Sisters of Society',
    'score': -1.5426561,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0},
   {'text': 'Catholic',
    'score': -2.540998,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0},
   {'text': 'the Sisters of Society',
    'score': -2.7306073,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0},
   {'text': 'the Society of Sisters',
    'score': -2.9683392,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0},
   {'text': 'Sisters of Sisters',
    'score': -3.2666168,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0}],
  'target': ['McCrary']},
 {'candidates': [{'text': 'nuclear',
    'score': -0.6824541,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0},
   {'text': 'solar, coal and nuclear',
    'score': -2.5996037,
    'EM': 0.0,
    'F1': 0.0,
    'precision': 0.0,
    'recall': 0.0},
   {'text': 'fossil power plants',
  

In [5]:
set_level('time')

pred = model.predict(
    valid.iloc[:10], method = 'beam', metrics = ['f1'], save = False,
    max_input_length = 512
)


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.11it/s]

timers :
  - predict : 9.027 sec
    - processing : 0.006 sec
    - metrics executed 50 times : 0.437 sec (0.009 sec / exec)





In [None]:
set_level('info')

config = model.get_dataset_config(batch_size = 5, is_validation = True, shuffle_size = 0)
ds = prepare_dataset(valid.sample(25, random_state = 0), ** config, is_rectangular = False)

n = 5
for batch in ds.take(1):
    inputs, target = batch
    inputs = [inp[:n] for inp in inputs]
    infer = model.infer(inputs[:-2], method = 'beam', num_beams = 10, num_sentences = 5, max_length = 50)
    if infer is not None:
        for beams, scores, tar in zip(infer.tokens, infer.score, target[0]):
            text  = model.decode_text(beams)
            target = model.decode_text(tar)
            f1 = f1_score(target, text, as_matrix = True)
            print("Target : {}".format(target))
            print(model.infer_to_str(model.decode_text(beams), scores))
            print("Score : {}".format(f1))
            #for s, t in zip(scores, text):
        #        print("Infer ({:.3f}) : {}".format(s, t))
        #    print()
    
    #model.predict_with_target(batch, n_pred = 10, debug = False)


In [None]:
set_level('info')

config = model.get_dataset_config(batch_size = 5, is_validation = True, shuffle_size = 0)
ds = prepare_dataset(valid.sample(25, random_state = 0), ** config, is_rectangular = not use_doc)

n = 5
for batch in ds.take(1):
    inputs, target = batch
    inputs = [inp[:n] for inp in inputs]
    infer = model.infer(inputs[:-2], method = 'greedy', num_beams = 10, num_sentences = 5, max_length = 50)
    if infer is not None:
        for beams, scores, tar in zip(tf.expand_dims(infer.tokens, 1), tf.expand_dims(infer.score, 1), target[0]):
            text  = model.decode_text(beams)
            print("Target : {}".format(model.decode_text(tar)))
            for s, t in zip(scores, text):
                print("Infer ({:.3f}) : {}".format(s, t))
            print()
    
    #model.predict_with_target(batch, n_pred = 10, debug = False)


In [7]:
import json

from tqdm import tqdm

question = [
    'How is the night vision of cat ?',
    'How is the night vision of cat ?',
    'What is the anoatomy of a cat ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'How many paws does a cat have ?',
    'What is the origin of life ?'
]
context  = [
    'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.',
    [p.strip() + '.' for p in 'The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.'.split('.') if len(p) > 0],
    ['The cat is similar in anatomy to the other felid species: it has a strong flexible body, \
quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well \
developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-\
specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. \
It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and \
perceives pheromones.', 'The answer to everything is 42'],
    'A cat is an animal which has 4 paws and whiskers.',
    'A cat is an animal which has 4 paws and whiskers. However, everyone knows that the answer to everything is 42 !',
    ['A cat is an animal which has 4 paws and whiskers.', 'However, everyone knows that the answer to everything is 42 !'],
    'The answer to everything is 42.'
]

n = 1
#question, context = question[n], [context[n]]

if not isinstance(question, list): question = [question]
if not isinstance(context, list): context = [context]

answers = predict(model, question, [{'context' : c, 'title' : 'cat'} for c in context], method = 'beam', tqdm = tqdm)

for q, c, a in zip(question, context, answers):
    print("Question : {}\nContext : {}\nAnswer : {}\n".format(q, c, json.dumps(to_json(a), indent = 4)))

100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:35<00:00,  5.06s/it]

Question : How is the night vision of cat ?
Context : The cat is similar in anatomy to the other felid species: it has a strong flexible body, quick reflexes, sharp teeth and retractable claws adapted to killing small prey. Its night vision and sense of smell are well developed. Cat communication includes vocalizations like meowing, purring, trilling, hissing, growling and grunting as well as cat-specific body language. A predator that is most active at dawn and dusk (crepuscular), the cat is a solitary hunter but a social species. It can hear sounds too faint or too high in frequency for human ears, such as those made by mice and other small mammals.[7] It secretes and perceives pheromones.
Answer : [
    {
        "text": "well developed",
        "score": -1.950272560119629
    },
    {
        "text": "Well developed",
        "score": -3.319387197494507
    },
    {
        "text": "is well developed",
        "score": -3.528207540512085
    },
    {
        "text": "best",
      


