In [102]:
import io
import os
import re
import time
import unicodedata

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers.experimental import preprocessing

import wandb
from deepcomedy.models.transformer import *
from deepcomedy.preprocessing import *
from deepcomedy.utils import *
from deepcomedy.metrics import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 1. Data loading and preprocessing

In [2]:
raw_text = open("./data/divina_textonly.txt", "rb").read().decode(encoding="utf-8")
raw_syll_text = (
    open("./data/divina_syll_textonly.txt", "rb").read().decode(encoding="utf-8")
)
syll_text = preprocess_text(raw_syll_text, end_of_tercet="")
text = preprocess_text(raw_text, end_of_tercet="")

Split preprocessed text into verses

In [3]:
sep = "<EOV>"
input_verses = [x + sep for x in text.split(sep)][:-1]
target_verses = [x + sep for x in syll_text.split(sep)][:-1]

Encode with tokenizer

In [4]:
tokenizer = tf.keras.preprocessing.text.Tokenizer(
    char_level=False, filters="", lower=False
)
tokenizer.fit_on_texts(target_verses)
enc_input_verses = tokenizer.texts_to_sequences(input_verses)
enc_target_verses = tokenizer.texts_to_sequences(target_verses)
vocab_size = len(tokenizer.word_index) + 1

Pad sequences

In [5]:
input_text = tf.keras.preprocessing.sequence.pad_sequences(
    enc_input_verses, padding="post"
)
target_text = tf.keras.preprocessing.sequence.pad_sequences(
    enc_target_verses, padding="post"
)

In [6]:
input_train, input_test, target_train, target_test = train_test_split(
    input_text, target_text
)

In [7]:
batch_size = 32
dataset = make_dataset(input_train, target_train, batch_size=batch_size)
validation_dataset = make_dataset(input_test, target_test, batch_size=batch_size)

## 2. Training

In [None]:
sweep_config = {
    "name": "sweep-test-1",
    "method": "grid",
    "metric": {"name": "loss", "goal": "minimize"},
    "parameters": {
        "batch_size": {"value": 32},
        "epochs": {"values": 10},
        "num_layers": {"values": [4, 8, 12]},
        "num_heads": {"values": [4, 8]},
        "d_model": {"values": [128, 256]},
        "dff": {"value": 1024},
        # TODO include architecture + dataset
    },
}

sweep_id = wandb.sweep(sweep_config, project="deepcomedy", entity="deepcomedy")

In [None]:
def sweep():
    with wandb.init() as run:
        config = wandb.config
        dataset = make_dataset(input_train, target_train, config["batch_size"])
        model, trainer = make_model(config)
        trainer.train(dataset, config["epochs"], log_wandb=True)


wandb.agent(sweep_id, function=sweep)

In [10]:
best_config = {"num_layers": 4, "d_model": 256, "num_heads": 4, "dff": 1024}

transformer, transformer_trainer = make_transformer_model(
    best_config, vocab_size, vocab_size, checkpoint_save_path=None
)

In [11]:
transformer_trainer.train(dataset, 10)

Epoch 1 Batch 0 Loss 4.7690 Accuracy 0.0115
Epoch 1 Batch 50 Loss 3.8194 Accuracy 0.1316
Epoch 1 Batch 100 Loss 3.4262 Accuracy 0.1715
Epoch 1 Batch 150 Loss 3.2363 Accuracy 0.1933
Epoch 1 Batch 200 Loss 3.0199 Accuracy 0.2273
Epoch 1 Batch 250 Loss 2.8322 Accuracy 0.2575
Epoch 1 Batch 300 Loss 2.6857 Accuracy 0.2810
Epoch 1 Loss 2.6094 Accuracy 0.2937
Time taken for 1 epoch: 29.22 secs

Epoch 2 Batch 0 Loss 1.8332 Accuracy 0.4261
Epoch 2 Batch 50 Loss 1.8122 Accuracy 0.4312
Epoch 2 Batch 100 Loss 1.7746 Accuracy 0.4397
Epoch 2 Batch 150 Loss 1.7385 Accuracy 0.4491
Epoch 2 Batch 200 Loss 1.6993 Accuracy 0.4599
Epoch 2 Batch 250 Loss 1.6459 Accuracy 0.4759
Epoch 2 Batch 300 Loss 1.5578 Accuracy 0.5037
Epoch 2 Loss 1.4831 Accuracy 0.5278
Time taken for 1 epoch: 16.92 secs

