In [1]:
# Imports
import os

import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision

from transformers import GPT2Config, TFGPT2Model

from config import Config

2022-11-26 18:14:23.734772: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-26 18:14:23.851201: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-26 18:14:23.882939: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-26 18:14:24.447294: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; 

# Config

In [2]:
ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
conf = Config("single_instruments_type", ROOT_PATH)

2022-11-26 18:14:25.498247: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


2022-11-26 18:14:26.588167: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30970 MB memory:  -> device: 0, name: Tesla V100S-PCIE-32GB, pci bus id: 0000:3b:00.0, compute capability: 7.0
2022-11-26 18:14:26.588759: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 30970 MB memory:  -> device: 1, name: Tesla V100S-PCIE-32GB, pci bus id: 0000:86:00.0, compute capability: 7.0


Decoder creation

In [3]:
decoder = conf.get_decoder()

Testing the decoder on random inputs

In [5]:
output = decoder({'inputs_embeds': tf.ones((conf.BATCH_SIZE, conf.SEQ_LEN, conf.TOKEN_DIM))})
output['last_hidden_state'].shape

TensorShape([12, 6144, 512])

# Dataset

Load the dataset from disk and process it (batching, shuffling, ...)

In [8]:
DATASET_PATH = os.path.join('..', 'data', 'tf_data7')
dataset = tf.data.Dataset.load(DATASET_PATH).batch(conf.BATCH_SIZE).cache().shuffle(conf.SHUFFLE_SIZE).prefetch(conf.PREFETCH_SIZE)
dataset

<PrefetchDataset element_spec=(TensorSpec(shape=(None, 6143, 11), dtype=tf.uint8, name=None), TensorSpec(shape=(None, 3), dtype=tf.uint8, name=None))>

In [9]:
X, y = next(dataset.as_numpy_iterator())
print(X.shape, y.shape)

(12, 6143, 11) (12, 3)


2022-11-26 18:31:55.627353: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


# Embedding layers

The inputs need to be encoded by some embedding layer (a specific embedding layer for each token type).

