<a href="https://colab.research.google.com/github/DanieleVeri/deep_comedy/blob/master/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Import & seed
import time
import re
import os
import math
import numpy as np
from matplotlib import pyplot as plt
import requests
import collections
import pickle
import copy, random
import nltk as nl
nl.download('punkt')
from itertools import zip_longest

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (
    Input, Reshape, BatchNormalization, Dense, Dropout, concatenate,
    Embedding, LSTM, Dense, GRU, Bidirectional, Add
)
from tensorflow.keras.activations import elu, relu, softmax, sigmoid
from tensorflow.keras import regularizers
from tensorflow.keras.utils import to_categorical

print(tf.__version__)

np.random.seed(1234)
!nvidia-smi

In [None]:
#@title Setup wandb
!pip install wandb
!wandb login f57cb185d23a8b60d349a4ea02278a6eee82550a
import wandb
wandb.init(project="deep_comedy", name="lr 13e-5 voc1800")

In [3]:
##@title Model

vocab_size = 1800
terces_per_batch = 4
terces_len = 75

batch_len = terces_per_batch * (terces_len + 1)

wandb.config.num_layers = 4
wandb.config.d_model = 128
wandb.config.dff = 256
wandb.config.num_heads = 4
wandb.config.dropout = 0.1
input_vocab_size = vocab_size
target_vocab_size = vocab_size
EPOCHS = 250
learning_rate = 15e-5

def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)
    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

def create_padding_mask(seq):   
    seq = tf.cast(tf.math.equal(seq, pad), tf.float32)
    # add extra dimensions to add the padding
    # to the attention logits.
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)

