In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
if 'google.colab' in str(get_ipython()):
    from google.colab import files
    
    files.upload()

Saving data.zip to data.zip
Saving deepcomedy.zip to deepcomedy.zip


In [None]:
!tar zxvf deepcomedy.tar.gz

deepcomedy/
deepcomedy/utils.py
deepcomedy/models/
deepcomedy/models/layers.py
deepcomedy/models/decoder_only.py
deepcomedy/models/transformer.py
deepcomedy/models/__pycache__/
deepcomedy/models/__pycache__/layers.cpython-37.pyc
deepcomedy/models/__pycache__/__init__.cpython-37.pyc
deepcomedy/models/__pycache__/transformer.cpython-37.pyc
deepcomedy/models/__init__.py
deepcomedy/models/.ipynb_checkpoints/
deepcomedy/models/.ipynb_checkpoints/transformer-checkpoint.py
deepcomedy/preprocessing.py
deepcomedy/__pycache__/
deepcomedy/__pycache__/utils.cpython-37.pyc
deepcomedy/__pycache__/__init__.cpython-37.pyc
deepcomedy/__pycache__/preprocessing.cpython-37.pyc
deepcomedy/metrics.py
deepcomedy/__init__.py
deepcomedy/.ipynb_checkpoints/


In [None]:
import io
import os
import re
import time
import unicodedata
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers.experimental import preprocessing

from deepcomedy.models.transformer import *
from deepcomedy.preprocessing import *
from deepcomedy.utils import *
from deepcomedy.metrics import *
import tqdm

%load_ext autoreload
%autoreload 2

## 1. Data preprocessing

In [None]:
raw_text = open("./data/divina_textonly.txt", "rb").read().decode(encoding="utf-8")
raw_syll_text = (
    open("./data/divina_syll_textonly.txt", "rb").read().decode(encoding="utf-8")
)
syll_text = preprocess_text(raw_syll_text, end_of_tercet='')
text = preprocess_text(raw_text, end_of_tercet='')

Split preprocessed text into verses

In [None]:
sep = "<EOV>"
input_tercets = [x.lstrip() + sep for x in text.split(sep)][:-1]
target_tercets = [x.lstrip() + sep for x in syll_text.split(sep)][:-1]

Encode with input and target tokenizers

In [None]:
input_tokenizer = tf.keras.preprocessing.text.Tokenizer(
    char_level=False, filters="", lower=False
)
input_tokenizer.fit_on_texts(input_tercets)

target_tokenizer = tf.keras.preprocessing.text.Tokenizer(
    char_level=False, filters="", lower=False
)
target_tokenizer.fit_on_texts(target_tercets)

enc_input_tercets = input_tokenizer.texts_to_sequences(input_tercets)
enc_target_tercets = target_tokenizer.texts_to_sequences(target_tercets)

input_vocab_size = len(input_tokenizer.word_index) + 1
target_vocab_size = len(target_tokenizer.word_index) + 1

In [None]:
input_text = []
target_text = []
target_text_tercet = []

for line in range(len(enc_input_tercets) - 2):
    input_text.append(list(chain(*enc_input_tercets[line : line + 3])))
    target_text_tercet.append(list(chain(*enc_target_tercets[line : line + 3])))
    target_text.append(list(chain(*enc_target_tercets[line : line + 4])))

Pad sequences

In [None]:
padded_input_text = tf.keras.preprocessing.sequence.pad_sequences(
    input_text, padding="post"
)
padded_target_text = tf.keras.preprocessing.sequence.pad_sequences(
    target_text, padding="post"
)

In [None]:
input_train, input_val, target_train, target_val = train_test_split(padded_input_text, padded_target_text)

## 2. Hyperparameter sweep

In [None]:
original_text = preprocess_text(raw_text, end_of_verse='\n', end_of_tercet='', start_of_verse='', word_level=True)
original_text = re.sub(r' <SEP> ', ' ', original_text)
original_text