In [12]:
embedding_layers = [
    # Type embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['type'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Measure embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['measure'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Beat embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['beat'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Position embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['position'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Duration embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['duration'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Pitch embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['pitch'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Instrument embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['instrument'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Velocity embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['velocity'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Key sign embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['key_sign'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Time sign embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['time_sign'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN),
    # Tempo embedding
    tf.keras.layers.Embedding(conf.INPUT_RANGES['tempo'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN)
]

Run the embedding layers on our inputs

In [15]:
outputs = []
for i in tf.range(X.shape[2]):
    outputs.append(embedding_layers[i](X[:, : ,i]))

We also need to encode the genre using some layers.

In [20]:
genre_embedding_module = tf.keras.Sequential([
    tf.keras.layers.Dense(conf.SINGLE_EMB_SIZE, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(conf.GENRE_DIM, activation='relu')
])

In [21]:
genre_embedding = genre_embedding_module(y)
genre_embedding.shape

TensorShape([12, 512])

## Embedding concatenation

We concatenate the output embeddings into a single tensor

In [22]:
types_concat_layer = tf.keras.layers.Concatenate(axis=2)
concat_outputs = types_concat_layer(outputs)
concat_outputs.shape

TensorShape([12, 6143, 704])

Then we need to resize them into a known dimensionality

In [23]:
dense_layer = tf.keras.layers.Dense(conf.TOKEN_DIM)
encoding = dense_layer(concat_outputs)
encoding.shape

TensorShape([12, 6143, 512])

Finally, we need to preprend the genre embedding token to the sequence

In [24]:
sequence_concat_layer = tf.keras.layers.Concatenate(axis=1)
final_sequence = sequence_concat_layer([genre_embedding[:, np.newaxis, :], encoding])
final_sequence.shape

TensorShape([12, 6144, 512])

## Positional encoding

We also add positional encodings to encode which is the position of each token in the sequence.

In [25]:
positional_encoding_matrix = conf.get_positional_embedding_matrix()

In transformers, it is common to add the positional embedding to the elements embeddings.

In [26]:
sum_layer = tf.keras.layers.Add()
positional_encoding = tf.repeat(positional_encoding_matrix[np.newaxis, :, :], tf.constant(conf.BATCH_SIZE), axis=0)
final_encoding = sum_layer([final_sequence, positional_encoding])
final_encoding.shape

TensorShape([12, 6144, 512])

# Output management

In [27]:
output = decoder({'inputs_embeds': final_encoding})
output['last_hidden_state'].shape

TensorShape([12, 6144, 512])

We need a dense + softmax layer for each of the tokens for trying to reconstruct the input.

In [28]:
output_dense_layers = [
    # Type
    tf.keras.layers.Dense(conf.INPUT_RANGES['type'], activation='softmax'),
    # Measure
    tf.keras.layers.Dense(conf.INPUT_RANGES['measure'], activation='softmax'),
    # Beat
    tf.keras.layers.Dense(conf.INPUT_RANGES['beat'], activation='softmax'),
    # Position
    tf.keras.layers.Dense(conf.INPUT_RANGES['position'], activation='softmax'),
    # Duration
    tf.keras.layers.Dense(conf.INPUT_RANGES['duration'], activation='softmax'),
    # Pitch
    tf.keras.layers.Dense(conf.INPUT_RANGES['pitch'], activation='softmax'),
    # Instrument
    tf.keras.layers.Dense(conf.INPUT_RANGES['instrument'], activation='softmax'),
    # Velocity
    tf.keras.layers.Dense(conf.INPUT_RANGES['velocity'], activation='softmax'),
    # Key sign
    tf.keras.layers.Dense(conf.INPUT_RANGES['key_sign'], activation='softmax'),
    # Time sign
    tf.keras.layers.Dense(conf.INPUT_RANGES['time_sign'], activation='softmax'),
    # Tempo
    tf.keras.layers.Dense(conf.INPUT_RANGES['tempo'], activation='softmax')
]

In [None]:
out_scores = [output_dense_layers[i](output['last_hidden_state']) 
              for i in range(len(output_dense_layers))]

for i in range(len(out_scores)):
    print(out_scores[i].shape)

(12, 6144, 8)
(12, 6144, 256)
(12, 6144, 131)
(12, 6144, 128)
(12, 6144, 136)
(12, 6144, 256)
(12, 6144, 129)
(12, 6144, 128)
(12, 6144, 25)
(12, 6144, 153)
(12, 6144, 49)


## Groundtruth vectors definition

In [30]:
gt_vectors = [X[:,:,i] for i in range(len(out_scores))]

for i in range(len(out_scores)):
    print(gt_vectors[i].shape)

(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)
(12, 6143)


 ## Loss definition

We can use a simple sparse categorical crossentropy loss function. The two distributions we are comparing are the input sequence (so we ignore the genre embedding token representation) and the output sequence up to the last token representation (`output[:-1]`)
- Note: can we use regularizers or other kinds of constraint enforcing methods for some of the fields? Like, we know that regarding the type field of events there is a strict order to follow (start of song, start of events, ..., notes and end of song). Can we enforce this structure?

In [31]:
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
losses = []
for i in range(len(out_scores)):
    losses.append(loss_function(gt_vectors[i], out_scores[i][:, :-1, :]))
losses

[<tf.Tensor: shape=(), dtype=float32, numpy=3.624619>,
 <tf.Tensor: shape=(), dtype=float32, numpy=6.1048756>,
 <tf.Tensor: shape=(), dtype=float32, numpy=5.639281>,
 <tf.Tensor: shape=(), dtype=float32, numpy=6.1507883>,
 <tf.Tensor: shape=(), dtype=float32, numpy=4.737446>,
 <tf.Tensor: shape=(), dtype=float32, numpy=6.3642745>,
 <tf.Tensor: shape=(), dtype=float32, numpy=5.920309>,
 <tf.Tensor: shape=(), dtype=float32, numpy=5.8808045>,
 <tf.Tensor: shape=(), dtype=float32, numpy=4.109704>,
 <tf.Tensor: shape=(), dtype=float32, numpy=6.3772664>,
 <tf.Tensor: shape=(), dtype=float32, numpy=5.1257977>]

To these loss terms we can add some regularization terms that can help the model produce a grammatically correct sequence.

In [32]:
types = gt_vectors[0]
max_pred_types = tf.argmax(out_scores[0], axis=2) # 6, 6144
# Use a StaticHashTable to map values to their consecutive version within Tensorflow
keys_tensor = tf.range(TYPE_RANGE, dtype=tf.int32)
vals_tensor = tf.constant([0,1,2,3,3,3,3,4], dtype=tf.int32)
table = tf.lookup.StaticHashTable(tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), default_value=-1)
consecutive_gt_types   = table.lookup(tf.cast(types, tf.int32))
consecutive_pred_types = table.lookup(tf.cast(max_pred_types, tf.int32))
# Note: we assume that after token token type 7 all following token types are 7s
differences = consecutive_pred_types[:, 1:] - consecutive_pred_types[:, :-1]

In [33]:
# There are some constraint to pose for regularization
reg_term_1 = tf.reduce_sum(tf.math.maximum(0, -differences))                           # Difference between one element's type and the next is >= 0
reg_term_2 = tf.reduce_sum(tf.math.maximum(0, tf.math.maximum(1, differences) - 1))    # Difference between one element's type and the next is < 1

reg_term_1, reg_term_2

(<tf.Tensor: shape=(), dtype=int32, numpy=5554>,
 <tf.Tensor: shape=(), dtype=int32, numpy=1567>)

In [35]:
REG_SCALER = 0.001

total_loss = tf.reduce_sum(losses) + \
             REG_SCALER * tf.cast(reg_term_1, tf.float32) + \
             REG_SCALER * tf.cast(reg_term_2, tf.float32)
total_loss

<tf.Tensor: shape=(), dtype=float32, numpy=67.156166>

When defining the whole Keras model for training, we can set up multiple outputs and give different weights for the multiple losses.

# Single model

Let's try and define everything that this model does into a complete callable model.

In [1]:
# Imports
import os
import math

import numpy as np
import tensorflow as tf
from transformers import GPT2Config, TFGPT2Model

from config import Config

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
conf = Config("single_instruments_type", ROOT_PATH)

2022-11-27 03:33:15.938035: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-27 03:33:16.091152: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-27 03:33:16.126891: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-27 03:33:16.745832: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; 

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


2022-11-27 03:33:18.948459: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 26391 MB memory:  -> device: 0, name: Tesla V100S-PCIE-32GB, pci bus id: 0000:3b:00.0, compute capability: 7.0
2022-11-27 03:33:18.949039: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 30487 MB memory:  -> device: 1, name: Tesla V100S-PCIE-32GB, pci bus id: 0000:86:00.0, compute capability: 7.0


In [2]:
### CUSTOM LAYERS
# Custom intermediate layer for allowing types transformation (no parameters to be learnt)
class SubsequentTypeTransformationLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(SubsequentTypeTransformationLayer, self).__init__()
        # Use a StaticHashTable to map values to their consecutive version within Tensorflow
        self.keys_tensor = tf.range(conf.INPUT_RANGES['type'])
        self.vals_tensor = tf.constant([0,1,2,3,3,3,3,4])
        self.table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(self.keys_tensor, self.vals_tensor), 
            default_value=-1)

    def call(self, inputs):
        return self.table.lookup(inputs)
    
    