def scaled_dot_product_attention(q, k, v, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.
    
    Args:
        q: query shape == (..., seq_len_q, depth)
        k: key shape == (..., seq_len_k, depth)
        v: value shape == (..., seq_len_v, depth_v)
        mask: Float tensor with shape broadcastable 
            to (..., seq_len_q, seq_len_k). Defaults to None.
        
    Returns:
        output, attention_weights
    """
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  
    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
    return output, attention_weights

def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
        tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
    ])


class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.dense = tf.keras.layers.Dense(d_model)
            
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
        
    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]
        
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
        
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
        
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention, 
                                    (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output, attention_weights

class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        
    def __call__(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)
        
        return out2

class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)
    
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)
        
        
    def __call__(self, x, enc_output, training, 
            look_ahead_mask, padding_mask):
        # enc_output.shape == (batch_size, input_seq_len, d_model)

        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)
        
        attn2, attn_weights_block2 = self.mha2(
            enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)
        
        ffn_output = self.ffn(out2)  # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)
        
        return out3, attn_weights_block1, attn_weights_block2

class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                                self.d_model)
        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                        for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)
            
    def __call__(self, x, training, mask):
        seq_len = tf.shape(x)[1]
        # adding embedding and position encoding.
        x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x, training=training)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)
        return x  # (batch_size, input_seq_len, d_model)

class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
                maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                        for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)
        
    def __call__(self, x, enc_output, training, look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        attention_weights = {}
        x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x, training=training)
        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                                look_ahead_mask, padding_mask)
            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
        
        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights

class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
                target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                            input_vocab_size, pe_input, rate)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                            target_vocab_size, pe_target, rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
        
    def __call__(self, inp, tar, training, enc_padding_mask, 
            look_ahead_mask, dec_padding_mask):

        enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
        
        # dec_output.shape == (batch_size, tar_seq_len, d_model)
        dec_output, attention_weights = self.decoder(
            tar, enc_output, training, look_ahead_mask, dec_padding_mask)
        
        final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)
        
        return final_output, attention_weights

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=1000):
        super(CustomSchedule, self).__init__()
        
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps
        
    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

#learning_rate = CustomSchedule(wandb.config.d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

val_loss = tf.keras.metrics.Mean(name='train_loss')
val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

transformer = Transformer(wandb.config.num_layers, wandb.config.d_model, 
                          wandb.config.num_heads, wandb.config.dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=wandb.config.dropout)
def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)
    
    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)
    
    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by 
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
    
    return enc_padding_mask, combined_mask, dec_padding_mask

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

train_step_signature = [
    tf.TensorSpec(shape=(None, batch_len), dtype=tf.int64),
    tf.TensorSpec(shape=(None, batch_len), dtype=tf.int64),
]

@tf.function()#(input_signature=train_step_signature)
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    
    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp, 
                                    True, 
                                    enc_padding_mask, 
                                    combined_mask, 
                                    dec_padding_mask)
        loss = loss_function(tar_real, predictions)
        gradients = tape.gradient(loss, transformer.trainable_variables)    
        optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    
    train_loss(loss)
    train_accuracy(tar_real, predictions)

@tf.function()#(input_signature=train_step_signature)
def val_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    
    predictions, _ = transformer(inp, tar_inp, 
                                False, 
                                enc_padding_mask, 
                                combined_mask, 
                                dec_padding_mask)
    loss = loss_function(tar_real, predictions)
    
    val_loss(loss)
    val_accuracy(tar_real, predictions)


In [None]:
#@title Preprocessing

def get_hyp_lm_tercets(tercets):
    new_tercets = []
    for tercet in tercets:
        new_tercets.append([])
        for verse in tercet:
            new_tercets[-1].append([])
            for hyp_w in verse:
                new_tercets[-1][-1].extend(hyp_w)
                new_tercets[-1][-1].append('<SEP>')
            new_tercets[-1][-1] = new_tercets[-1][-1][:-1]

    return new_tercets

def is_vowel(c):
    return c in 'aeiouAEIOUàìíèéùúüòï'

def unsplittable_cons():
    u_cons = []
    for c1 in ('b', 'c', 'd', 'f', 'g', 'p', 't', 'v'):
        for c2 in ('l', 'r'):
            u_cons.append(c1 + c2)

    others = ['gn', 'gh', 'ch']
    u_cons.extend(others)
    return u_cons


def are_cons_to_split(c1, c2):
    to_split = ('cq', 'cn', 'lm', 'rc', 'bd', 'mb', 'mn', 'ld', 'ng', 'nd', 'tm', 'nv', 'nc', 'ft', 'nf', 'gm', 'fm', 'rv', 'fp')
    return (c1 + c2) in to_split or (not is_vowel(c1) and (c1 == c2)) or ((c1 + c2) not in unsplittable_cons()) and (
        (not is_vowel(c1)) and (not is_vowel(c2)) and c1 != 's')


def is_diphthong(c1, c2):
    return (c1 + c2) in ('ia', 'ie', 'io', 'iu', 'ua', 'ue', 'uo', 'ui', 'ai', 'ei', 'oi', 'ui', 'au', 'eu', 'ïe', 'iú', 'iù')


def is_triphthong(c1, c2, c3):
    return (c1 + c2 + c3) in ('iai', 'iei', 'uoi', 'uai', 'uei', 'iuo')


def is_toned_vowel(c):
    return c in 'àìèéùòï'

def has_vowels(sy):
    for c in sy:
        if is_vowel(c):
            return True
    return False


def hyphenation(word):
    """
    Split word in syllables
    :param word: input string
    :return: a list containing syllables of the word
    """
    if not word or word == '':
        return []
    # elif len(word) == 3 and (is_vowel(word[1]) and is_vowel(word[2]) and not is_toned_vowel(word[2]) and (
    #     not is_diphthong(word[1], word[2]) or (word[1] == 'i'))):
    elif len(word) == 3 and (is_vowel(word[1]) and is_vowel(word[2]) and not is_toned_vowel(word[2]) and (
        not is_diphthong(word[1], word[2]))):
        return [word[:2]] + [word[2]]
    elif len(word) == 3 and is_vowel(word[0]) and not is_vowel(word[1]) and is_vowel(word[2]):
        return [word[:2]] + [word[2]]
    elif len(word) == 3:
        return [word]

    syllables = []
    is_done = False
    count = 0
    while not is_done and count <= len(word) - 1:
        syllables.append('')
        c = word[count]
        while not is_vowel(c) and count < len(word) - 1:
            syllables[-1] = syllables[-1] + c
            count += 1
            c = word[count]

        syllables[-1] = syllables[-1] + word[count]

        if count == len(word) - 1:
            is_done = True
        else:
            count += 1

            if count < len(word) and not is_vowel(word[count]):
                if count == len(word) - 1:
                    syllables[-1] += word[count]
                    count += 1
                elif count + 1 < len(word) and are_cons_to_split(word[count], word[count + 1]):
                    syllables[-1] += word[count]
                    count += 1
                elif count + 2 < len(word) and not is_vowel(word[count + 1]) and not is_vowel(word[count + 2]) and word[
                    count] != 's':
                    syllables[-1] += word[count]
                    count += 1
            elif count < len(word):
                if count + 1 < len(word) and is_triphthong(word[count - 1], word[count], word[count + 1]):
                    syllables[-1] += word[count] + word[count + 1]
                    count += 2
                elif is_diphthong(word[count - 1], word[count]):
                    syllables[-1] += word[count]
                    count += 1

                if count + 1 < len(word) and are_cons_to_split(word[count], word[count + 1]):
                    syllables[-1] += word[count]
                    count += 1

            else:
                is_done = True

    if not has_vowels(syllables[-1]) and len(syllables) > 1:
        syllables[-2] = syllables[-2] + syllables[-1]
        syllables = syllables[:-1]

    return syllables



def get_dc_hyphenation(canti):
    hyp_canti, hyp_tokens = [], []
    for canto in canti:
        hyp_canti.append([])
        for verso in canto:
            syllables = seq_hyphentation(verso)
            hyp_canti[-1].append(syllables)
            for syllable in syllables:
                hyp_tokens.extend(syllable)

    return hyp_canti, hyp_tokens


def seq_hyphentation(words):
    """
    Converts words in a list of strings into lists of syllables
    :param words: a list of words (strings)
    :return: a list of lists containing word syllables
    """
    return [hyphenation(w) for w in words]


def get_dc_cantos(filename, encoding=None):
    # raw_data = read_words(filename=filename)
    cantos, words, raw = [], [], []
    with open(filename, "r", encoding=encoding) as f:
        for line in f:
            sentence = line.strip()
            sentence = str.replace(sentence, "\.", " \. ")
            sentence = str.replace(sentence, "[", '')
            sentence = str.replace(sentence, "]", '')
            sentence = str.replace(sentence, "-", '')
            sentence = str.replace(sentence, ";", " ; ")
            sentence = str.replace(sentence, ",", " , ")
            # sentence = str.replace(sentence, " \'", '')
            sentence = str.replace(sentence, "\'", ' \' ')
            if len(sentence) > 1:
                # sentence = sentence.translate(string.punctuation)
                tokenized_sentence = nl.word_tokenize(sentence)
                # tokenized_sentence = sentence.split()
                tokenized_sentence = [w.lower() for w in tokenized_sentence if len(w) > 0]

                #tokenized_sentence = [w for w in tokenized_sentence if "," not in w]
                #tokenized_sentence = [w for w in tokenized_sentence if "." not in w]
                #tokenized_sentence = [w for w in tokenized_sentence if ":" not in w]
                #tokenized_sentence = [w for w in tokenized_sentence if ";" not in w]
                tokenized_sentence = [w for w in tokenized_sentence if "«" not in w]
                tokenized_sentence = [w for w in tokenized_sentence if "»" not in w]
                # ts = []
                ts = tokenized_sentence
                # [ts.extend(re.split("(\')", e)) for e in tokenized_sentence]
                tokenized_sentence = [w for w in ts if len(w) > 0]

                if len(tokenized_sentence) == 2:
                    cantos.append([])
                    raw.append([])
                elif len(tokenized_sentence) > 2:
                    raw[-1].append(sentence)
                    cantos[-1].append(tokenized_sentence)
                    words.extend(tokenized_sentence)

    return cantos, words, raw


def create_tercets(cantos):
    tercets = []
    for i,canto in enumerate(cantos):
        for v,verse in enumerate(canto):
            if v%3 == 0:
                tercets.append([])

            tercets[-1].append(verse)
        tercets = tercets[:-1]  # removes the last malformed tercets (only 2 verses)

    return tercets

def pad_list(l, pad_token, max_l_size, keep_lasts=False, pad_right=True):
    """
    Adds a padding token to a list
    inputs:
    :param l: input list to pad.
    :param pad_token: value to add as padding.
    :param max_l_size: length of the new padded list to return,
    it truncates lists longer that 'max_l_size' without adding
    padding values.
    :param keep_lasts: If True, preserves the max_l_size last elements
    of a sequence (by keeping the same order).  E.g.:
    if keep_lasts is True and max_l_size=3 [1,2,3,4] becomes [2,3,4].


    :return: the list padded or truncated.
    """
    to_pad = []
    max_l = min(max_l_size, len(l))  # maximum len
    l_init = len(l) - max_l if len(l) > max_l and keep_lasts else 0  # initial position where to sample from the list
    l_end = len(l) if len(l) > max_l and keep_lasts else max_l
    for i in range(l_init, l_end):
        to_pad.append(l[i])

    # for j in range(len(l), max_l_size):
    #     to_pad.append(pad_token)
    pad_tokens = [pad_token] * (max_l_size-len(l))
    padded_l = to_pad + pad_tokens if pad_right else pad_tokens + to_pad

    return padded_l


def save_data(data, file):
    with open(file, 'wb') as output:
        pickle.dump(data, output, pickle.HIGHEST_PROTOCOL)

def load_data(file):
    with open(file, 'rb') as obj:
        return pickle.load(obj)

def print_and_write(file, s):
    print(s)
    file.write(s)


class Vocabulary(object):
    def __init__(self, vocab_size=None):
        self.dictionary = dict()
        self.rev_dictionary = dict()
        self.count = []
        self.special_tokens = []
        self.vocab_size = vocab_size

    def build_vocabulary_from_counts(self, count, special_tokens=[]):
        """
        Sets all the attributes of the Vocabulary object.
        :param count: a list of lists as follows: [['token', number_of_occurrences],...]
        :param special_tokens: a list of strings. E.g. ['<EOS>', '<PAD>',...]
        :return: None
        """

        dictionary = dict()
        for word, _ in count:
            dictionary[word] = len(dictionary)

        # adding eventual special tokens to the dictionary (e.g. <EOS>,<PAD> etc..)
        d = len(dictionary)
        for i, token in enumerate(special_tokens):
            dictionary[token] = d + i

        self.count = count
        self.dictionary = dictionary
        self.rev_dictionary = dict(zip(self.dictionary.values(), self.dictionary.keys()))
        self.special_tokens = special_tokens
        self.vocab_size = len(dictionary)

    def build_vocabulary_from_tokens(self, tokens, vocabulary_size=None, special_tokens=[]):
        """
        Given a list of tokens, it sets the Vocabulary object attributes by constructing
        a dictionary mapping each token to a unique id.
        :param tokens: a list of strings.
         E.g. ["the", "cat", "is", ... ".", "the", "house" ,"is" ...].
         NB: Here you should put all your token instances of the corpus.
        :param vocabulary_size: The number of elements of your vocabulary. If there are more
        than 'vocabulary_size' elements on tokens, it considers only the 'vocabulary_size'
        most frequent ones.
        :param special_tokens: Optional. A list of strings. Useful to add special tokens in vocabulary.
        If you don't have any, keep it empty.
        :return: None
        """

        vocabulary_size = vocabulary_size if vocabulary_size is not None else self.vocab_size
        vocabulary_size = vocabulary_size - (len(special_tokens) + 1) if vocabulary_size else None
        # counts occurrences of each token
        count = [['<UNK>', -1]]
        count.extend(collections.Counter(tokens).most_common(vocabulary_size))  # takes only the most frequent ones, if size is None takes them all
        self.build_vocabulary_from_counts(count, special_tokens)  # actually build the vocabulary
        self._set_unk_count(tokens)  # set the number of OOV instances

    @staticmethod
    def merge_vocabulary(vocab0, vocab1, vocabulary_size=-1):
        """
        Merge two Vocabulary objects into a new one.
        :param vocab0: first Vocabulary object
        :param vocab1: second Vocabulary object
        :param vocabulary_size: parameter to decide the merged vocabulary size.
        With default value -1, all the words of both vocabularies are preserved.
        When set to 0, the size of the vocabulary is set to the size of vocab0,
        when set to 1 it is kept the size of vocab1.
        :return: a new vocabulary
        """
        # get size of the new vocabulary
        vocab_size = vocab0.vocab_size + vocab1.vocab_size if vocabulary_size == -1 else vocabulary_size
        merged_special_tokens = list(set(vocab0.special_tokens) | set(vocab1.special_tokens))

        # merge the counts from the two vocabularies and then selects the most_common tokens
        merged_counts = collections.Counter(dict(vocab0.count)) + collections.Counter(dict(vocab1.count))
        merged_counts = merged_counts.most_common(vocab_size)
        count = [['<UNK>', -1]]
        count.extend(merged_counts)

        # create the new vocabulary
        merged_vocab = Vocabulary(vocab_size)
        merged_vocab.build_vocabulary_from_counts(count, merged_special_tokens)
        return merged_vocab

    @staticmethod
    def merge_vocabularies(vocab_list, vocab_size=None):
        """
        Join a list of vocabularies into a new one.
        :param vocab_list: a list of Vocabulary objects
        :param vocab_size: the maximum size of the merged vocabulary.
        :return: a vocabulary merging them all.
        """
        vocab_size = vocab_size if vocab_size else sum([v.vocab_size for v in vocab_list])
        merged_vocab = Vocabulary(vocab_size)
        for voc in vocab_list:
            merged_vocab = Vocabulary.merge_vocabulary(merged_vocab, voc, vocab_size)
        return merged_vocab

    def string2id(self, dataset):
        """
        Converts a dataset of strings into a dataset of ids according to the object dictionary.
        :param dataset: any string-based dataset with any nested lists.
        :return: a new dataset, with the same shape of dataset, where each string is mapped into its
        corresponding id associated in the dictionary (0 for unknown tokens).
        """

        def _recursive_call(items):
            new_items = []
            for item in items:
                if isinstance(item, str) or isinstance(item, int) or isinstance(item, float):
                    new_items.append(self.word2id(item))
                else:
                    new_items.append(_recursive_call(item))
            return new_items

        return _recursive_call(dataset)

    def id2string(self, dataset):
        """
        Converts a dataset of integer ids into a dataset of string according to the reverse dictionary.
        :param dataset: any int-based dataset with any nested lists. Allowed types are int, np.int32, np.int64.
        :return: a new dataset, with the same shape of dataset, where each token is mapped into its
        corresponding string associated in the reverse dictionary.
        """
        def _recursive_call(items):
            new_items = []
            for item in items:
                if isinstance(item, int) or isinstance(item, np.int) or isinstance(item, np.int32) or isinstance(item, np.int64):
                    new_items.append(self.id2word(item))
                else:
                    new_items.append(_recursive_call(item))
            return new_items

        return _recursive_call(dataset)

    def word2id(self, item):
        """
        Maps a string token to its corresponding id.
        :param item: a string.
        :return: If the token belongs to the vocabulary, it returns an integer id > 0, otherwise
        it returns the value associated to the unknown symbol, that is typically 0.
        """
        return self.dictionary[item] if item in self.dictionary else self.dictionary['<UNK>']

    def id2word(self, token_id):
        """
        Maps an integer token to its corresponding string.
        :param token_id: an integer.
        :return: If the id belongs to the vocabulary, it returns the string
        associated to it, otherwise it returns the string associated
        to the unknown symbol, that is '<UNK>'.
        """

        return self.rev_dictionary[token_id] if token_id in self.rev_dictionary else self.rev_dictionary[self.dictionary['<UNK>']]

    def get_unk_count(self):
        return self.count[0][1]

    def _set_unk_count(self, tokens):
        """
        Sets the number of OOV instances in the tokens provided
        :param tokens: a list of tokens
        :return: None
        """
        data = list()
        unk_count = 0
        for word in tokens:
            if word in self.dictionary:
                index = self.dictionary[word]
            else:
                index = 0  # dictionary['<UNK>']
                unk_count += 1
            data.append(index)
        self.count[0][1] = unk_count

    def add_element(self, name, is_special_token=False):
        if name not in self.dictionary:
            self.vocab_size += 1
            self.dictionary[name] = self.vocab_size
            self.rev_dictionary[self.vocab_size] = name

            if is_special_token:
                self.special_tokens = list(self.special_tokens)
                self.special_tokens.append(name)

            self.count.append([name, 1])

    def set_vocabulary(self, dictionary, rev_dictionary, special_tokens, vocab_size):
        self.dictionary = dictionary,
        self.rev_dictionary = rev_dictionary
        self.special_tokens = special_tokens
        self.vocab_size = vocab_size

    @staticmethod
    def load_vocabulary(filename):
        return load_data(filename)

    def save_vocabulary(self, filename):
        save_data(self, filename)

class SyLMDataset(object):
    def __init__(self, config, sy_vocab=None):
        self.config = config
        self.vocabulary = sy_vocab

        self.raw_train_x = []
        self.raw_val_x = []
        self.raw_test_x = []
        self.raw_x = []

        self.train_x, self.train_y = [], []
        self.val_x, self.val_y = [], []
        self.test_x, self.test_y = [], []
        self.x, self.y = [], []

    def initialize(self, sess):
        pass

    def load(self, sources):
        """
        Extract raw texts form sources and gather them all together.
        :param sources: a string or an iterable of strings containing the file(s)
        to process in order to build the dataset.
        :return: a list of raw strings.
        """
        return NotImplementedError

    def build(self, sources, split_size=0.8):
        """
        :param sources: a string or an iterable of strings containing the file(s)
        to process in order to build the dataset.
        :param split_size: the size to split the dataset, set >=1.0 to not split.
        """

        raw_x = self.load(sources)
        # raw_x = self.tokenize([self.preprocess(ex) for ex in raw_x])  # fixme
        # splitting data
        self.raw_x = raw_x
        if split_size < 1.0:
            self.raw_train_x, self.raw_test_x = self.split(self.raw_x, train_size=split_size)
            self.raw_train_x, self.raw_val_x = self.split(self.raw_train_x, train_size=split_size)
        else:
            self.raw_train_x = self.raw_x

        if self.vocabulary is None:
            # creates vocabulary
            tokens = [item for sublist in self.raw_train_x for item in sublist]  # get tokens
            special_tokens = ("<GO>", "<PAD>", "<SEP>", "<EOS>", "<EOV>")
            self._create_vocab(tokens, special_tokens=special_tokens)

        # creates x,y for train
        self.train_x = self._build_dataset(self.raw_train_x, insert_go=True, max_len=self.config.sentence_max_len, shuffle=False)
        self.train_y = self._build_dataset(self.raw_train_x, insert_go=True, max_len=self.config.sentence_max_len, shuffle=False)

        # creates x,y for validation
        self.val_x = self._build_dataset(self.raw_val_x, insert_go=True, max_len=self.config.sentence_max_len, shuffle=False)
        self.val_y = self._build_dataset(self.raw_val_x, insert_go=True, max_len=self.config.sentence_max_len, shuffle=False)

        # creates x,y for validation
        self.test_x = self._build_dataset(self.raw_test_x, insert_go=True, max_len=self.config.sentence_max_len, shuffle=False)
        self.test_y = self._build_dataset(self.raw_test_x, insert_go=True, max_len=self.config.sentence_max_len, shuffle=False)

    def _create_vocab(self, tokens, special_tokens=("<PAD>", "<GO>", "<SEP>", "<EOV>", "<EOS>")):
        """
        Create the vocabulary. Special tokens can be added to the tokens obtained from
        the corpus.
        :param tokens: a list of all the tokens in the corpus. Each token is a string.
        :param special_tokens: a list of strings.
        """
        vocab = Vocabulary(vocab_size=self.config.input_vocab_size)
        vocab.build_vocabulary_from_tokens(tokens, special_tokens=special_tokens)
        self.vocabulary = vocab

    @staticmethod
    def split(raw_data, train_size=0.8):
        size = math.floor(len(raw_data)*train_size)
        return raw_data[:size], raw_data[size:]

    @staticmethod
    def preprocess(txt):
        return txt

    @staticmethod
    def shuffle(x):
        return random.sample(x, len(x))

    @staticmethod
    def tokenize(txt):
        return txt

    def _build_dataset(self, raw_data, max_len=100, insert_go=True, keep_lasts=False, pad_right=True, shuffle=True):
        """
        Converts all the tokens in e1_raw_data by mapping each token with its corresponding
        value in the dictionary. In case of token not in the dictionary, they are assigned to
        a specific id. Each sequence is padded up to the seq_max_len setup in the config.

        :param raw_data: list of sequences, each sequence is a list of tokens (strings).
        :param max_len: max length of a sequence, crop longer and pad smaller ones.
        :param insert_go: True to insert <GO>, False otherwise.
        :param keep_lasts: True to truncate initial elements of a sequence.
        :param pad_right: pad to the right (default value True), otherwise pads to left.
        :param shuffle: Optional. If True data are shuffled.
        :return: A list of sequences where each token in each sequence is an int id.
        """
        dataset = []
        for sentence in raw_data:
            sentence_ids = [self.vocabulary.word2id("<GO>")] if insert_go else []
            sentence_ids.extend([self.vocabulary.word2id(w) for w in sentence])
            sentence_ids.append(self.vocabulary.word2id("<EOS>"))
            sentence_ids = pad_list(sentence_ids, self.vocabulary.word2id("<PAD>"), max_len, keep_lasts=keep_lasts, pad_right=pad_right)

            dataset.append(sentence_ids)

        if shuffle:
            return random.sample(dataset, len(dataset))
        else:
            return dataset

    def get_batches(self, batch_size=32, split_sel='train'):
        """
        Iterator over the training set. Useful method to run experiments.
        :param batch_size: size of the mini_batch
        :return: input and target.
        """
        if split_sel == 'train':
            x, y = self.train_x, self.train_y
        elif split_sel == 'val':
            x, y = self.val_x, self.val_y
        else:
            x, y = self.test_x, self.test_y
        
        i = 0#random.randint(0, batch_size)
        batches = []
        eov = self.vocabulary.word2id("<EOV>")
        go = self.vocabulary.word2id("<GO>")
        # prepare batches
        while i < len(x):
            j = 0
            batch_x, batch_y = [], []
            while j < batch_size and i+j<len(x):
                for c in x[i+j]:
                  batch_x.append(c)
                batch_x.append(eov)
                for c in y[i+j]:
                  batch_y.append(c)
                batch_y.append(eov)
                j += 1
            i += batch_size
            batches.append((batch_x, batch_y))

        # supply
        i = 0
        while i < len(batches):
            yield batches[i][0], batches[i][1]
            i += 1

class DanteSyLMDataset(SyLMDataset):
    def __init__(self, config, sy_vocab=None):
        """
        Class to create a dataset from Dante Alighieri's Divine Comedy.
        :param config: a Config object
        :param sy_vocab: (optional) a Vocabulary object where tokens of the dictionary
        are syllables. If None, the vocabulary is create automatically from the source.
        """
        super().__init__(config, sy_vocab)

    def load(self, sources):
        """
        Load examples from dataset
        :param sources: data filepath.
        :return:
        """
        canti, _, raw = get_dc_cantos(filename=sources)  # get raw data from file
        canti, tokens = get_dc_hyphenation(canti)  # converts each

        tercets = create_tercets(canti)
        tercets = get_hyp_lm_tercets(tercets)

        x = []
        for tercet in tercets:
            x.append([])
            for verse in tercet:
                x[-1].extend(verse)
                x[-1].append("<EOV>")

        #x = self.shuffle(x)
        return x

def seq2str(seq):
    def output2string(batch, rev_vocabulary, special_tokens, end_of_tokens):
        to_print = ''
        for token in batch:
            if token in special_tokens:
                to_print += ' '
            elif end_of_tokens and token in end_of_tokens:
                to_print += '\n'
            elif token in rev_vocabulary:
                to_print += rev_vocabulary[token]
            else:
                to_print += '<UNK>'
        return to_print

    return output2string(seq, poetry_sy_lm_dataset.vocabulary.rev_dictionary,
      special_tokens=[poetry_sy_lm_dataset.vocabulary.word2id("<PAD>"), 0, poetry_sy_lm_dataset.vocabulary.word2id("<SEP>"),
                      poetry_sy_lm_dataset.vocabulary.word2id("<GO>"), poetry_sy_lm_dataset.vocabulary.word2id("<EOS>")],
      end_of_tokens=[poetry_sy_lm_dataset.vocabulary.word2id("<EOV>")])

class cnfg:
  vocab_size = vocab_size
  input_vocab_size = vocab_size
  sentence_max_len = terces_len

config = cnfg()
poetry_sy_lm_dataset = DanteSyLMDataset(config, sy_vocab=None)
url = "https://gitlab.com/zugo91/nlgpoetry/-/raw/release/data/la_divina_commedia.txt"
response = requests.get(url)
response.encoding = 'ISO-8859-1'
fi = open("divcom.txt","w")
fi.write(response.text)
fi.close()
data_path = os.path.join(os.getcwd(), "divcom.txt")  # dataset location, here just the name of the source file
poetry_sy_lm_dataset.build(data_path, split_size=0.99)  # actual creation of  vocabulary (if not provided) and dataset
print("Train size: " + str(len(poetry_sy_lm_dataset.train_y)))
print("Val size: " + str(len(poetry_sy_lm_dataset.val_y)))
print("Test size: " + str(len(poetry_sy_lm_dataset.test_y)))

eov = poetry_sy_lm_dataset.vocabulary.word2id("<EOV>")
pad = poetry_sy_lm_dataset.vocabulary.word2id("<PAD>")
go = poetry_sy_lm_dataset.vocabulary.word2id("<GO>")
eos = poetry_sy_lm_dataset.vocabulary.word2id("<EOS>")

batches = [b for b in poetry_sy_lm_dataset.get_batches(terces_per_batch)]
print(batches[0][0])
print(batches[0][1])
print(len(batches[0][0]))
val_b = [b for b in poetry_sy_lm_dataset.get_batches(terces_per_batch, split_sel='val')]
print(val_b[0][0])
print(val_b[0][1])
print(len(val_b[0][0]))
test_b = [b for b in poetry_sy_lm_dataset.get_batches(terces_per_batch, split_sel='test')]
print(test_b[0][0])
print(test_b[0][1])
print(len(test_b[0][0]))

'''
d = poetry_sy_lm_dataset.vocabulary.dictionary.items()
d_view = [ (v,k) for k,v in d]
d_view.sort(reverse=False) # natively sort tuples by first element
for v,k in d_view:
    print(k,v)
'''
len(poetry_sy_lm_dataset.vocabulary.dictionary.items())

In [5]:
#@title Evaluation

import tensorflow_datasets as tfds

def ngrams_plagiarism(generated_text, n=4):
    # the tokenizer is used to remove non-alphanumeric symbols
    tokenizer = tfds.features.text.Tokenizer()
    with open("divcom.txt") as f:
        original_text = f.read()
    original_text = tokenizer.join(tokenizer.tokenize(original_text.lower()))
    generated_text_tokens = tokenizer.tokenize(generated_text.lower())

    total_ngrams = len(generated_text_tokens) - n + 1
    plagiarism_counter = 0

    for i in range(total_ngrams):
        ngram = tokenizer.join(generated_text_tokens[i:i+n])
        plagiarism_counter += 1 if ngram in original_text else 0
    return 1 - (plagiarism_counter / total_ngrams)

# coding=utf-8

# Syllabification module.
# A special thanks goes to Simona S., Italian linguist, teacher and friend, without whom this module could never exist.

# This module is used both for building the dataset and for computing metrics.
# IMPORTANT: the #, @ and § characters are used internally to correctly split syllables, the input string should not contain them.

# Splits a string along word boundaries (empty spaces and punctuation marks). If synalepha is True, doesn't split
# words which have a vowel boundary (eg. selva_oscura).
def split_words(strn, synalepha=False):
    regex = re.compile(r"""[,.;:"“”«»?—'`‘’\s]*\s+[,.;:"“”«»?—'`‘’\s]*""")
    matches = regex.finditer(strn)
    indexes = [0]

    for m in matches:
        begin = (m.start() - 1) if m.start() - 1 > 0 else 0
        end = m.end() + 1
        if _is_split_acceptable(strn[begin: end], synalepha):
            indexes.append(begin + 1)

    return [strn[i:j] for i,j in zip(indexes, indexes[1:]+[None])]

# Splits a single word into syllables.
def syllabify_word(strn):
    return _perform_final_splits(_perform_initial_splits(strn))

# Splits a block into words and then into syllables.
def syllabify_block(strn, synalepha=False):
    words = split_words(strn, synalepha)
    syllables = [syllabify_word(w) for w in words]
    return "#".join(syllables)

# Removes capitalization, punctuation marks and, optionally, diacritics (accents and dieresis).
def prettify(strn, keep_diacritics=True):
    if keep_diacritics:
        out = _strip_spaces(_strip_punctuaction(strn.lower()))
    else:
        out = _strip_spaces(_strip_punctuaction(_remove_diacritics(strn.lower())))
    return out

# Removes hash characters from a string.
def strip_hashes(strn):
    return re.sub("#", "", strn)

# Determines if a split between two words is acceptable, ie. if there are no synalepha nor elision (eg. "l' amico" should be kept together).
# Heuristic: all apostrophes are considered a non-breakable point. This is not always the case (eg. "perch’ i’ fu’" should be split into "perch’ i’"-"fu’).
def _is_split_acceptable(strn, synalepha=False):
    prev = strn[0]
    next = strn[len(strn) - 1]
    vowel = re.compile(r"""[AEIOUaeiouàèéìòóùÈ]""")
    apostrophe = re.compile(r""".*['`‘’].*""")
    newline = re.compile(r""".*\n+.*""")

    out = newline.match(strn) or \
          not (apostrophe.match(strn) and (vowel.match(prev) or vowel.match(next)))

    if synalepha:
        out = out and not (vowel.match(prev) and vowel.match(next))

    return out

# Removes punctuation from a string.
def _strip_punctuaction(str):
    return re.sub(r"""[,.;:"“”!?«»—'`’]+""", "", str)

# Removes diacritic marks from a string.
def _remove_diacritics(str):
    out = re.sub(r"""[àä]""", "a", str)
    out = re.sub(r"""[èéë]""", "e", out)
    out = re.sub(r"""[ìï]""", "i", out)
    out = re.sub(r"""[òóö]""", "o", out)
    out = re.sub(r"""[ùü]""", "u", out)
    return out

# Removes spaces from a string.
def _strip_spaces(str):
    return re.sub(r"""\s+""", "", str)

# Performs the first (easy and unambiguous) phase of syllabification.
def _perform_initial_splits(str):
    return _split_hiatus(_split_dieresis(_split_double_cons(_split_multiple_cons(str))))

# Performs the second (difficult and heuristic) phase of syllabification.
def _perform_final_splits(str):
    cvcv = r"""(?i)([bcdfglmnpqrstvz][,.;:"“”«»?—'`‘’\s]*[aeiouàèéìóòùÈËÏ]+)([bcdfglmnpqrstvz]+[,.;:"“”«»?—'`‘’\s]*[aeiouàèéìóòùÈËÏ]+)"""
    vcv = r"""(?i)([aeiouàèéìóòùÈËÏ]+)([bcdfglmnpqrstvz]+[,.;:"“”«»?—'`‘’\s]*[aeiouàèéìóòùÈËÏ]+)"""
    vv = r"""(?i)(?<=[aeiouàèéìóòùÈËÏ])(?=[aeiouàèéìóòùÈËÏ])"""

    # Split the contoid vocoid - contoid vocoid case (eg. ca-ne). Deterministic.
    out = re.sub(cvcv, r"""\1#\2""", str)
    # Split the vocoid - contoid vocoid case (eg. ae-reo). Deterministic.
    out = re.sub(vcv, r"""\1#\2""", out)

    # Split the vocoid - vocoid case (eg. a-iuola). Heuristic.
    out = _clump_diphthongs(out)
    out = re.sub(vv, r"""#""", out)
    out = re.sub("§", "", out)

    return out

# Splits double consonants (eg. al-legro)
def _split_double_cons(str):
    doubles = re.compile(r"""(?i)(([bcdfglmnpqrstvz])(?=\2)|c(?=q))""")
    return "#".join(doubles.sub(r"""\1@""", str).split("@"))

# Splits multiple consonants, except: impure s (sc, sg, etc.), mute followed by liquide (eg. tr), digrams and trigrams.
def _split_multiple_cons(str):
    impures = re.compile(r"""(?i)(s(?=[bcdfghlmnpqrtvz]))""")
    muteliquide = re.compile(r"""(?i)([bcdgpt](?=[lr]))""")
    digrams = re.compile(r"""(?i)(g(?=li)|g(?=n[aeiou])|s(?=c[ei])|[cg](?=h[eèéiì])|[cg](?=i[aou]))""")
    trigrams = re.compile(r"""(?i)(g(?=li[aou])|s(?=ci[aou]))""")
    multicons = re.compile(r"""(?i)([bcdfglmnpqrstvz](?=[bcdfglmnpqrstvz]+))""")

    # Preserve non admissibile splits.
    out ="§".join(impures.sub(r"""\1@""", str).split("@"))
    out = "§".join(muteliquide.sub(r"""\1@""", out).split("@"))
    out = "§".join(digrams.sub(r"""\1@""", out).split("@"))
    out = "§".join(trigrams.sub(r"""\1@""", out).split("@"))
    # Split everything else.
    out = "#".join(multicons.sub(r"""\1@""", out).split("@"))

    return "".join(re.split("§", out))

# Splits dieresis.
def _split_dieresis(str):
    dieresis = re.compile(r"""(?i)([äëïöüËÏ](?=[aeiou])|[aeiou](?=[äëïöüËÏ]))""")
    return "#".join(dieresis.sub(r"""\1@""", str).split("@"))

# Splits SURE hiatuses only. Ambiguous ones are heuristically considered diphthongs.
def _split_hiatus(str):
    hiatus = re.compile(r"""(?i)([aeoàèòóé](?=[aeoàèòóé])|[rb]i(?=[aeou])|tri(?=[aeou])|[ìù](?=[aeiou]))""")
    return "#".join(hiatus.sub(r"""\1@""", str).split("@"))

# Prevents splitting of diphthongs and triphthongs.
def _clump_diphthongs(str):
    diphthong = r"""(?i)(i[,.;:"“”«»?—'`‘’\s]*[aeouàèéòóù]|u[,.;:"“”«»?—'`‘’\s]*[aeioàèéìòó]|[aeouàèéòóù][,.;:"“”«»?—'`‘’\s]*i|[aeàèé][,.;:"“”«»?—'`‘’\s]*u)"""
    diphthongsep = r"""(\{.[,.;:"“”«»?—'`‘’\s]*)(.\})"""
    triphthong = r"""(?i)(i[àèé]i|u[àòó]i|iu[òó])"""
    triphthongsep = r"""(\{.)(.)(.\})"""

    out = re.sub(triphthong, r"""{\1}""", str)
    out = re.sub(triphthongsep, r"""\1§\2§\3""", out)
    out = re.sub(diphthong, r"""{\1}""", out)
    out = re.sub(diphthongsep, r"""\1§\2""", out)
    out = re.sub(r"""[{}]""", "", out)

    return out

# coding=utf-8

# Rhyme scoring and extraction module. Exploits informations about accents, syllables and heuristics to perform
# the difficult task of determining if two words form a rhyme.

# Computes a rhyming score between two words.
def rhyme_score(w1, w2):
    if w1 == "" or w2 == "": # One of the two words is missing.
        return 0

    pw1 = prettify(w1, True)  # preserving accents.
    pw2 = prettify(w2, True)
    ppw1 = prettify(w1, False)  # removing accents.
    ppw2 = prettify(w2, False)
    accent1 = _locate_accent(pw1)
    accent2 = _locate_accent(pw2)

    if accent1 == 0 and accent2 == 0: # Difficult case: no accent is known. Heuristic match.
        out = _heuristic_rhyme(w1, w2)
    elif accent1 == accent2: # Trivial case: both accents in the same position.
        out = _match_syllable(ppw1[accent1:], ppw2[accent1:])
    elif accent1 != 0 and accent2 == 0: # Trivial case: accent1 known.
        out = _match_syllable(ppw1[accent1:], ppw2[accent1:])
    elif accent2 != 0 and accent1 == 0: # Trivial case: accent2 known.
        out = _match_syllable(ppw1[accent2:], ppw2[accent2:])
    else: # Trivial case: both accents are known, but in different positions.
        out = _match_syllable(ppw1[accent1:], ppw2[accent2:])

    return out

# Determines if a word is tronca (accent on the last syllable). Exact cases: word ending with an accented letter (morì) or word ending with a consonant (mangiàr).
# Heuristic: NO other words are considered tronche since the majority of Italian words are piane (accent on the second to last syllable) or sdrucciole (third to last).
def is_tronca(word):
    consonant = re.compile(r"""[bcdfghlmnprstvz]""")
    accentlastsyl = re.compile(r""".*#[^#]*[àèéìóòù][^#]*""")
    w = prettify(word, True)
    out = False

    if w == "": # The "word" was actually composed by punctuation only.
        out = False
    elif consonant.match(w[-1]):
        out = True
    else:
        sw = syllabify_word(w)
        if sw.count("#") == 0:
            out = True
        elif accentlastsyl.match(sw):
            out = True
        else:
            out = False

    return out

# Not used:
# def _is_piana(word): # Most common case.
#     return not (_is_tronca(word) or _is_sdrucciola(word))
#
# def _is_sdrucciola(word): # Detected only if the accent is marked.
#     accentlastsyl = re.compile(r""".*[àèéìóòù].*#.*#.*""") # The accent is marked and followed by at least two hashes.
#     return accentlastsyl.match(s.syllabify_word(s.prettify(word, True)))

# Returns the accent position FROM THE END of the word (eg. mangiò -> -1, dormìre -> -3).
# NOTE: prettification is done by the caller, since it could change accent position.
def _locate_accent(word):
    accent = re.compile(r"""[àèéìóòù]""")
    match = accent.search(word)
    if match:
        pos = match.start()
    else:
        pos = len(word)

    return pos - len(word)

# Determines a rhyming score if the two words don't have accents.
def _heuristic_rhyme(w1, w2):
    pw1 = prettify(w1, False)
    pw2 = prettify(w2, False)

    sw1 = syllabify_word(pw1).split("#")
    sw2 = syllabify_word(pw2).split("#")

    # Approximate match:
    if is_tronca(w1) and is_tronca(w2): # Both words are tronche: match only the last syllable from the vowel.
        out = _match_syllable(sw1[-1], sw2[-1])
    else: # Both words are piane: match exactly the last syllable and the last-but-one from the vowel.
        ssw1 = "".join(sw1[-2:])
        ssw2 = "".join(sw2[-2:])
        out = _match_syllable(ssw1, ssw2)

    return out

# Computes a score based on how many letters match from the end of the two strings, up to the last vowel of the first vocoid (eg. "men#te" vs. "can#te" tries to match ente and ante, computing a score of 0.75, while "iuo#la" vs. "suo#la" tries to match ola and ola, computing a score of 1.0).
# HEURISTIC: since no accent is known, the match is as PERMISSIVE as possible (ie. matches from the LAST vowel of a diphthong). This rhymes correctly "quivi/sorgivi" (while a restrictive heuristic wouldn't).
# The computed score is the sum of all matching characters (truncated at the first difference), weighted exponentially with the distance from the putative beginning of the rhyme.
# As such, it's a score which can scale well on different matching lengths (eg. "più" and "fu" have a score similar to "frangente" and "assolutamente"), at the expenses of not having a "natural" meaning which could be easier to threshold.
def _match_syllable(s1, s2):
    lastvowel = re.compile(r"""[aeiou](?![aeiou])""") # Inside a syllable vowels can only be together, so only the NEXT character needs to be checked.

    match1 = lastvowel.search(s1)
    match2 = lastvowel.search(s2)

    if match1 and match2:
        ss1 = s1[match1.start():]
        ss2 = s2[match2.start():]

        # maxlength = len(ss2) if len(ss1) < len(ss2) else len(ss1) # The two lengths could be different.
        minlength = len(ss1) if len(ss1) < len(ss2) else len(ss2)
        out = 0.0
        a = (ss1 if len(ss1) < len(ss2) else ss2)[::-1] # reverse.
        b = (ss2 if len(ss1) < len(ss2) else ss1)[::-1]

        i = 0
        while (i < minlength) and (a[i] == b[i]): # Iterate only over the shared part of the string.
            out += 2.0 ** (minlength - i)
            i += 1
        out /= 2 ** minlength
    else:
        out = 0.0

    return out

# Metrics evaluation module.

# Evaluates metrics on a string, computing each value on a per-terzina basis and then outputting the average scores.
# If verbose, outputs the scores referred to each terzina.
def eval_txt(string, verbose=False, synalepha=False, permissive=True, rhyme_threshold=1.0):
    terzine = _extract_terzine(string)

    avg_hendecasyllabicness = 0.0
    avg_rhymeness = 0.0
    last_terzina = terzine[0]
    for terzina in terzine[1:]:
        hendecasyllabicness = _hendecasyllabicness(terzina, synalepha, permissive)
        tmp = "\n".join(terzina.split("\n")[1:])

        # In order to properly check chaining, two terzine at the time need to be considered.
        rhymeness = _rhymeness(last_terzina + tmp, rhyme_threshold)
        avg_hendecasyllabicness += hendecasyllabicness
        avg_rhymeness += rhymeness

        last_terzina = terzina

        if verbose:
            print()
            print(terzina)
            print("Hendecasyllabicness: {}, Rhymeness: {}".format(hendecasyllabicness, rhymeness))

    print()
    if len(terzine) > 1:
        # Each "optimal" terzina has 5 lines, the last of which is shared with the next one
        # (therefore a file with n perfect terzine has 4n + 2 lines, due to the final stray verse and empty line).
        avg_structuredness = (4 * len(terzine) + 2) / len(string.split("\n"))
        avg_hendecasyllabicness /= len(terzine)
        avg_rhymeness /= len(terzine) - 1 # The rhymes on the first terzina are not checked.

        return ["Number of putative terzine: {}".format((len(string.split("\n")) - 1) // 4),
            "Number of well formed terzine: {}".format(len(terzine)),
            "Average structuredness: {}".format(avg_structuredness),
            "Average hendecasyllabicness: {}".format(avg_hendecasyllabicness),
            "Average rhymeness: {}".format(avg_rhymeness),
            "N-grams plagiarism: {}".format(ngrams_plagiarism(string))]
    else:
        print("ERROR: no valid terzina detected.")

# Hendecasyllabicness score. For each of the four verses in input, computes a score and returns their average.
# The score is 1.0 if a verse has 10, 11 or 12 syllables, and decreases towards 0.0 the more the number of syllables diverges.
# Syllabification is done using Italian grammar rules, ignoring synalepha.
def _hendecasyllabicness(strn, synalepha, permissive):
    score  = 0.0
    lines = strn.split("\n")
    for line in lines:
        if line != "":
            # In order to avoid cheating, strip all # characters and perform syllabification according to grammar.
            tmp = syllabify_block(strip_hashes(line), synalepha)
            if is_tronca(split_words(line, False)[-1]):
                target = 10
            else:
                target = 11

            syllables = [s for s in tmp.split("#") if s != ""]
            if not permissive or abs(len(syllables) - target) > 1: # Tolerate 10 and 12 syllables.
                score += 1 - abs(len(syllables) - target) / target
            else:
                score += 1.0

    return score / 4

# Rhymeness score. In order to correctly detect chaining, TWO terzine need to be passed, but the score is referred only to the second one.
# Since a terzina formally includes the stray verse which begins the next one, the rhyming scheme to be checked is the following:
# don't care
# B
# don't care
#
# B
# C
# B
#
# C.
# For each of the three rhymes (BB, CC and BB) assign 1.0 if the rhyme score (computed in an encoding-agnostic way in rhymes.py) is above 1.5.
# NOTE: due to the intrinsic difficulty of formally define a rhyme, this threshold has no clear semantic and was chosen empirically.
def _rhymeness(strn, rhyme_threshold):
    score = 0.0
    last_words = _extract_last_words(strn)

    rhymes = []
    rhymes.append(rhyme_score(last_words[1], last_words[3]))
    rhymes.append(rhyme_score(last_words[3], last_words[5]))
    # rhymes.append(rhyme_score(last_words[1], last_words[5])) # Is transitivity implied?
    rhymes.append(rhyme_score(last_words[4], last_words[6]))

    for rhyme in rhymes:
        if rhyme >= rhyme_threshold:
            score += 1.0

    return score / len(rhymes)

# Extracts a list of terzine from a string, skipping malformed lines.
# Each well formed terzina has the following structure:
# Verse
# Verse
# Verse
#
# Verse,
# In order to correctly handle chaining, the last verse of each terzina is also the first verse of the next one.
def _extract_terzine(strn):
    terzinaA = re.compile(r"""([^\n]+\n[^\n]+\n[^\n]+\n\n[^\n]+\n)""") # Case LLL L. Extract 3 + 1 lines and then skip 4 lines.
    terzinaB = re.compile(r"""[^\n]+\n([^\n]+\n\n[^\n]+\n[^\n]+\n[^\n]+)""") # Case LL LLL. Ignore 1 line, extract 1 + 3 lines and then skip 3 lines. After the skip, only case A can appear.
    skipA = re.compile(r"""[^\n]+\n[^\n]+\n[^\n]+(\n\n)?""")
    skipB = re.compile(r"""[^\n]+\n[^\n]+(\n\n)?""")
    out = []
    tmp = strn

    m = terzinaA.search(tmp)
    if m:
        while m:
            out.append(m.group(0))
            tmp = tmp[skipA.search(tmp).end():]
            m = terzinaA.search(tmp)
    else:
        m = terzinaB.search(tmp)
        if m:
            out.append(m.group(0)) # The regex will not capture the first line.
            tmp = tmp[skipB.search(tmp).end():]
            m = terzinaA.search(tmp)  # After the first skip, the case A appears.
            while m:
                out.append(m.group(0))
                tmp = tmp[skipA.search(tmp).end():]
                m = terzinaA.search(tmp)

    return out

# Extract the last words from each verse of a string.
# NOTE: empty lines are skipped.
def _extract_last_words(strn):
    lines = strn.split("\n")

    verses = [l for l in lines if l != ""]
    words = [split_words(v, False)[-1] for v in verses]
    out = [strip_hashes(prettify(w, True)) for w in words]
    return out

In [11]:
#@title Generation
def generate():
    def evaluate_greedy(inp_sentence, decoder_input):
        inp_sentence = inp_sentence
        encoder_input = tf.expand_dims(inp_sentence, 0)
        
        output = tf.expand_dims(decoder_input, 0)

        terces = 0
        for i in range(batch_len):
            enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
                encoder_input, output)
        
            # predictions.shape == (batch_size, seq_len, vocab_size)
            predictions, attention_weights = transformer(encoder_input, 
                                                        output,
                                                        False,
                                                        enc_padding_mask,
                                                        combined_mask,
                                                        dec_padding_mask)
            
            # select the last word from the seq_len dimension
            predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

            predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

            # return the result if the predicted_id is equal to the end token
            if predicted_id == eos:
                terces += 1
                if terces == terces_per_batch-1:
                    return tf.squeeze(output, axis=0), attention_weights
            # concatentate the predicted_id to the output which is given to the decoder
            # as its input.
            output = tf.concat([output, predicted_id], axis=-1)

        return tf.squeeze(output, axis=0), attention_weights


    def evaluate_topk(inp_sentence, decoder_input, k=5, temperature=0.5):
        inp_sentence = inp_sentence
        encoder_input = tf.expand_dims(inp_sentence, 0)
        
        output = tf.expand_dims(decoder_input, 0)

        def scale(tensor):
            tensor = tf.math.divide(
                tf.subtract(
                    tensor, 
                    tf.reduce_min(tensor)
                ), 
                tf.subtract(
                    tf.reduce_max(tensor), 
                    tf.reduce_min(tensor))
                )
            return tensor

        terces = 0
        for i in range(batch_len):
            enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
                encoder_input, output)
        
            # predictions.shape == (batch_size, seq_len, vocab_size)
            predictions, attention_weights = transformer(encoder_input, 
                                                        output,
                                                        False,
                                                        enc_padding_mask,
                                                        combined_mask,
                                                        dec_padding_mask)
            
            # select the last word from the seq_len dimension
            predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)
            predictions, indices = tf.math.top_k(predictions,k=k)
            predictions /= temperature
            #predictions = scale(predictions)
            predictions = np.squeeze(predictions, axis=0)
            indices = np.squeeze(indices, axis=0)
            indices = np.squeeze(indices, axis=0)
            predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
            predicted_id = indices[predicted_id]

            # return the result if the predicted_id is equal to the end token
            if predicted_id == eos:
                terces += 1
                if terces == terces_per_batch-1:
                    return tf.squeeze(output, axis=0), attention_weights
            # concatentate the predicted_id to the output which is given to the decoder
            # as its input.
            predicted_id = tf.expand_dims(predicted_id, 0)
            predicted_id = tf.expand_dims(predicted_id, 0)
            output = tf.concat([output, predicted_id], axis=-1)

        return tf.squeeze(output, axis=0), attention_weights


    def plot_attention_weights(attention, sentence, result, layer):
        fig = plt.figure(figsize=(32, 16))
        attention = tf.squeeze(attention[layer], axis=0)
        for head in range(attention.shape[0]):
            ax = fig.add_subplot(2, 4, head+1)
            # plot the attention weights
            ax.matshow(attention[head][:-1, :], cmap='viridis')
            fontdict = {'fontsize': 10}
            ax.set_xticks(range(len(sentence)+2))
            ax.set_yticks(range(len(result)))
            ax.set_ylim(len(result)-1.5, -0.5)
            ax.set_xticklabels(sentence, fontdict=fontdict, rotation=90)
            ax.set_yticklabels(result, fontdict=fontdict)
            ax.set_xlabel('Head {}'.format(head+1))
        plt.tight_layout()
        plt.show()

    out_list = test_b[0][0]
    print(seq2str(out_list)+"---------------------------")

    offset = terces_len # a tercet
    txt_gen = seq2str(out_list[-offset:])+"\n"
    k=1
    t=0.5
    for i in range(32//(terces_per_batch-1)): # 30 terces = cantica
        out, att_w = evaluate_topk([pad], out_list[-offset:], k, t)
        aa = out.numpy().tolist()
        '''
        if i==0: #only once
            plot_attention_weights(att_w, out_list, aa, 'decoder_layer1_block1')
            plot_attention_weights(att_w, out_list, aa, 'decoder_layer2_block1')
            plot_attention_weights(att_w, out_list, aa, 'decoder_layer3_block1')
            plot_attention_weights(att_w, out_list, aa, 'decoder_layer4_block1')
        '''
        out_list = aa
        out_str = seq2str(out_list[offset:])
        txt_gen += out_str + "\n"
        print(out_str) 

    wandb.log({"generated":
            wandb.Html("k="+str(k)+" t="+str(t)+
                       "<pre>"+txt_gen+"</pre>", inject=False)})

In [None]:
 #@title Train loop

for epoch in range(EPOCHS):
    random.shuffle(batches)
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    for (batch, (inp, tar)) in enumerate(batches):
        if (len(inp) != batch_len or len(tar) != batch_len):
            print("discarded batch", batch)
            continue
        train_step(np.expand_dims(inp, axis=0), np.expand_dims(tar, axis=0))
        
        if batch % 50 == 0:
            print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
                epoch + 1, batch, train_loss.result(), train_accuracy.result()))
        
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))

    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, train_loss.result(), train_accuracy.result()))
    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

    wandb.log({
        'train_loss': train_loss.result(),
        'train_accuracy': train_accuracy.result()
    }, step=epoch+1)

    # validation
    if epoch % 5 == 0:
        loss_l, acc_l = [], []
        for (batch, (inp, tar)) in enumerate(val_b):
            val_loss.reset_states()
            val_accuracy.reset_states()
            
            if (len(inp) != batch_len or len(tar) != batch_len):
                print("discarded batch", batch)
                continue

            val_step(np.expand_dims(inp, axis=0), np.expand_dims(tar, axis=0))

            loss_l.append(val_loss.result())
            acc_l.append(val_accuracy.result())

        loss_mean = sum(loss_l)/len(loss_l)
        acc_mean = sum(acc_l)/len(acc_l)
        print('Epoch {} VALIDATION: Loss {:.4f} Accuracy {:.4f}\n'.format(epoch + 1, loss_mean, acc_mean))

        wandb.log({
            'val_loss': loss_mean,
            'val_accuracy': acc_mean
        }, step=epoch+1)

    # generation
    if epoch in [100, 150, 200, 249, 299]:
        generate()


In [None]:
transformer.save_weights("./optimus_rhyme")
#transformer.load_weights("./optimus_rhyme")


#emb_enc_w = transformer.encoder.embedding.get_weights()[0]
emb_enc_w = transformer.decoder.embedding.get_weights()[0]
print(emb_enc_w.shape)

out_v = open('vecs.tsv', 'w', encoding='utf-8')
out_m = open('meta.tsv', 'w', encoding='utf-8')

for num, word in enumerate(poetry_sy_lm_dataset.vocabulary.dictionary):
  vec = emb_enc_w[num] # skip 0, it's padding.
  out_m.write(word + "\n")
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
out_v.close()
out_m.close()


'''
%load_ext tensorboard
%tensorboard --logdir .
'''