In [1]:
from Transformer import Transformer

In [2]:
from Parameters import TrainingParameters, EnviromentParameters
from Controller import TrainingController
from Utils.SaveUtils import load_parameters
import tensorflow as tf
import numpy as np
from Data import XESDatasetWithResource

In [3]:
folder_path = "./SavedModels/%s" % (
    "0.8264_BPI2012WithResource_BaselineLSTMWithResource_2021-06-18 06:11:10.009443" # AOW
)

In [4]:
parameters_json = load_parameters(folder_path=folder_path)
parameters = TrainingParameters(**parameters_json)
tf.random.set_seed(parameters.dataset_split_seed)
np.random.seed(parameters.dataset_split_seed)

In [5]:
trainer = TrainingController(parameters = parameters)


| Running on /job:localhost/replica:0/task:0/device:CPU:0  

| Preprocessed data loaded successfully: ./datasets/preprocessed/BPI_Challenge_2012_with_resource/AOW 


In [6]:
max_trace_length = max([ len(t) for t in list(trainer.dataset.df['trace'])])
vocab_size = len(trainer.dataset.vocab)

In [7]:
#### Prepare model ####
model = Transformer(
    num_layers=parameters.transformerParameters.num_layers,
    d_model=parameters.transformerParameters.model_dim,
    num_heads=parameters.transformerParameters.num_heads,
    dff=parameters.transformerParameters.feed_forward_dim,
    input_vocab_size=vocab_size,
    target_vocab_size=vocab_size,
    pe_input= max_trace_length * 10,
    pe_target= max_trace_length * 10,
)

In [8]:
#### Take a batch from dataset ####
batch_index = list(trainer.train_dataset.as_numpy_iterator())[0]

In [9]:
batch_training_data =  trainer.dataset.collate_fn(batch_index)

In [10]:
caseids, padded_data_traces, lengths, padded_data_resources, batch_amount, padded_target_traces =  batch_training_data

In [12]:
from Transformer.scheduler import CustomSchedule
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
learning_rate = CustomSchedule(parameters.transformerParameters.model_dim)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

In [13]:
from Transformer.masking import create_masks

def train_step(inp, tar):
#   tar_inp = tar[:, :-1]
#   tar_real = tar[:, 1:]
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, inp)

  with tf.GradientTape() as tape:
    predictions, _ = model(inp, inp,
                                 True,
                                 enc_padding_mask,
                                 combined_mask,
                                 dec_padding_mask)
    loss = loss_function(tar, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(accuracy_function(tar, predictions))

In [14]:
import time

In [15]:
from Transformer.utils import loss_function, accuracy_function

In [18]:
EPOCHS = 1
step = 0

for epoch in range(EPOCHS):
    start = time.time()

    train_loss.reset_states()
    train_accuracy.reset_states()

    # inp -> portuguese, tar -> english
    for train_idxs in trainer.train_dataset:
        step += 1
        caseids, padded_data_traces, lengths, padded_data_resources, batch_amount, padded_target_traces  = trainer.dataset.collate_fn(train_idxs)

        train_step(padded_data_traces, padded_target_traces)

        if step % 10 == 0:
            print(f'Epoch {epoch + 1} Step {step} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    print(f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')


Epoch 1 Step 10 Loss 3.1712 Accuracy 0.1259
Epoch 1 Step 20 Loss 3.0832 Accuracy 0.1481
Epoch 1 Step 30 Loss 3.0053 Accuracy 0.1686
Epoch 1 Step 40 Loss 2.9431 Accuracy 0.1837
Epoch 1 Step 50 Loss 2.8798 Accuracy 0.1969
Epoch 1 Step 60 Loss 2.8180 Accuracy 0.2091
Epoch 1 Step 70 Loss 2.7481 Accuracy 0.2256
Epoch 1 Step 80 Loss 2.6746 Accuracy 0.2440
Epoch 1 Loss 2.6604 Accuracy 0.2475
Time taken for 1 epoch: 125.01 secs