# Custom layer that computes masks for type probabilities computation
class MaskTypeProbabilitiesLayer(tf.keras.layers.Layer):
    def __init__(self, trainable=False, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)

    @tf.function
    def create_mask(self, inputs):
        batch_gt_types = inputs
        mask = tf.TensorArray(tf.bool, size=conf.SEQ_LEN)
        mask = mask.write(0, tf.constant([True, False, False, False, False, False, False, False], dtype=tf.bool))
        for i in tf.range(conf.SEQ_LEN-1):
            token_type = batch_gt_types[i]
            if token_type == 0: # only start of song token: cannot be anything else than instrument choice (1)
                type_mask = tf.constant([False, True, False, False, False, False, False, False], dtype=tf.bool)
            elif token_type == 1: # we reached instrument choice: cannot be anything else than instrument choice (1) or start of events (2)
                type_mask = tf.constant([False, True, True, False, False, False, False, False], dtype=tf.bool)
            elif token_type == 2: # after a 2 there must be at least a 4
                type_mask = tf.constant([False, False, False, False, True, False, False, False], dtype=tf.bool)
            elif token_type == 3: # allow 3,4,5,6,7
                type_mask = tf.constant([False, False, False, True, True, True, True, True], dtype=tf.bool)
            elif token_type >= 4 and token_type <= 6:
                # - if there are at least a 5 and a 6 (there is always a 4)   --> [3, 4, 5, 6, 7]
                # - if a 5 is missing, we only allow 5                        --> [5]
                # - if a 6 is missing, we only allow 6                        --> [6]
                if tf.size(tf.where(batch_gt_types[:i] == 5)) == 0:
                    type_mask = tf.constant([False, False, False, False, False, True, False, False], dtype=tf.bool)
                if tf.size(tf.where(batch_gt_types[:i] == 6)) == 0:
                    type_mask = tf.constant([False, False, False, False, False, False, True, False], dtype=tf.bool)
                else:
                    type_mask = tf.constant([False, False, False, True, True, True, True, True], dtype=tf.bool)
            elif token_type == 7: # at the end of the song we can ONLY GUESS "000000000" TODO: change ending token to type 7s -> 7000000000
                type_mask = tf.constant([False, False, False, False, False, False, False, True], dtype=tf.bool)
            else:
                # ERROR. Define a random type mask so that it's defined in all branches for tf.function
                type_mask = tf.constant([False, False, False, False, False, False, False, False], dtype=tf.bool)
            mask = mask.write(i+1, type_mask)
        return mask.stack()

    def call(self, inputs, training=True):
        '''
        Takes as input the ground truth song (at training time) or the logits (at testing time) 
        and computes a mask for the type probabilities.
        '''
        if training:
            # Use the groundtruth song as a target
            song        = inputs
            gt_types    = song[:,:,0]       # Get the token types from the song (batch_size x seq_len-1)
            # Iterate over the batch to collect the appropriate masks from the song
            masks = tf.map_fn(fn=self.create_mask, 
                elems=gt_types, 
                fn_output_signature=tf.TensorSpec(
                    (conf.SEQ_LEN, conf.INPUT_RANGES['type']), 
                    dtype=tf.bool)
            )
            return masks
        else:
            # Compute the types and their masks one by one based on the type chosen at the previous iteration
            # TODO: implement this branch
            pass