# Get the set of real words from the Divine Comedy to evaluate word correctness
# TODO create function to obtain word-level vocabulary from divine comedy
word_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='\n-:,?“‘)—»«!”(";.’ ', lower=False)
word_tokenizer.fit_on_texts([raw_text])
real_words = set(word_tokenizer.word_index.keys())

def generation_metrics(result):
    # Generation metrics
    # · Average syllables
    # · Hendecasyllabicness
    # · Correctness
    # · Ngrams-plagiarism
    # · Rhymeness
    
    result_verses = result.split("\n")
    
    avg_syll = average_syllables(result_verses)
    hend_ratio = correct_hendecasyllables_ratio(result_verses)

    result_verses = re.sub(r'\|', '', result)
    result_verses = remove_punctuation(result_verses)

    plagiarism = ngrams_plagiarism(result_verses, original_text)
    
    gen_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters="\n", lower=False)
    gen_tokenizer.fit_on_texts([result_verses])
    gen_words = set(gen_tokenizer.word_index)
    
    correctness, _ = correct_words_ratio(gen_words, real_words, return_errors=True)
    incorrectness_score = incorrectness(gen_words, real_words)
        
    result_verses = result_verses.split('\n')
    rhyme_ratio = chained_rhymes_ratio(result_verses)
    
    return avg_syll, hend_ratio, plagiarism, correctness, incorrectness_score, rhyme_ratio

In [None]:
start_symbol = target_tokenizer.word_index["<GO>"]
stop_symbol = target_tokenizer.word_index["<EOV>"]

def syllabify_tercets(input_text, correct_text, n=10):
    
    N = 3 * n

    output = []
            
    syll_correct = target_tokenizer.sequences_to_texts(correct_text[:N:3])
    syll_correct = ''.join(syll_correct)
    syll_correct = strip_tokens(syll_correct)
    syll_correct = syll_correct.split('\n')
    
    for i in tqdm.tqdm(range(0, N, 3)):

        encoder_input = tf.convert_to_tensor([input_text[i]])
        decoder_input = tf.convert_to_tensor([[start_symbol]])

        syll_output = evaluate(transformer, encoder_input, decoder_input, stop_symbol, max_length=400)
        syll_output = target_tokenizer.sequences_to_texts(syll_output.numpy())[0]
        syll_output = strip_tokens(syll_output)
        syll_output = syll_output.split('\n')[:3] # Only take first 3 produced verses 

        output += syll_output
        
    return output, syll_correct

syll_output, syll_target = syllabify_tercets(input_text, target_text_tercet, n=1)

  0%|          | 0/1 [00:00<?, ?it/s]


NameError: name 'transformer' is not defined

In [None]:
sweep_config = {
    "name": "char2char-sweep",
    "method": "grid",
    "metric": {"name": "loss", "goal": "minimize"},
    "parameters": {
        "batch_size": {"value": 32},
        "epochs": {"value": 50},
        "num_layers": {"values": [4, 8, 12]},
        "num_heads": {"values": [4, 8]},
        "d_model": {"value": 256},
        "dff": {"values": [512, 1024]},
    },
}

sweep_id = wandb.sweep(sweep_config, project='deepcomedy', entity='deepcomedy')



Create sweep with ID: xzets3ic
Sweep URL: https://wandb.ai/deepcomedy/deepcomedy/sweeps/xzets3ic


In [None]:
start_symbol = target_tokenizer.word_index["<GO>"]
stop_symbol = target_tokenizer.word_index["<EOV>"]

# Input for generation
encoder_input = [input_text[0]]
decoder_input = [target_text_tercet[0]]

def sweep():
    with wandb.init() as run:
        config = wandb.config
        dataset = make_dataset(input_train, target_train, batch_size=config["batch_size"])
        validation_dataset = make_dataset(input_val, target_val, batch_size=config["batch_size"])
        model, trainer = make_transformer_model(config, input_vocab_size, target_vocab_size, checkpoint_save_path=None)
        trainer.train(dataset, config["epochs"], log_wandb=True, validation_dataset=validation_dataset, validation_every=5)
        
        # At the end of the training
        # · Generate text
        # · Compute generation metrics
        # · Compute syllabification metrics
        
        # Generate
        result = generate(model, encoder_input, decoder_input, input_tokenizer, target_tokenizer, 1, start_symbol, stop_symbol)
        result = strip_tokens(result)
        result = '<br>'.join(result.split('\n'))
        print(result)
        wandb.log({"generated": wandb.Html("<pre>"+result+"</pre>", inject=False)})
        
