In [None]:
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import tensorflow as tf
tf.executing_eagerly()
tf.get_logger().setLevel('ERROR')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [15, 10]

from tqdm import tqdm

from stog.utils.params import Params
from stog.data.dataset_builder import dataset_from_params, iterator_from_params
from stog.data.vocabulary import Vocabulary
from stog.training.trainer import Trainer
from stog.data.dataset import Batch
from model.text_to_amr import TextToAMR

from tensorflow.keras.layers import Embedding, Input, Dense, Flatten, LSTM, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.losses import MeanSquaredError

In [None]:
params = Params.from_file("../model/model_params.yaml")

In [None]:
data_params = params['data']
dataset = dataset_from_params(data_params)

In [None]:
train_data = dataset['train']
dev_data = dataset.get('dev')
test_data = dataset.get('test')

In [None]:
vocab_params = params.get('vocab', {})
vocab = Vocabulary.from_instances(instances=train_data, **vocab_params)
vocab.save_to_files("../data/processed/serialization")

dataset = Batch(train_data)
dataset.index_instances(vocab)
dataset.index_instances(vocab)
print(dataset)


for key in dataset.as_tensor_dict():
    print(key)
    content = dataset.as_tensor_dict()[key]
    if isinstance(content, dict):
        for inner_key in content:
            print("  ", inner_key)

train_iterator, dev_iterater, test_iterater = iterator_from_params(vocab, data_params['iterator'])

train_dataset = Batch(train_data)
train_dataset.index_instances(vocab)

test_dataset = Batch(test_data)
test_dataset.index_instances(vocab)

train_dataset = train_dataset.as_tensor_dict()
test_dataset = test_dataset.as_tensor_dict()

In [None]:
def create_model_input(encoder_input, decoder_input, generator_input, parser_input):
    token_encoder_input = encoder_input.get('token')
    pos_encoder_input = encoder_input.get('pos_tag')
    mask_encoder_input = encoder_input.get('mask')
    token_decoder_input = decoder_input.get('token')
    pos_decoder_input = decoder_input.get('pos_tag')

    copy_attention_map_input = generator_input.get('copy_attention_maps')
    coref_attention_map_input = generator_input.get('coref_attention_maps')
    vocab_target_input  = generator_input.get('vocab_targets')
    coref_target_input  = generator_input.get('coref_targets')
    copy_target_input  = generator_input.get('copy_targets')

    edge_heads_input = parser_input.get('edge_heads')
    edge_labels_input = parser_input.get('edge_labels')
    parser_mask_input = parser_input.get('mask')
    coref_input = parser_input.get('corefs')


    # token_encoder_input, 
    # pos_encoder_input, 
    # token_decoder_input, 
    # pos_decoder_input, 
    # copy_attention_maps_input, 
    # coref_attention_maps_input,
    # mask_input,
    # edge_heads_input,
    # edge_labels_input,
    # corefs_input,

    model_input = [
        token_encoder_input, 
        pos_encoder_input, 
        token_decoder_input, 
        pos_decoder_input, 
        copy_attention_map_input, 
        coref_attention_map_input,
        parser_mask_input, 
        edge_heads_input,
        edge_labels_input,
        coref_input,
        vocab_target_input,
        coref_target_input,
        copy_target_input,
        mask_encoder_input
    ]
    
    return model_input


In [None]:
print(vocab)

###### Test Model

In [None]:
text_to_amr = TextToAMR(vocab)
encoder_input, decoder_input, generator_input, parser_input = text_to_amr.prepare_input(train_dataset)

In [None]:
train_model_input = create_model_input(encoder_input, decoder_input, generator_input, parser_input)
train_model_input = [e.astype('int32') for e in train_model_input]

In [None]:
print(train_dataset['tgt_tokens']['decoder_tokens'][0])

### Train

In [None]:
total_losses = []
token_losses = []
edge_losses = []

EPOCHS = 100

epoch_tqdm = tqdm(range(EPOCHS))

for epoch in epoch_tqdm:
    loss, token_loss, edge_loss = text_to_amr.train(train_model_input)
    total_losses.append(loss)
    token_losses.append(token_loss)
    edge_losses.append(edge_loss)
    epoch_tqdm.set_description("TOKEN LOSS: {:.4f}, EDGE LOSS: {:.4f}, TOTAL LOSS: {:.4f}".format(float(token_loss.numpy()), \
                                             float(edge_loss.numpy()), \
                                            float(loss.numpy())))

In [None]:
plt.stackplot([i for i in range(EPOCHS)], token_losses, edge_losses, labels=['Token Loss', 'Edge Loss'])
plt.title("TRAIN 200 EPOCH")
plt.legend()
plt.show()

In [None]:
plt.plot(edge_losses, label="Edge Losses", color="orange")
plt.legend()
plt.show()

In [None]:
plt.plot(token_losses, label="Token Losses")
plt.legend()
plt.show()

## Inference

In [None]:
test_dataset = dataset_from_params(data_params, evaluation=True)

In [None]:
test_data = test_dataset['test']
test_data = Batch(test_data)
test_data.index_instances(vocab)
test_data_tensor = test_data.as_tensor_dict()

In [None]:
encoder_input, decoder_input, generator_input, parser_input = text_to_amr.prepare_input(test_data_tensor)

In [None]:
test_model_input = create_model_input(encoder_input, decoder_input, generator_input, parser_input)
test_model_input += [test_data_tensor['src_copy_vocab'], test_data_tensor['tag_lut'], test_data_tensor.get('source_copy_invalid_ids', None)]

In [None]:
outputs = text_to_amr.predict(test_model_input)

In [None]:
outputs['heads']

In [None]:
outputs['head_labels'][0]

In [None]:
for node in outputs['nodes']:
    print(node)