# The main masking layer applying all constraints based on the predicted types 
class MaskingActivationLayer(tf.keras.layers.Layer):
    def __init__(self, trainable=False, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)
        self.default_mask = conf.default_mask
        self.full_mask    = conf.full_mask
        self._numerators  = tf.constant(conf.numerators)
        self._tot_numerators = tf.constant(conf.tot_numerators)

    @tf.function
    def get_max_beat_from_time_sign(self, time_sign):
        '''
        Since the time sign is defined (in utils.time_sign_map()) as: 
            conf.numerators.index(time_sign[0]) + conf.denominators.index(time_sign[1])*conf.tot_numerators

        to retrieve the NUMERATOR of the time_sign given the index you need to divide by conf.tot_numerators and take the rest of the division
        that gives you the index of the corresponding numerator in conf.numerators
        then you use gather or, more simply, a slice to get the actual value of the numerator

        You then subtract 1 because the beat is in [0, numerator)
        '''
        idx = tf.math.floormod(time_sign, self._tot_numerators)
        return self._numerators[idx] - 1

    @tf.function
    def get_mask_for_all_tokens(self, inputs): 
        '''
        Inputs:
        - chosen_types:         (SEQ_LEN-1)*1
        - song_tokens:          (SEQ_LEN-1)*11
        - seq_scores:           (SEQ_LEN-1)*1391

        Returns a list of ndarrays of bool type used for masking
        Inputs are for a SINGLE ELEMENT OF A BATCH of size SEQ_LEN*(1+11+1391) where 1391 is the summed length of logits (minus the type)
        '''
        # Collect inputs from longer tensor
        chosen_types, song_tokens, seq_scores = inputs
        chosen_types = tf.cast(chosen_types, dtype=tf.int32)
        song_tokens  = tf.cast(song_tokens , dtype=tf.int32)
        seq_scores   = tf.cast(seq_scores  , dtype=tf.int32)
        # Indexes
        index_tensor = tf.range(conf.SEQ_LEN-1, dtype=tf.int32)
        # Define mask (output) using a TensorArray
        mask = tf.TensorArray(dtype=tf.bool, size=conf.SEQ_LEN-1)
        # Iterate over the indexes
        for idx in index_tensor:
            ## SETUP ##
            # Define the default variables and flags
            default_token_parts   = [True]*(len(conf.INPUT_RANGES)-1)
            default_flag          = False
            min_measure           = tf.constant(-1, dtype=tf.int32)
            min_beat              = tf.constant(-1, dtype=tf.int32)
            min_position          = tf.constant(-1, dtype=tf.int32)
            # TODO: variable length arrays: can we do it with tensorarrays?
            allowed_instruments   = tf.constant([0]*conf.INPUT_RANGES["instrument"], dtype=tf.int32)
            allowed_key_sign      = tf.constant(-1, dtype=tf.int32)
            allowed_time_sign     = tf.constant(-1, dtype=tf.int32)
            allowed_tempo         = tf.constant(-1, dtype=tf.int32)
            forbidden_instruments_flag = False
            forbidden_instruments = tf.constant([0]*conf.INPUT_RANGES["instrument"], dtype=tf.int32)
            forbidden_key_sign    = tf.constant(-1, dtype=tf.int32)
            forbidden_time_sign   = tf.constant(-1, dtype=tf.int32)
            forbidden_tempo       = tf.constant(-1, dtype=tf.int32)
            # Define the inputs
            chosen_type = chosen_types[idx]
            scores      = seq_scores[idx]
            song        = song_tokens * (tf.expand_dims([1]*idx + [0]*(conf.SEQ_LEN-1-idx), axis=-1)) # Mask all tokens after index idx
            ## MAIN BODY ##
            if chosen_type == 0 or chosen_type == 2: # TODO: change 0s to 7s at the end of the song
                # Original comments: 
                # only way it chooses 0 is that max_type==7 --> AFTER END OF SONG --> only thing the model can do is guess all zeros
                # "does not have to learn nothing" --> it's all zeros just like the padding tensors
                default_token_parts = [True, True, True, True, True, True, True, True, True, True]
                default_flag = True
            elif chosen_type == 1: # Instrument selection, false only for type and instrument type (the ones that you can choose)
                if tf.size(tf.where(song[:idx, 0] == 1)[:,0]) == 0:
                    # Choice of first instrument
                    default_token_parts = [True, True, True, True, True, False, True, True, True, True]  # TODO: Element 6 should not be default = True right?
                    default_flag = True
                else:
                    forbidden_instruments, _ = tf.unique(tf.gather(
                        song[:idx, 6], 
                        tf.where(song[:idx, 0] == 1)[:,0]        # Cast to 1D array
                    ))
                    forbidden_instruments_flag = True
            elif chosen_type == 3: # Notes: They have the same key_sign, time_sign and tempo as last previous event, everything has to be manually decided
                min_measure = song[idx, 1]   # It has to be >= than the last measure
                # If in the MEASURE SCORES the MAX SCORE between all possible measures == min_measure, the measure is min_measure.
                # In this case, we need to make sure that beat >= last_beat
                if tf.math.argmax(
                    scores[:conf.INPUT_RANGES["measure"]], 
                        output_type=tf.int32) == min_measure:  
                    min_beat = song[idx,2]      # It has to be >= than the last beat when measure is the same
                    if tf.math.argmax(scores[
                        conf.INPUT_RANGES["measure"] : 
                        conf.INPUT_RANGES["measure"] + conf.INPUT_RANGES["beat"]], 
                        output_type=tf.int32) == min_beat:
                        min_position = song[idx,3]  # It has to be >= than the last position (if beat and measure are the same)
                    else:
                        min_position = tf.constant(0, dtype=tf.int32)
                else:
                    min_beat = tf.constant(0, dtype=tf.int32)
                    min_position = tf.constant(0, dtype=tf.int32)
                # Only some instruments, key signs, time signs and tempos are allowed for these events: 
                # - for instruments, the allowed ones are the ones that have been defined previously with type = 1
                # - for the others, the allowed ones are the ones that are collected right before the note from event types 4, 5 and 6
                allowed_instruments, _ = tf.unique(tf.gather(
                    song[:idx, 6], 
                    tf.where(song[:idx, 0] == 1)[:,0]
                ))
                # We have made it so that the model should output 3s only after at least a 4, 5 and 6.
                allowed_key_sign = tf.gather(
                    song[:idx, 8], 
                    tf.where(song[:idx, 0] == 4)[:,0]   # if type == 4 --> read the LAST key_sign
                )[-1] 
                allowed_time_sign = tf.gather(
                    song[:idx, 9], 
                    tf.where(song[:idx, 0] == 5)[:,0]   # if type == 5 --> read the LAST time_sign
                )[-1] 
                allowed_tempo = tf.gather(
                    song[:idx, 10], 
                    tf.where(song[:idx, 0] == 6)[:,0]   # if type == 6 --> read the LAST tempo
                )[-1] 
            elif chosen_type >= 4 and chosen_type <= 6:     # key_sign, time_sign, tempo
                # If last event is at the beginning of a measure, you can add an event at the same time
                if song[idx, 3] == 0 and song[idx, 2] == 0:  # if beat and position == 0, the event can be at this measure
                    min_measure = song[idx, 1]
                else:
                    min_measure = song[idx, 1] + 1                   # otherwise it goes to the next measure
                # Fine-grain checks
                # Here, there are cases where there is not a LAST key_sign/time_sign (when this is the first 4, 5 or 6). 
                # In these cases we should use the default masks.
                if chosen_type == 4:
                    # Cannot put the same key_sign again
                    forbidden_key_signs = tf.gather(
                        song[:idx, 8], 
                        tf.where(song[:idx, 0] == 4)[:,0]) # if type == 4 --> read the LAST key_sign
                    if tf.size(forbidden_key_signs) > 0:
                        forbidden_key_sign = forbidden_key_signs[-1]
                elif chosen_type == 5:
                    # Cannot put the same time_sign again
                    forbidden_time_signs = tf.gather(
                        song[:idx, 9], 
                        tf.where(song[:idx, 0] == 5)[:,0]) # if type == 5 --> read the LAST time_sign
                    if tf.size(forbidden_time_signs) > 0:
                        forbidden_time_sign = forbidden_time_signs[-1]
                elif chosen_type == 6:
                    # Cannot put the same tempo again
                    forbidden_tempos = tf.gather(
                        song[:idx, 10], 
                        tf.where(song[:idx, 0] == 6)[:,0]) # if type == 6 --> read the LAST tempo
                    if tf.size(forbidden_tempos) > 9:
                        forbidden_tempo = forbidden_tempos[-1]
            elif chosen_type == 7: # end of song --> only type can be chosen, all the others are default
                default_token_parts = [True, True, True, True, True, True, True, True, True, True]
                default_flag = True

            ## ENDING PART ##
            # Put together the masks
            if default_flag: 
                # No manual masking required, either "can freely choose this part of the token" (True) or 
                # "can only choose default for this part of the token" (False)
                mask.write(idx, tf.concat(
                    # Default mask only allows to predict a 0
                    # Full mask allows to predict any value
                    [self.default_mask[i] if default_token_parts[i] else self.full_mask[i] 
                        for i in range(len(default_token_parts))], axis=-1)
                )
            elif forbidden_instruments_flag:
                # Default flag is False and forbidden instruments contains some elmeents (which means that the chosen type is 1)
                instruments_mask = tf.sparse.SparseTensor(  # Forbidden instruments
                        indices= tf.expand_dims(tf.cast(forbidden_instruments, tf.int64), axis=-1),
                        values = tf.zeros_like(forbidden_instruments),
                        dense_shape=[conf.INPUT_RANGES["instrument"]]
                    )
                instruments_mask = tf.cast(
                    tf.sparse.to_dense(tf.sparse.reorder(instruments_mask), default_value=1), 
                    dtype=tf.dtypes.bool)
                # Only mask the forbidden instruments, all the rest is default
                mask.write(idx, tf.concat(
                    [self.default_mask[i] for i in range(5)] + \
                    [instruments_mask] + \
                    [self.default_mask[i] for i in range(6,len(default_token_parts))], 
                    axis=-1))
            elif chosen_type >= 3 and chosen_type <= 6:
                # General event. What we do depends on which specific event it is, but
                # in general there is always a measure mask.
                measure_mask = tf.cast(
                    tf.concat([
                        tf.repeat([False], min_measure),        # Can be equal to or greater than min_measure
                        tf.repeat([True],  conf.INPUT_RANGES["measure"]-min_measure)], 
                        axis=-1),
                    dtype=tf.dtypes.bool)
                # We need to do manual masking. Define all tensors
                measure_mask     = self.default_mask[0]
                beat_mask        = self.default_mask[1]
                position_mask    = self.default_mask[2]
                duration_mask    = self.default_mask[3]
                pitch_mask       = self.default_mask[4]
                instruments_mask = self.default_mask[5]
                velocity_mask    = self.default_mask[6]
                key_sign_mask    = self.default_mask[7]
                time_sign_mask   = self.default_mask[8]
                tempo_mask       = self.default_mask[9]
                # Create more specific masks depending on the type
                if chosen_type == 4:
                    if forbidden_key_sign != -1: ## forbidden_key_sign can only appear if chosen_type = 4
                        # True in all places but the forbidden key signs
                        key_sign_mask = tf.convert_to_tensor([
                            i != forbidden_key_sign 
                            for i in range(conf.INPUT_RANGES["key_sign"])], 
                            dtype=tf.bool)
                elif chosen_type == 5:
                    if forbidden_time_sign != -1: ## forbidden_time_sign can only appear if chosen_type = 5
                        # True in all places but the forbidden time signs
                        time_sign_mask = tf.convert_to_tensor([
                            i != forbidden_time_sign 
                            for i in range(conf.INPUT_RANGES["time_sign"])], 
                            dtype=tf.bool)
                elif chosen_type == 6:
                    if forbidden_tempo != -1: ## forbidden_tempo can only appear if chosen_type = 6
                        # True in all places but the forbidden tempos
                        tempo_mask = tf.convert_to_tensor([
                            i != forbidden_tempo 
                            for i in range(conf.INPUT_RANGES["tempo"])], 
                            dtype=tf.bool)
                elif chosen_type == 3:
                    # If the event is a note, we have ALLOWED time signs/tempos/key signs, not
                    # forbidden ones. Also, there are many other elements to take into account
                    if min_beat != -1:
                        # oss: allowed_time_sign is always != None if min_beat != None
                        max_beat = self.get_max_beat_from_time_sign(allowed_time_sign)
                        # allowed beats are only AFTER previous beat and BEFORE max_beat from the numerator of the time_sign
                        beat_mask = tf.cast(tf.concat([
                            tf.repeat([False], min_beat),
                            tf.repeat([True],  max_beat-min_beat), 
                            tf.repeat([False], conf.INPUT_RANGES["beat"]-max_beat)],
                            axis=-1), 
                        dtype=tf.dtypes.bool)
                    if min_position != -1:
                        position_mask = tf.cast(tf.concat([
                            tf.repeat([False], min_position), 
                            tf.repeat([True],  conf.INPUT_RANGES["position"]-min_position)],
                            axis=-1), 
                        dtype=tf.dtypes.bool)
                    instruments_mask = tf.sparse.SparseTensor( # Allowed instruments
                        indices=tf.expand_dims(tf.cast(allowed_instruments, tf.int64), axis=-1),
                        values=tf.ones_like(allowed_instruments),
                        dense_shape=[conf.INPUT_RANGES["instrument"]]
                    )
                    instruments_mask = tf.cast(
                        tf.sparse.to_dense(tf.sparse.reorder(instruments_mask), default_value=0),
                        dtype=tf.dtypes.bool)
                    if allowed_key_sign != -1:
                        key_sign_mask = tf.convert_to_tensor([
                            i == allowed_key_sign 
                            for i in range(conf.INPUT_RANGES["key_sign"])], 
                            dtype=tf.bool)
                    if allowed_time_sign != -1:
                        time_sign_mask = tf.convert_to_tensor([
                            i == allowed_time_sign 
                            for i in range(conf.INPUT_RANGES["time_sign"])], 
                            dtype=tf.bool)
                    if allowed_tempo != -1:
                        tempo_mask = tf.convert_to_tensor([
                            i == allowed_tempo 
                            for i in range(conf.INPUT_RANGES["tempo"])], 
                            dtype=tf.bool)
                # Write on the mask
                mask.write(idx, tf.concat([
                    measure_mask, beat_mask, position_mask, duration_mask,
                    pitch_mask, instruments_mask, velocity_mask, key_sign_mask,
                    time_sign_mask, tempo_mask], axis=-1))
        # Return the whole mask
        return mask.stack()

    def call(self, inputs, training=True):
        '''
        Inputs:
        - songs:                BATCH*(SEQ_LEN-1)*11
        - out_logits:           BATCH*(SEQ_LEN-1)*1391 (all except type)
        - types_probabilities:  BATCH*(SEQ_LEN-1)*8 --> becomes chosen_types through argmax --> BATCH*(SEQ_LEN-1)*1

        passes through map_fn --> get_mask_fro_all_tokens to debatch
        '''
        songs, out_logits, types_probabilities = inputs
        chosen_types  = tf.expand_dims(tf.math.argmax(types_probabilities[:,:-1], axis=2), axis=-1)# TODO: check if SEQ_LEN -1 or -2
        concat_logits = tf.concat(out_logits[1:], axis=-1)                 # Concatenate all logits (except type) into a tensor batch_size x seq_len x 1391
        masks = tf.map_fn(fn=self.get_mask_for_all_tokens, elems=(         # Iterate function over batch dimension 
                tf.cast(chosen_types, concat_logits.dtype),                # BATCH*(SEQ_LEN-1)*1
                tf.cast(songs,   concat_logits.dtype),                     # BATCH*(SEQ_LEN-1)*11
                concat_logits[:, :conf.SEQ_LEN-1, :]                       # BATCH*(SEQ_LEN-1)*1391 # TODO: check if SLICE is needed or we could directly pass the full concat_logits
            ), fn_output_signature=tf.TensorSpec(                          # Total: a BATCH * SEQ_LEN-1 * 1403 tensor
                (conf.SEQ_LEN-1, conf.input_ranges_sum - conf.INPUT_RANGES['type']),
                dtype=tf.bool
            ))
        return masks