Epoch 3 Batch 0 Loss 0.6504 Accuracy 0.7937
Epoch 3 Batch 50 Loss 0.5474 Accuracy 0.8276
Epoch 3 Batch 100 Loss 0.4717 Accuracy 0.8513
Epoch 3 Batch 150 Loss 0.4157 Accuracy 0.8681
Epoch 3 Batch 200 Loss 0.3749 Accurac

In [12]:
start_symbol = tokenizer.word_index['<GO>']
stop_symbol = tokenizer.word_index['<EOV>']

In [92]:
encoder_input = tf.convert_to_tensor(input_test[:100])
decoder_input = tf.repeat([[start_symbol]], repeats=encoder_input.shape[0], axis=0)

In [93]:
output = evaluate(transformer, encoder_input, decoder_input,  stop_symbol, stopping_condition=stop_after_stop_symbol)

In [95]:
# Only take output before the first end of verse
stripped_output = list(map(lambda x: x.split('<EOV>')[0], tokenizer.sequences_to_texts(output.numpy())))

In [96]:
stripped_output = list(map(strip_tokens, stripped_output))

In [97]:
stripped_output

['|lu|cen|te |più |as|sai |di |quel |ch’ el|l’ |e|ra.',
 '|che |si |sta|va|no a |l’ om|bra |die|tro al |sas|so',
 '|Poi, |ral|lar|ga|ti |per |la |stra|da |so|la,',
 '|Po|scia |ch’ io |v’ eb|bi al|cun |ri|co|no|sciu|sco,',
 '|e |co|me |quel |ch’ è |pa|sto |la |ri|mi|ra;',
 '|con |le |quai |la |tua |E|ta|ca |per|trat|ta',
 '|ma |noi |siam |pe|re|grin |co|me |voi |sie|te.',
 '|La |lin|gua |ch’ io |par|lai |fu |tut|ta |spen|ta',
 '|che |guar|da ’l |pon|te, |che |Fio|ren|za |fes|se',
 '« |Io |sa|rò |pri|mo, e |tu |sa|rai |se|con|do».',
 '|por|re un |uom |per |lo |po|po|lo a’ |mar|tì|ri.',
 '|pri|ma |che |pos|sa |tut|ta in |sé |mu|tar|si;',
 '|con|tra ’l |di|sio, |fo |ben |ch’ io |non |di|man|do”.',
 '|se |non |co|me |tri|sti|zia o |se|te o |fa|me:',
 '|vie |più |lu|cen|do, |co|mi|cia|ron |can|ti',
 '|E |se |più |fu |lo |suo |par|lar |dif|fu|so,',
 '|a |Ce|pe|ran, |là |do|ve |fu |bu|giar|do',
 '|al |mio |di|sio |cer|ti|fi|ca|to |fer|mi.',
 '|non |fos|se |sta|ta a |Ce|sa|re |no|ver|ca,',
 '|p

In [99]:
correct_syll = target_test[:100]
correct_syll = ' '.join(tokenizer.sequences_to_texts(correct_syll))
correct_syll = strip_tokens(correct_syll)
correct_syll = correct_syll.split('\n')

In [110]:
exact_matches, similarities = zip(*validate_syllabification(stripped_output, correct_syll))

In [111]:
accuracy = sum(exact_matches) / len(exact_matches)
avg_similarities = np.mean(similarities)

In [109]:
print('Syllabification exact matches: {:.2f}%'.format(accuracy * 100))

Syllabification exact matches: 76.00%


In [112]:
print('Average similarity: {:.2f}'.format(avg_similarities))

Average similarity: 0.99


In [121]:
stripped_output = np.array(stripped_output)
correct_syll = np.array(correct_syll)
error_mask = ~np.array(exact_matches)

errors_output = stripped_output[error_mask]
errors_correct = correct_syll[error_mask]

In [126]:
errors_correct[1]

'|Po|scia |ch’ io |v’ eb|bi al|cun |ri|co|no|sciu|to,'

In [127]:
errors_output[1]

'|Po|scia |ch’ io |v’ eb|bi al|cun |ri|co|no|sciu|sco,'