#         print('Text generated!')        
        
#         # Generation metrics
#         avg_syll, hend_ratio, plagiarism, correctness, incorrectness, rhymeness = generation_metrics(result)
#         print('Generation metrics ok!')
        
#         # Syllabification metrics
#         syll_output, syll_target = syllabify_tercets(input_text, target_text_tercet, n=10)
        
#         print('Syllabification...')
        
#         correct_verses, distances = list(zip(*validate_syllabification(syll_output, syll_target)))
        
#         print('Syll metrics ok!')
        
#         wandb.log({
#             'avg_syll': avg_syll, 
#             'hend_ratio': hend_ratio,
#             'plagiarism': plagiarism,
#             'correctness': correctness,
#             'incorrectness': incorrectness,
#             'rhymeness': rhymeness,
#             'exact_syll_ratio': sum(correct_verses) / len(correct_verses),
#             'syll_edit_distance': np.mean(distances),
#         })
        
        
wandb.agent(sweep_id, function=sweep)

[34m[1mwandb[0m: Agent Starting Run: gke9pmr3 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	d_model: 256
[34m[1mwandb[0m: 	dff: 512
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	num_heads: 4
[34m[1mwandb[0m: 	num_layers: 12
[34m[1mwandb[0m: wandb version 0.10.31 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


## 3. Training

In [None]:
dataset = make_dataset(padded_input_text, padded_target_text)
val_dataset = make_dataset(input_val, target_val)

In [None]:
config = {
    "num_layers" : 6,
    "d_model" : 256,
    "num_heads" : 4,
    "dff" : 512,
}

checkpoint_save_path = "./checkpoints/char-input_char-output_gen"

In [None]:
transformer, transformer_trainer = make_transformer_model(config, input_vocab_size, target_vocab_size, checkpoint_save_path= None)

In [None]:
wandb.init()
transformer_trainer.train(dataset, 50, validation_dataset=val_dataset, validation_every=1, log_wandb=True)

Epoch 1 Batch 0 Loss 5.2680 Accuracy 0.0026
Epoch 1 Batch 50 Loss 4.0824 Accuracy 0.1124
Epoch 1 Batch 100 Loss 3.5939 Accuracy 0.1597
Epoch 1 Batch 150 Loss 3.3976 Accuracy 0.1779
Epoch 1 Batch 200 Loss 3.2241 Accuracy 0.1992
Epoch 1 Batch 250 Loss 3.0314 Accuracy 0.2269
Epoch 1 Batch 300 Loss 2.8759 Accuracy 0.2488
Epoch 1 Batch 350 Loss 2.7553 Accuracy 0.2656
Epoch 1 Batch 400 Loss 2.6590 Accuracy 0.2791
Epoch 1 Batch 0 Validation Loss 1.9128 Validation Accuracy 0.3933
Epoch 1 Batch 50 Validation Loss 1.8951 Validation Accuracy 0.3944
Epoch 1 Batch 100 Validation Loss 1.8935 Validation Accuracy 0.3953
Epoch 1 Loss 2.5909 Accuracy 0.2887
Time taken for 1 epoch: 127.46 secs

Epoch 2 Batch 0 Loss 1.9310 Accuracy 0.3846
Epoch 2 Batch 50 Loss 1.9296 Accuracy 0.3830
Epoch 2 Batch 100 Loss 1.9165 Accuracy 0.3863
Epoch 2 Batch 150 Loss 1.9031 Accuracy 0.3896
Epoch 2 Batch 200 Loss 1.8896 Accuracy 0.3934
Epoch 2 Batch 250 Loss 1.8744 Accuracy 0.3978
Epoch 2 Batch 300 Loss 1.8573 Accuracy 0.4

## 4. Generation

In [None]:
def generate(transformer, input_sequence, target_sequence, input_tokenizer, target_tokenizer, steps, start_symbol, stop_symbol):

    result = target_tokenizer.sequences_to_texts(target_sequence)[0]
    
    encoder_input = input_sequence
    decoder_input = target_sequence

    for _ in range(steps):

        encoder_input = tf.convert_to_tensor(encoder_input)
        decoder_input = tf.convert_to_tensor(decoder_input)
        output = evaluate(transformer, encoder_input, decoder_input, stop_symbol, choose_next_token=choose_topk)

        generated_text = target_tokenizer.sequences_to_texts(output.numpy())[0]
        
        verses = [line.lstrip() + '<EOV> ' for line in generated_text.split('<EOV>') if line.strip() != '']
        
        result = ''.join([result, verses[-1]])
                
        verses = ''.join(verses[-3:])
        
        decoder_input = target_tokenizer.texts_to_sequences([verses])
        
        verses = remove_syll_token(verses)
        verses = re.sub("<EOV><GO>", "<EOV> <GO>", verses)
        verses = verses.strip()

        encoder_input = input_tokenizer.texts_to_sequences([verses])
        
    return result

In [None]:
start_symbol = target_tokenizer.word_index["<GO>"]
stop_symbol = target_tokenizer.word_index["<EOV>"]

encoder_input = [input_text[0]]
decoder_input = [target_text_tercet[0]]

result = generate(transformer, encoder_input, decoder_input, input_tokenizer, target_tokenizer, 5, start_symbol, stop_symbol)

In [None]:
print(strip_tokens(result))

|Nel |mez|zo |del |cam|min |di |no|stra |vi|ta
|mi |ri|tro|vai |per |u|na |sel|va o|scu|ra,
|ché |la |di|rit|ta |via |e|ra |smar|ri|ta.
|La |spa|da |fiam|ma e |io |at|ten|to |lu|ce;
|non |che |con|tr’ al |ciel |che |vi |sog|gior|no,
|pa|u|ra |di |là, |se |sta|re è |più |for|te.
|Lo |du|ca |mio |sem|bian|za |mi |ri|sco|sta,
|qual |si |mos|se |con|vien |che |tu |ti |pe|ne;


In [None]:
generation_metrics(strip_tokens(result))

mi ritrovai per una selva oscura La spada fiamma e io attento luce
La spada fiamma e io attento luce paura di là se stare è più forte
non che contr al ciel che vi soggiorno Lo duca mio sembianza mi riscosta


(10.875, 1.0, 0.3018867924528302, 0.9807692307692307, 0.02, 0.25)

### Hendecasyllabicness

In [None]:
x = strip_tokens(result).split('\n')
x

['|Nel |mez|zo |del |cam|min |di |no|stra |vi|ta',
 '|mi |ri|tro|vai |per |u|na |sel|va o|scu|ra,',
 '|ché |la |di|rit|ta |via |e|ra |smar|ri|ta.',
 '|E |quel|l’ al|tro |che |più |e |più |d’ in|con|tran|to,',
 '|e |con |le |sue |lu|ci |ri|pen|ne e |cal|le,',
 '|so|lo in|tor|no |co|me |quei |c’ han|no |mu|to.',
 '|Ma |se |l’ a|mor |che |tu |di|scer|nes|se',
 '|da |l’ o|ra |che |vien |per |la |vil|ta |de|gna,']

In [None]:
average_syllables(x)

11.0

In [None]:
correct_hendecasyllables_ratio(x)

1.0

### Rhymeness

In [None]:
x = strip_tokens(result)
x = re.sub(r'\|', '', x)
x = remove_punctuation(x)
x = x.split('\n')
x

['Nel mezzo del cammin di nostra vita',
 'mi ritrovai per una selva oscura',
 'ché la diritta via era smarrita',
 'E quell altro che più e più d incontranto',
 'e con le sue luci ripenne e calle',
 'solo intorno come quei c hanno muto',
 'Ma se l amor che tu discernesse',
 'da l ora che vien per la vilta degna']

In [None]:
chained_rhymes_ratio(x)

mi ritrovai per una selva oscura E quell altro che più e più d incontranto
E quell altro che più e più d incontranto solo intorno come quei c hanno muto
e con le sue luci ripenne e calle Ma se l amor che tu discernesse


0.25

### Ngrams plagiarism

In [None]:
x = strip_tokens(result)
x = re.sub(r'\|', '', x)
x = remove_punctuation(x)

In [None]:
original_text = preprocess_text(raw_text, end_of_verse='\n', end_of_tercet='', start_of_verse='', word_level=True)
original_text = re.sub(r' <SEP> ', ' ', original_text)
original_text

'Nel mezzo del cammin di nostra vita\nmi ritrovai per una selva oscura , \nché la diritta via era smarrita . \nAhi quanto a dir qual era è cosa dura\nesta selva selvaggia e aspra e forte\nche nel pensier rinova la paura ! \nTant ’  è amara che poco è più morte ; \nma per trattar del ben ch ’ i ’  vi trovai , \ndirò de l ’ altre cose ch ’ i ’  v ’ ho scorte . \nIo non so ben ridir com ’  i ’  v ’ intrai , \ntant ’  era pien di sonno a quel punto\nche la verace via abbandonai . \nMa poi ch ’ i ’  fui al piè d ’ un colle giunto , \nlà dove terminava quella valle\nche m ’ avea di paura il cor compunto , \nguardai in alto e vidi le sue spalle\nvestite già de ’  raggi del pianeta\nche mena dritto altrui per ogne calle . \nAllor fu la paura un poco queta , \nche nel lago del cor m ’ era durata\nla notte ch ’ i ’  passai con tanta pieta . \nE come quei che con lena affannata , \nuscito fuor del pelago a la riva , \nsi volge a l ’ acqua perigliosa e guata , \ncosì l ’ animo mio ,  ch ’ ancor fu

In [None]:
ngrams_plagiarism(x, original_text)

0.39285714285714285

### Word correctness

In [None]:
raw_text

'\n\n  Nel mezzo del cammin di nostra vita\n  mi ritrovai per una selva oscura\n  ché la diritta via era smarrita\n\n  Ahi quanto a dir qual era è cosa dura\n  esta selva selvaggia e aspra e forte\n  che nel pensier rinova la paura\n\n  Tant è amara che poco è più morte\n  ma per trattar del ben chi vi trovai\n  dirò de laltre cose chi vho scorte\n\n  Io non so ben ridir com i vintrai\n  tant era pien di sonno a quel punto\n  che la verace via abbandonai\n\n  Ma poi chi fui al piè dun colle giunto\n  là dove terminava quella valle\n  che mavea di paura il cor compunto\n\n  guardai in alto e vidi le sue spalle\n  vestite già de raggi del pianeta\n  che mena dritto altrui per ogne calle\n\n  Allor fu la paura un poco queta\n  che nel lago del cor mera durata\n  la notte chi passai con tanta pieta\n\n  E come quei che con lena affannata\n  uscito fuor del pelago a la riva\n  si volge a lacqua perigliosa e guata\n\n  così lanimo mio chancor fuggiva\n  si volse a retro a rimirar lo passo\n 

In [None]:
x = strip_tokens(result)
x = re.sub(r'\|', '', x)
x = remove_punctuation(x)

In [None]:
word_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='\n-:,?“‘)—»«!”(";.’ ', lower=False)
word_tokenizer.fit_on_texts([raw_text])
real_words = set(word_tokenizer.word_index.keys())

In [None]:
x = remove_punctuation(x)
gen_tokenizer = tf.keras.preprocessing.text.Tokenizer(
    filters="\n", lower=False
)
gen_tokenizer.fit_on_texts([x])

In [None]:
gen_words = set(gen_tokenizer.word_index)

In [None]:
correct_words_ratio(gen_words, real_words, return_errors=True)

(0.9411764705882353, array(['incontranto', 'ripenne', 'vilta'], dtype='<U11'))

In [None]:
real_words

dict_keys(['<SEP>', '<GO>', '<EOV>', ',', '’', 'e', 'che', '.', 'l', 'la', 'a', 'di', ';', 'non', 'per', 'in', '«', '»', 'si', 'ch', ':', 'io', 'le', 'è', 'mi', 'sì', 'li', 'il', 'de', 'più', 'con', 'da', 'come', 'del', 'al', 'd', 'i', 'lo', 'un', 'se', 's', 'tu', 'ne', 'E', 'ma', 'quel', 'me', 'fu', 'mio', 'nel', 'suo', 'sua', 'ti', 'era', '?', 'm', 'là', 'lui', 'tanto', 'son', 'già', 'poi', '!', 'quando', 'così', 'altro', 'quella', 'occhi', 'disse', 'sé', 'noi', 'lor', 'mia', 'una', 'ché', 'o', 'qual', 'perché', 'ben', 'chi', 'tutto', 'ha', 'fa', 'questo', 'dal', 'qui', 'esser', '“', 'ogne', 'elli', 'cui', 'giù', 'pur', 'vidi', 'ad', 'com', 'né', 'altra', 'ciò', 'n', 'tal', '”', 'Io', 'prima', 'questa', 'mondo', 'Ma', 'ancor', 'te', 'poco', 'mai', 'sù', 'terra', 'fuor', 'su', 'sanza', 'onde', 'quanto', 'tra', 'dove', 'Dio', 'però', 'O', 'gente', 'tua', 'avea', 'parte', 'altri', 'tuo', 'due', 'Non', 'col', 'dentro', 'lei', 'ciel', 'voi', 'veder', 'sotto', 'mente', 'tutti', 'sol', 'col

In [None]:
incorrectness(gen_words, real_words)

0.04

## 5. Syllabification

In [None]:
start_symbol = target_tokenizer.word_index["<GO>"]
stop_symbol = target_tokenizer.word_index["<EOV>"]

In order to perform syllabification we pass the tercet to be syllabified to the encoder and the `start_symbol` to the decoder.

In [None]:
encoder_input = tf.convert_to_tensor([input_text[0]])
decoder_input = tf.convert_to_tensor([[start_symbol]])

In [None]:
syll_output = evaluate(transformer, encoder_input, decoder_input, stop_symbol, max_length=400)

In [None]:
syll_output = evaluate(transformer, encoder_input, decoder_input, stop_symbol, max_length=400)
syll_output = target_tokenizer.sequences_to_texts(syll_output.numpy())[0]
syll_output = strip_tokens(syll_output)
syll_output = syll_output.split('\n')[:3] # Only take first 3 produced verses
syll_output

In [None]:
syll_correct = target_text_tercet[0]
syll_correct = target_tokenizer.sequences_to_texts([syll_correct])[0]
syll_correct = strip_tokens(syll_correct)
syll_correct = syll_correct.split('\n')
syll_correct

In [None]:
input_tokenizer.sequences_to_texts([input_text[0]])

In [None]:
n = 1
N = 3 * n

output = []

for i in tqdm.tqdm(range(0, N, 3)):
    
    encoder_input = tf.convert_to_tensor([input_text[i]])
    decoder_input = tf.convert_to_tensor([[start_symbol]])
        
    syll_output = evaluate(transformer, encoder_input, decoder_input, stop_symbol, max_length=400)
    syll_output = target_tokenizer.sequences_to_texts(syll_output.numpy())[0]
    syll_output = strip_tokens(syll_output)
    syll_output = syll_output.split('\n')[:3] # Only take first 3 produced verses 
        
    output += syll_output
    

100%|██████████| 1/1 [00:21<00:00, 21.94s/it]


In [None]:
output

['|ma |per |lo |vi|so |che |di |là |si |ca|la,',
 '|e |dis|se:« |Qui |non |ti |da|reb|be |frut|to.',
 '|Non |ti |po|rò |ch’ al|cu|na |ver|go|gna|ta;']

In [None]:
syll_correct = target_tokenizer.sequences_to_texts(target_text_tercet[:N:3])
syll_correct = ''.join(syll_correct)
syll_correct = strip_tokens(syll_correct)
syll_correct = syll_correct.split('\n')
syll_correct

['|Nel |mez|zo |del |cam|min |di |no|stra |vi|ta',
 '|mi |ri|tro|vai |per |u|na |sel|va o|scu|ra,',
 '|ché |la |di|rit|ta |via |e|ra |smar|ri|ta.']

In [None]:
validate_syllabification(output, syll_correct)

[(False, 0.4130434782608695),
 (False, 0.36170212765957444),
 (False, 0.3555555555555555)]

## 6. Save model

In [None]:
transformer.save_weights('models/c2c-gen.h5')

In [None]:
config = {
    "num_layers" : 6,
    "d_model" : 256,
    "num_heads" : 4,
    "dff" : 512,
}

new_transformer = Transformer(
        num_layers=config["num_layers"],
        d_model=config["d_model"],
        num_heads=config["num_heads"],
        dff=config["dff"],
        input_vocab_size=input_vocab_size,
        target_vocab_size=target_vocab_size,
        pe_input=1000,
        pe_target=1000,
        rate=0.1,
    )

In [None]:
# In order to load the new weights the model should be called once for the variables to be initialized

# Any inp, tar is ok here
start_symbol = target_tokenizer.word_index["<GO>"]
stop_symbol = target_tokenizer.word_index["<EOV>"]

inp = tf.convert_to_tensor([[start_symbol]])
tar = tf.convert_to_tensor([[start_symbol]])

enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(inp, tar)

new_transformer(inp, tar, False, enc_padding_mask, look_ahead_mask, dec_padding_mask);

In [None]:
new_transformer.load_weights('models/c2c-gen.h5')
transformer = new_transformer

In [None]:
encoder_input = [input_text[0]]
decoder_input = [target_text_tercet[0]]

result = generate(new_transformer, encoder_input, decoder_input, input_tokenizer, target_tokenizer, 6, start_symbol, stop_symbol)

In [None]:
result

'<GO> | N e l <SEP> | m e z | z o <SEP> | d e l <SEP> | c a m | m i n <SEP> | d i <SEP> | n o | s t r a <SEP> | v i | t a <EOV> <GO> | m i <SEP> | r i | t r o | v a i <SEP> | p e r <SEP> | u | n a <SEP> | s e l | v a <SEP> o | s c u | r a , <EOV> <GO> | c h é <SEP> | l a <SEP> | d i | r i t | t a <SEP> | v i a <SEP> | e | r a <SEP> | s m a r | r i | t a . <EOV><GO> | P o i <SEP> | c h ’ <SEP> a <SEP> | m e <SEP> | s t e s | s o <SEP> | l e | v a r <SEP> | l i <SEP> | f é <SEP> | p i e | n i ; <EOV> <GO> | c o | s ì <SEP> | d i s | s e <SEP> ’ l <SEP> | m a | e | s t r o ; <SEP> e d <SEP> | e l <SEP> | s ’ <SEP> a c | c e n | d e <EOV> <GO> | q u a n | t u n | q u e <SEP> | c o | s a <SEP> | c h e <SEP> | t u <SEP> | v e | d i <SEP> | s a p | p i e . <EOV> <GO> | L a <SEP> | p e r | c h é <SEP> | t u a <SEP> | d a <SEP> | l i <SEP> | s p e | r a n | z a <SEP> | p o r | t i <EOV> <GO> | n o n <SEP> | t i <SEP> | f e r | r ò <SEP> | c o | s ì , <SEP> | p e r | c h é <SEP> | n o n <SEP> | 

In [None]:
if 'google.colab' in str(get_ipython()):
    files.download('models/c2c-gen.h5')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>