In [3]:
# Model creation function (to be called within a scope in case of MultiGPU training)
def create_model(input_shape=(conf.SEQ_LEN-1, len(conf.INPUT_RANGES)), num_genres=len(conf.accepted_subgenres), 
                 use_regularization=True, use_masking_layers=True, reg_loss_scale=conf.REG_LOSS_SCALE):
    
    # Get input shapes
    seq_len = input_shape[0]
    events_elements = input_shape[1]
    
    # Instantiate transformer decoder (n_emb % n_head must be 0)
    decoder = conf.get_decoder()
    
    # Define inputs
    songs  = tf.keras.Input(shape=input_shape, name='songs',  dtype=tf.int32)
    genres = tf.keras.Input(shape=num_genres , name='genres', dtype=tf.float32)
    
    # Define loss
    loss_function = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    subsequent_type_transform_layer = SubsequentTypeTransformationLayer()
    reg_scaler = tf.constant(reg_loss_scale, dtype=tf.float32)
    
    # Embedding layers
    embedding_layers = [
        # Type embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['type'],       conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='type_embeddings'),
        # Measure embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['measure'],    conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='measure_embeddings'),
        # Beat embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['beat'],       conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='beat_embeddings'),
        # Position embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['position'],   conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='position_embeddings'),
        # Duration embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['duration'],   conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='duration_embeddings'),
        # Pitch embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['pitch'],      conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='pitch_embeddings'),
        # Instrument embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['instrument'], conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='instrument_embeddings'),
        # Velocity embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['velocity'],   conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='velocity_embeddings'),
        # Key sign embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['key_sign'],   conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='key_sign_embeddings'),
        # Time sign embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['time_sign'],  conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='time_sign_embeddings'),
        # Tempo embedding
        tf.keras.layers.Embedding(conf.INPUT_RANGES['tempo'],      conf.SINGLE_EMB_SIZE, input_length=conf.SEQ_LEN, name='tempo_embeddings')
    ]
    
    genre_embedding_layer = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(conf.GENRE_DIM)
    ], name='genre_embedding')
    
    # Input processing layers
    input_concat_layer         = tf.keras.layers.Concatenate(axis=2)
    sequence_concat_layer      = tf.keras.layers.Concatenate(axis=1)
    encoding_processing_layer  = tf.keras.layers.Dense(conf.TOKEN_DIM, name='encoding_processing')
    
    # Positional encoding
    positional_encoding_matrix = conf.get_positional_embedding_matrix()
    positional_encoding        = tf.repeat(positional_encoding_matrix[tf.newaxis, :, :], tf.shape(songs)[0], axis=0)
    sum_layer                  = tf.keras.layers.Add(name='final_encoding')

    # Output layers
    output_dense_layers = [
        # Type
        tf.keras.layers.Dense(conf.INPUT_RANGES['type'],       name='type_scores'),
        # Measure
        tf.keras.layers.Dense(conf.INPUT_RANGES['measure'],    name='measure_scores'),
        # Beat
        tf.keras.layers.Dense(conf.INPUT_RANGES['beat'],       name='beat_scores'),
        # Position
        tf.keras.layers.Dense(conf.INPUT_RANGES['position'],   name='position_scores'),
        # Duration
        tf.keras.layers.Dense(conf.INPUT_RANGES['duration'],   name='duration_scores'),
        # Pitch
        tf.keras.layers.Dense(conf.INPUT_RANGES['pitch'],      name='pitch_scores'),
        # Instrument
        tf.keras.layers.Dense(conf.INPUT_RANGES['instrument'], name='instrument_scores'),
        # Velocity
        tf.keras.layers.Dense(conf.INPUT_RANGES['velocity'],   name='velocity_scores'),
        # Key sign
        tf.keras.layers.Dense(conf.INPUT_RANGES['key_sign'],   name='keysign_scores'),
        # Time sign
        tf.keras.layers.Dense(conf.INPUT_RANGES['time_sign'],  name='timesign_scores'),
        # Tempo
        tf.keras.layers.Dense(conf.INPUT_RANGES['tempo'],      name='tempo_scores')
    ]
    
    output_probs_layers = [
        # Type
        tf.keras.layers.Softmax(name='type_probabilities'),
        # Measure
        tf.keras.layers.Softmax(name='measure_probabilities'),
        # Beat
        tf.keras.layers.Softmax(name='beat_probabilities'),
        # Position
        tf.keras.layers.Softmax(name='position_probabilities'),
        # Duration
        tf.keras.layers.Softmax(name='duration_probabilities'),
        # Pitch
        tf.keras.layers.Softmax(name='pitch_probabilities'),
        # Instrument
        tf.keras.layers.Softmax(name='instrument_probabilities'),
        # Velocity
        tf.keras.layers.Softmax(name='velocity_probabilities'),
        # Key sign
        tf.keras.layers.Softmax(name='keysign_probabilities'),
        # Time sign
        tf.keras.layers.Softmax(name='timesign_probabilities'),
        # Tempo
        tf.keras.layers.Softmax(name='tempo_probabilities')
    ]
    
    # Masking layers
    if use_masking_layers:
        type_masking_layer = MaskTypeProbabilitiesLayer()
        activations_masking =  MaskingActivationLayer()
    
    # Model dynamics
    embeddings        = [embedding_layers[i](songs[:,:,i]) for i in range(events_elements)]
    genre_embedding   = genre_embedding_layer(genres)
    input_embedding   = input_concat_layer(embeddings)
    input_embedding   = encoding_processing_layer(input_embedding)
    input_embedding   = sequence_concat_layer([genre_embedding[:, np.newaxis, :], input_embedding])
    input_embedding   = sum_layer([input_embedding, positional_encoding])
    model_output      = decoder({'inputs_embeds': input_embedding})['last_hidden_state']
    out_scores        = [output_dense_layers[i](model_output)[:,:-1,:] for i in range(len(output_dense_layers))]
    # We don't care about the last scores, since they refer to a token that's out of bounds.
    if use_masking_layers:
        type_mask           = type_masking_layer(songs, training=True)[:,:-1,:]
        types_probabilities = output_probs_layers[0](out_scores[0], type_mask)
        full_mask           = activations_masking([songs, out_scores, types_probabilities])
        index = 0;    masks = []          # Unpack the final masks into a list of masks
        for key in conf.INPUT_RANGES:
            if key != 'type':
                masks.append(full_mask[:, :, index:index+conf.INPUT_RANGES[key]])
                index += conf.INPUT_RANGES[key]
        out_probabilities = [types_probabilities] + [
            output_probs_layers[i](out_scores[i], masks[i-1]) 
            for i in range(1, len(output_dense_layers))]
    else:
        out_probabilities = [output_probs_layers[i](out_scores[i]) for i in range(len(output_dense_layers))]

    # Create model
    model = tf.keras.Model(inputs=[songs, genres], outputs=out_probabilities, name='music_generation_model')
    
    # Define loss
    def custom_loss(songs, y_pred):
        gt_vectors = [songs[:,:,i] for i in range(len(conf.INPUT_RANGES))]
        # Base loss term
        losses = []
        for i in range(len(y_pred)):
            losses.append(tf.math.reduce_sum(
                tf.cast(loss_function(gt_vectors[i], y_pred[i]), tf.float32) * \
                (1. / conf.GLOBAL_BATCH_SIZE)))
        return tf.math.reduce_sum(losses)
    
    # Define regularizers
    def custom_regularizers(songs, y_pred):
        gt_vectors = [songs[:,:,i] for i in range(len(conf.INPUT_RANGES))]
        # Regularization loss: transform the actual vectors into consecutive-type representation
        types = gt_vectors[0]
        max_pred_types = tf.argmax(y_pred[0], axis=2, output_type=tf.int32)
        consecutive_gt_types   = subsequent_type_transform_layer(types)
        consecutive_pred_types = subsequent_type_transform_layer(max_pred_types)
        # Compute difference
        differences = consecutive_pred_types[:, 1:] - consecutive_pred_types[:, :-1]
        # Compute regularization terms
        # Difference between one element's type and the next is >= 0
        reg_term_1 = tf.math.reduce_sum(tf.math.maximum(0, -differences))
        # Difference between one element's type and the next is < 1
        reg_term_2 = tf.math.reduce_sum(tf.math.maximum(0, tf.math.maximum(1, differences) - 1))
        return reg_scaler * tf.cast(reg_term_1, tf.float32) + reg_scaler * tf.cast(reg_term_2, tf.float32)
    
    # Add losses
    model.add_loss(custom_loss(songs, out_scores))
    if use_regularization:
        model.add_loss(custom_regularizers(songs, out_scores))
    
    # Compile and return
    model.compile(optimizer="adam")
    return model

In [4]:
if conf.num_devices > 1:
    print("Using multiple GPUs with Mirrored Strategy")
    with conf.training_strategy.scope():
        model = create_model()
else:
    print("Using single GPU/CPU device")
    model = create_model()

Using multiple GPUs with Mirrored Strategy
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [5]:
model.summary()
# tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

Model: "music_generation_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 songs (InputLayer)             [(None, 6143, 11)]   0           []                               
                                                                                                  
 tf.__operators__.getitem_1 (Sl  (None, 6143)        0           ['songs[0][0]']                  
 icingOpLambda)                                                                                   
                                                                                                  
 tf.__operators__.getitem_2 (Sl  (None, 6143)        0           ['songs[0][0]']                  
 icingOpLambda)                                                                                   
                                                                             

We can test the model with some inputs from our dataset

In [6]:
DATASET_PATH = os.path.join('..', 'data', 'tf_data7')
dataset = tf.data.Dataset.load(DATASET_PATH).batch(conf.GLOBAL_BATCH_SIZE).cache().shuffle(conf.SHUFFLE_SIZE).prefetch(conf.PREFETCH_SIZE)

In [7]:
X, y = next(dataset.take(1).as_numpy_iterator())

2022-11-27 03:33:30.144708: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [8]:
output = model([X, y])
print([x.shape for x in output])

[TensorShape([12, 6143, 8]), TensorShape([12, 6143, 256]), TensorShape([12, 6143, 131]), TensorShape([12, 6143, 128]), TensorShape([12, 6143, 136]), TensorShape([12, 6143, 256]), TensorShape([12, 6143, 129]), TensorShape([12, 6143, 128]), TensorShape([12, 6143, 25]), TensorShape([12, 6143, 153]), TensorShape([12, 6143, 49])]


In [9]:
model.losses

[<tf.Tensor: shape=(), dtype=float32, numpy=656649.5>,
 <tf.Tensor: shape=(), dtype=float32, numpy=7.8700004>]

TODO: There is something weird with the multi-GPU loss. I bet I have to divide for the global batch size or something.