In [2]:
# from transformers import GPT2Config, TFGPT2Model
import tensorflow as tf
import numpy as np
import os

physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
  try:
    tf.config.experimental.set_memory_growth(physical_device, True)
  except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

import utils
import config

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
conf = config.Config("single_instruments_type", ROOT_PATH)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# dataset = tf.data.Dataset.load(conf.lmda_genres_tf_data_path)   \
dataset = tf.data.Dataset.load(conf.tf_data_path)               \
    .cache()                                                    \
    .shuffle(conf.SHUFFLE_SIZE)                                 \
    .batch(conf.BATCH_SIZE)                                     \
    .prefetch(conf.PREFETCH_SIZE)                                                 

In [4]:
song_batch = next(dataset.take(1).as_numpy_iterator())[0][:, :conf.SEQ_LEN-1, :]
print("Song_shape: {}\n".format(song_batch.shape))

decoder_output = tf.random.uniform((conf.BATCH_SIZE, conf.SEQ_LEN, conf.TOKEN_DIM), minval=-1, maxval=1)
print(f"Decoder output shape: {decoder_output.shape}\n")

out_logits = [layer(decoder_output) for layer in conf.output_dense_layers]
for i, out_logit_part in enumerate(out_logits):
    print(f"Output logit #{i}: {out_logit_part.shape}")

Song_shape: (2, 1023, 11)

Decoder output shape: (2, 1024, 512)

Output logit #0: (2, 1024, 8)
Output logit #1: (2, 1024, 256)
Output logit #2: (2, 1024, 131)
Output logit #3: (2, 1024, 128)
Output logit #4: (2, 1024, 136)
Output logit #5: (2, 1024, 256)
Output logit #6: (2, 1024, 129)
Output logit #7: (2, 1024, 128)
Output logit #8: (2, 1024, 25)
Output logit #9: (2, 1024, 153)
Output logit #10: (2, 1024, 49)


In [109]:
class MaskTypeProbabilitiesLayer(tf.keras.layers.Layer):
    def __init__(self, trainable=False, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)

    @tf.function
    def create_mask(self, inputs):
        batch_gt_types = inputs
        mask = tf.TensorArray(tf.bool, size=conf.SEQ_LEN)
        mask = mask.write(0, tf.constant([True, False, False, False, False, False, False, False], dtype=tf.bool))
        for i in tf.range(conf.SEQ_LEN-1):
            token_type = batch_gt_types[i]
            if token_type == 0: # only start of song token: cannot be anything else than instrument choice (1)
                type_mask = tf.constant([False, True, False, False, False, False, False, False], dtype=tf.bool)
            elif token_type == 1: # we reached instrument choice: cannot be anything else than instrument choice (1) or start of events (2)
                type_mask = tf.constant([False, True, True, False, False, False, False, False], dtype=tf.bool)
            elif token_type >= 2 and token_type < 7: # we reached start of events or notes
                type_mask = tf.constant([False, False, False, True, True, True, True, True], dtype=tf.bool)
            elif token_type == 7: # at the end of the song we can ONLY GUESS "000000000" TODO: change ending token to type 7s -> 7000000000
                type_mask = tf.constant([True, False, False, False, False, False, False, False], dtype=tf.bool)
            else:
                # ERROR. Define a random type mask so that it's defined in all branches for tf.function
                type_mask = tf.constant([False, False, False, False, False, False, False, False], dtype=tf.bool)
            mask = mask.write(i+1, type_mask)
        return mask.stack()

    def call(self, inputs, training=True):
        '''
        Takes as input the ground truth song (at training time) or the logits (at testing time) 
        and computes a mask for the type probabilities.
        '''
        if training:
            # Use the groundtruth song as a target
            song        = inputs
            gt_types    = song[:,:,0]       # Get the token types from the song (batch_size x seq_len-1)
            # Iterate over the batch to collect the appropriate masks from the song
            masks = tf.map_fn(fn=self.create_mask, 
                elems=gt_types, 
                fn_output_signature=tf.TensorSpec(
                    (conf.SEQ_LEN, conf.INPUT_RANGES['type']), 
                    dtype=tf.bool)
            )
            return masks
        else:
            # Compute the types and their masks one by one based on the type chosen at the previous iteration
            # TODO: implement this branch
            pass

mask_probabilities = MaskTypeProbabilitiesLayer()(song_batch, training=True)
mask_probabilities.shape

TensorShape([2, 1024, 8])

In [110]:
# With these masks we can compute the probabilities for the token types
activations = [tf.keras.layers.Softmax()]*len(conf.INPUT_RANGES)

types_probabilities = activations[0](out_logits[0], mask_probabilities) # (last out logit predicts a token that's out of bound in our sequence)
types_probabilities[0, :5]

<tf.Tensor: shape=(5, 8), dtype=float32, numpy=
array([[1.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 1.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.4315946 , 0.56840533, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.0821274 , 0.91787255, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.29944634, 0.70055366, 0.        , 0.        ,
        0.        , 0.        , 0.        ]], dtype=float32)>

In [129]:
# Now the second part of the layer: given the type probabilities, compute the other constraints
class MaskingActivationLayer(tf.keras.layers.Layer):
    def __init__(self, trainable=False, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)
        self.default_mask = conf.default_mask
        self.full_mask    = conf.full_mask
        self._numerators  = conf.numerators
        self._tot_numerators = conf.tot_numerators
        self.triang_mask = tf.cast(
            tf.repeat(
                tf.expand_dims(np.tri(conf.SEQ_LEN-1), axis=-1),     # Much more efficient to do it like this: it's a lower triangular matrix
                repeats=len(conf.INPUT_RANGES), axis=-1), 
            dtype=tf.float32)                                        # Create a seq_len x seq_len x 10 float tensor

    @tf.function
    def get_max_beat_from_time_sign(self, time_sign):
        idx = tf.math.floormod(time_sign, self._tot_numerators)
        return tf.gather(self._numerators, idx) - 1

    def get_mask(self, inputs):
        chosen_type, song, scores, index_tensor = inputs
        chosen_type  = tf.cast(chosen_type, dtype=tf.int32)               # 1
        song         = tf.cast(song, dtype=tf.int32)                      # (SEQ_LEN-1) * 11
        index_tensor = tf.cast(index_tensor, dtype=tf.int32)              # 1
        default_token_parts = [True]*(len(conf.INPUT_RANGES)-1)
        default_flag = False

        # DEFAULT CREATION (tensorflow requires the variable to be present in all possible branches when if/else are present)
        min_measure           = tf.constant(-1, dtype=tf.int32)
        min_beat              = tf.constant(-1, dtype=tf.int32)
        min_position          = tf.constant(-1, dtype=tf.int32)
        allowed_instruments   = tf.constant([0]*conf.INPUT_RANGES["instrument"], dtype=tf.int32)
        allowed_key_sign      = tf.constant(-1, dtype=tf.int32)
        allowed_time_sign     = tf.constant(-1, dtype=tf.int32)
        allowed_tempo         = tf.constant(-1, dtype=tf.int32)
        forbidden_instruments_flag = False
        forbidden_instruments = tf.constant([0]*conf.INPUT_RANGES["instrument"], dtype=tf.int32)
        forbidden_key_sign    = tf.constant(-1, dtype=tf.int32)
        forbidden_time_sign   = tf.constant(-1, dtype=tf.int32)
        forbidden_tempo       = tf.constant(-1, dtype=tf.int32)
        
        # Check all different possibilities

        if chosen_type == 0 or chosen_type == 2: # TODO: change 0s to 7s at the end of the song
            # Original comments: 
            # only way it chooses 0 is that max_type==7 --> AFTER END OF SONG --> only thing the model can do is guess all zeros
            # "does not have to learn nothing" --> it's all zeros just like the padding tensors
            default_token_parts = [True, True, True, True, True, True, True, True, True, True]
            default_flag = True

        # Instrument selection
        elif chosen_type == 1: # false only for type and instrument type (the ones that you can choose)
            # TODO: this function was GREATLY changed. Is it ok?    
            if tf.size(tf.where(song[:index_tensor, 0] == 1)[:,0]) == 0:
                # Choice of first instrument
                default_token_parts = [True, True, True, True, True, False, True, True, True, True]  # TODO: Element 6 should not be default = True right?
                default_flag = True
            else:
                # TODO: Original is this:
                # forbidden_instruments, _ = tf.unique(tf.gather(
                #     song[:, 6],
                #     tf.squeeze(tf.where(song[:, 0] == 2))           
                # ))
                # I don't think it does what it's supposed to. 
                # Forbidden instruments should be a 1D tensor of all previously defined instruments, right?
                # Instruments are defined with type 1, right?
                forbidden_instruments, _ = tf.unique(tf.gather(
                    song[:index_tensor, 6], 
                    tf.where(song[:index_tensor, 0] == 1)[:,0]        # Cast to 1D array
                ))
                forbidden_instruments_flag = True
        
        # Notes
        elif chosen_type == 3: # Note has same key_sign, time_sign and tempo as last previous event, everything has to be manually decided
            min_measure = song[index_tensor, 1]   # It has to be >= than the last measure
            # If in the MEASURE SCORES the MAX SCORE between all possible measures == min_measure, the measure is min_measure.
            # In this case, we need to make sure that beat >= last_beat
            # TODO: I changed this. Is it okay? We are trying to get the part of the score that's between 0 and 256 right? (there is no type in my scores)
            if tf.math.argmax(
                scores[:conf.INPUT_RANGES["measure"]], 
                    output_type=tf.int32) == min_measure:
                
                min_beat = song[index_tensor,2]      # It has to be >= than the last beat when measure is the same
                if tf.math.argmax(scores[
                    conf.INPUT_RANGES["measure"] : 
                    conf.INPUT_RANGES["measure"] + conf.INPUT_RANGES["beat"]], 
                    output_type=tf.int32) == min_beat:

                    min_position = song[index_tensor,3]  # It has to be >= than the last position (if beat and measure are the same)
                else:
                    min_position = tf.constant(0, dtype=tf.int32)
            else:
                min_beat = tf.constant(0, dtype=tf.int32)
                min_position = tf.constant(0, dtype=tf.int32)
            
            # Only some instruments, key signs, time signs and tempos are allowed for these events: 
            # - for instruments, the allowed ones are the ones that have been defined previously with type = 1
            # - for the others, the allowed ones are the ones that are collected right before the note from event types 4, 5 and 6
            allowed_instruments, _ = tf.unique(tf.gather(
                song[:index_tensor, 6], 
                tf.where(song[:index_tensor, 0] == 1)[:,0]
            ))
            # TODO: There are cases where there is not a LAST key_sign/time_sign, ...
            # If the model chooses 3, we cannot be certain that there is at least a 4, 5 or 6 before it
            # In these cases we use the default masks
            allowed_key_signs = tf.gather(
                song[:index_tensor, 8], 
                tf.where(song[:index_tensor, 0] == 4)[:,0]) # if type == 4 --> read the LAST key_sign
            if tf.size(allowed_key_signs) > 0:
                allowed_key_sign = allowed_key_signs[-1]

            allowed_time_signs = tf.gather(
                song[:index_tensor, 9], 
                tf.where(song[:index_tensor, 0] == 5)[:,0]) # if type == 5 --> read the LAST time_sign
            if tf.size(allowed_time_signs) > 0:
                allowed_time_sign = allowed_time_signs[-1]
            
            allowed_tempos = tf.gather(
                song[:index_tensor, 10], 
                tf.where(song[:index_tensor, 0] == 6)[:,0]) # if type == 6 --> read the LAST tempo
            if tf.size(allowed_tempos) > 0:
                allowed_tempo = allowed_tempos[-1]
        
        # key_sign, time_sign, tempo
        elif chosen_type >= 4 and chosen_type <= 6:
            # If last event is at the beginning of a measure, you can add an event at the same time
            if song[index_tensor, 3] == 0 and song[index_tensor, 2] == 0:  # if beat and position == 0, the event can be at this measure
                min_measure = song[index_tensor, 1]
            else:
                min_measure = song[index_tensor, 1] + 1                   # otherwise it goes to the next measure
            # Fine-grain checks
            # TODO: As before, there are cases where there is not a LAST key_sign/time_sign. 
            # In these cases we should use the default masks.
            if chosen_type == 4:
                # Cannot put the same key_sign again
                forbidden_key_signs = tf.gather(
                    song[:index_tensor, 8], 
                    tf.where(song[:index_tensor, 0] == 4)[:,0]) # if type == 4 --> read the LAST key_sign
                if tf.size(forbidden_key_signs) > 0:
                    forbidden_key_sign = forbidden_key_signs[-1]
            elif chosen_type == 5:
                # Cannot put the same time_sign again
                forbidden_time_signs = tf.gather(
                    song[:index_tensor, 9], 
                    tf.where(song[:index_tensor, 0] == 5)[:,0]) # if type == 5 --> read the LAST time_sign
                if tf.size(forbidden_time_signs) > 0:
                    forbidden_time_sign = forbidden_time_signs[-1]
            elif chosen_type == 6:
                # Cannot put the same tempo again
                forbidden_tempos = tf.gather(
                    song[:index_tensor, 10], 
                    tf.where(song[:index_tensor, 0] == 6)[:,0]) # if type == 6 --> read the LAST tempo
                if tf.size(forbidden_tempos) > 9:
                    forbidden_tempo = forbidden_tempos[-1]

        elif chosen_type == 7: # end of song --> only type can be chosen, all the others are default
            default_token_parts = [True, True, True, True, True, True, True, True, True, True]
            default_flag = True

        else:
            # cannot RAISE inside tf graph, because it WILL pass from every path
            default_token_parts = [True, True, True, True, True, True, True, True, True, True]
            default_flag = True
            # raise ValueError("Impossible that chosen type isn't in [1,7] --> {}".format(chosen_type))
        
        # Put together the masks
        if default_flag: 
            # No manual masking required, either "can freely choose this part of the token" (True) or 
            # "can only choose default for this part of the token" (False)
            return tf.concat(
                    # Default mask only allows to predict a 0
                    # Full mask allows to predict any value
                    [self.default_mask[i] if default_token_parts[i] else self.full_mask[i] 
                        for i in range(len(default_token_parts))],
                    axis=-1)
        else: 
            # We need to do manual masking
            measure_mask     = self.default_mask[0]
            beat_mask        = self.default_mask[1]
            position_mask    = self.default_mask[2]
            duration_mask    = self.default_mask[3]
            pitch_mask       = self.default_mask[4]
            instruments_mask = self.default_mask[5]
            velocity_mask    = self.default_mask[6]
            key_sign_mask    = self.default_mask[7]
            time_sign_mask   = self.default_mask[8]
            tempo_mask       = self.default_mask[9]
            
            if not forbidden_instruments_flag:
                # TODO: I didn't understand this comment
                # Measure mask, beat and position go to default if type==2 and forbidden_instruments_flag == True
                # so if forbidden_instruments_flag == False --> you can change it
                measure_mask = tf.cast(
                    tf.concat([
                        tf.repeat([False], min_measure),        # Can be equal to or greater than min_measure
                        tf.repeat([True], conf.INPUT_RANGES["measure"]-min_measure)], 
                        axis=-1),
                    dtype=tf.dtypes.bool)

                if min_beat != -1:
                    # oss: allowed_time_sign is always != None if min_beat != None
                    # TODO: Did not understand this function
                    max_beat = self.get_max_beat_from_time_sign(allowed_time_sign)
                    # allowed beats are only AFTER previous beat and BEFORE max_beat from the numerator of the time_sign
                    beat_mask = tf.cast(tf.concat([
                        tf.repeat([False], min_beat),
                        tf.repeat([True],  max_beat-min_beat), 
                        tf.repeat([False], conf.INPUT_RANGES["beat"]-max_beat)],
                        axis=-1), 
                    dtype=tf.dtypes.bool)

                if min_position != -1:
                    position_mask = tf.cast(tf.concat([
                        tf.repeat([False], min_position), 
                        tf.repeat([True],  conf.INPUT_RANGES["position"]-min_position)],
                        axis=-1), 
                    dtype=tf.dtypes.bool)

            else:
                instruments_mask = tf.sparse.SparseTensor(  # Forbidden instruments
                    indices= tf.expand_dims(tf.cast(forbidden_instruments, tf.int64), axis=-1),
                    values = tf.zeros_like(forbidden_instruments),
                    dense_shape=[conf.INPUT_RANGES["instrument"]]
                )
                instruments_mask = tf.cast(
                    tf.sparse.to_dense(tf.sparse.reorder(instruments_mask), default_value=1), 
                    dtype=tf.dtypes.bool)
                # FORBIDDEN INSTRUMENTS is ONLY USED WHEN type==1 --> measure_mask, beat, position are all default

            if chosen_type==3:
                # Mask that's true only for defined instruments
                instruments_mask = tf.sparse.SparseTensor( # Allowed instruments
                    indices=tf.expand_dims(tf.cast(allowed_instruments, tf.int64), axis=-1),
                    values=tf.ones_like(allowed_instruments),               # TODO: this was zeros_like. Shouldn't it be inverted tho?
                    dense_shape=[conf.INPUT_RANGES["instrument"]]
                )
                instruments_mask = tf.cast(
                    tf.sparse.to_dense(tf.sparse.reorder(instruments_mask), default_value=1), # TODO: the default value was 0. Shouldn't it be 1 tho?
                    dtype=tf.dtypes.bool)
                
                # TODO: I think this part should be indented like this
                # Deal with key signs and time signs
                if allowed_key_sign != -1:
                    key_sign_mask = tf.convert_to_tensor([
                        i == allowed_key_sign 
                        for i in range(conf.INPUT_RANGES["key_sign"])], 
                        dtype=tf.bool)
                elif forbidden_key_sign != -1:
                    # Inverse
                    key_sign_mask = tf.convert_to_tensor([
                        i != forbidden_key_sign 
                        for i in range(conf.INPUT_RANGES["key_sign"])], 
                        dtype=tf.bool)
                else: 
                    pass

                if allowed_time_sign != -1:
                    time_sign_mask = tf.convert_to_tensor([
                        i == allowed_time_sign 
                        for i in range(conf.INPUT_RANGES["time_sign"])], 
                        dtype=tf.bool)
                elif forbidden_time_sign != -1:
                        time_sign_mask = tf.convert_to_tensor([
                            i != forbidden_time_sign 
                            for i in range(conf.INPUT_RANGES["time_sign"])], 
                            dtype=tf.bool)
                else:
                    pass

                if allowed_tempo != -1:
                    tempo_mask = tf.convert_to_tensor([
                        i == allowed_tempo 
                        for i in range(conf.INPUT_RANGES["tempo"])], 
                        dtype=tf.bool)
                elif forbidden_tempo != -1:
                        tempo_mask = tf.convert_to_tensor([
                            i != forbidden_tempo 
                            for i in range(conf.INPUT_RANGES["tempo"])], 
                            dtype=tf.bool)
                else:
                    pass

            return tf.concat([
                measure_mask, beat_mask, position_mask, duration_mask,
                pitch_mask, instruments_mask, velocity_mask, key_sign_mask,
                time_sign_mask, tempo_mask], axis=-1)


    def get_mask_for_all_tokens(self, inputs): 
        '''
        Returns a list of ndarrays of bool type used for masking
        Inputs are for a SINGLE ELEMENT OF A BATCH of size SEQ_LEN*(1+11+1391) where 1391 is the summed length of logits (minus the type)
        '''
        # Collect inputs from longer tensor
        chosen_types, song, scores = inputs
        # Result = SEQ_LEN-1 matrices --> first one has only first token full (others are 0), second one has first two tokens full (others are 0) etc
        triang_song = tf.math.multiply(
            tf.repeat(tf.expand_dims(song, axis=0), repeats=[conf.SEQ_LEN-1], axis=0),
            self.triang_mask
        )
        # Indexes
        index_tensor = tf.range(conf.SEQ_LEN-1, dtype=tf.float32)
        # TODO: to speedup we could change it to vectorized_map, but we need to try it out
        final_masks = tf.map_fn(
            fn=self.get_mask, 
            elems=(
                chosen_types,    # (SEQ_LEN-1)*1                 --> 1
                triang_song,     # (SEQ_LEN-1)*(SEQ_LEN-1)*11    --> (SEQ_LEN-1)*11
                scores,          # (SEQ_LEN-1)*1391              --> 1391
                index_tensor     # (SEQ_LEN-1)*1                 --> 1
            ), 
            fn_output_signature=tf.TensorSpec(
                shape=(conf.input_ranges_sum - conf.INPUT_RANGES['type']),
                dtype=tf.bool  # TODO: change it accordingly to the output signature
        ))
        return final_masks

    def call(self, inputs, training=True):
        songs, out_logits, types_probabilities = inputs
        chosen_types  = tf.expand_dims(tf.math.argmax(types_probabilities[:,:-1], axis=2), axis=-1)
        concat_logits = tf.concat(out_logits[1:], axis=-1)                 # Concatenate all logits (except type) into a tensor batch_size x seq_len x 1391
        masks = tf.map_fn(fn=self.get_mask_for_all_tokens, elems=(         # Iterate function over batch dimension 
                tf.cast(chosen_types, concat_logits.dtype),                # BATCH*(SEQ_LEN-1)*1
                tf.cast(songs,   concat_logits.dtype),                     # BATCH*(SEQ_LEN-1)*11
                concat_logits[:, :conf.SEQ_LEN-1, :]                       # BATCH*(SEQ_LEN-1)*1391
            ), fn_output_signature=tf.TensorSpec(                          # Total: a BATCH * SEQ_LEN-1 * 1403 tensor
                (conf.SEQ_LEN-1, conf.input_ranges_sum - conf.INPUT_RANGES['type']),
                dtype=tf.bool
            ))
        return masks

In [142]:
song_input = tf.keras.layers.Input(shape=(conf.SEQ_LEN-1, len(conf.INPUT_RANGES)), dtype=tf.int8)

mask_type_probabilities_layer = MaskTypeProbabilitiesLayer()
final_masking_layer = MaskingActivationLayer()
activations = [tf.keras.layers.Softmax()]*len(conf.INPUT_RANGES)

mask_for_type_probabilities = mask_type_probabilities_layer(song_input, training=True)
type_probabilities = activations[0](out_logits[0], mask_for_type_probabilities)
final_mask = final_masking_layer([song_input, out_logits, type_probabilities])

# Unpack the final masks
index = 0
masks = []
for key in conf.INPUT_RANGES:
    if key != 'type':       # We have already checked for the type
        masks.append(final_mask[:, :, index:index+conf.INPUT_RANGES[key]])
        index += conf.INPUT_RANGES[key]

model = tf.keras.Model(inputs=song_input, outputs=masks)

In [143]:
masks = model(song_batch)
[mask.shape for mask in masks]

[TensorShape([2, 1023, 256]),
 TensorShape([2, 1023, 131]),
 TensorShape([2, 1023, 128]),
 TensorShape([2, 1023, 136]),
 TensorShape([2, 1023, 256]),
 TensorShape([2, 1023, 129]),
 TensorShape([2, 1023, 128]),
 TensorShape([2, 1023, 25]),
 TensorShape([2, 1023, 153]),
 TensorShape([2, 1023, 49])]

In [None]:
# TODO: Check that the masks are ok

In [None]:
# class MusicGenerator(tf.keras.Model):

#     def __init__(self, conf: config.Config, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#         self.conf = conf
#         self.SEQ_LEN = conf.SEQ_LEN
#         self.TOKEN_DIM = conf.TOKEN_DIM
#         self.INPUT_RANGES = conf.INPUT_RANGES

#         self.dense_genre_emb = tf.keras.layers.Dense(self.TOKEN_DIM)
#         self.embeddings = conf.embedding_layers
#         self.concat_layer = tf.keras.layers.Concatenate(axis=2)

#         self.pos_embedding_matrix = conf.get_positional_embedding_matrix()
#         self.positional_embeddings = tf.stack([self.pos_embedding_matrix]*conf.BATCH_SIZE)
#         self.sum_layer = tf.keras.layers.Add()

#         # TODO: add dense layer from embeddings to input decoder

#         self.decoder = conf.get_decoder()

#         self.output_dense_layers = conf.output_dense_layers

#         self.full_mask = conf.full_mask
#         self.default_mask = conf.default_mask

#         self.masked_activations = [tf.keras.layers.Softmax()]*len(self.embeddings)


#     def softargmax(x, beta=1e10):
#         x = tf.convert_to_tensor(x, dtype=tf.float64)
#         x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
#         return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1)


#     def apply_activations(self, logits, masks):
#         return [activation(elem, mask) for elem, mask, activation in zip(logits, masks, self.masked_activations)]


#     def get_mask(self, 
#             default = [], 
#             min_measure = None, 
#             min_beat = None, 
#             min_position = None,
#             note = False,
#             allowed_instruments = None,
#             allowed_key_sign = None,
#             allowed_time_sign = None,
#             allowed_tempo = None,
#             forbidden_key_sign = None,
#             forbidden_time_sign = None,
#             forbidden_tempo = None
#         ): 

#         '''
#         Returns a list of ndarrays of bool type used for masking
#         '''
#         if len(default) > 0: # no manual masking, either "can freely choose this part of the token" or "can only choose default for this part of the token"
#             return [self.default_mask[i] if default[i] else self.full_mask[i] for i in range(len(default))]
        
#         else: # manual masking

#             measure_mask = np.asarray([False]*min_measure + [True]*(self.INPUT_RANGES["measure"]-min_measure), dtype=bool)
#             # TODO: Implement BEAT MASK only if measure == last_measure
#             # TODO: Implement POSITION MASK only if measure == last_measure AND beat == last_beat

#             if min_beat == None:
#                 beat_mask = self.default_mask[2]
#             else: # oss: allowed_time_sign is always != None if min_beat != None
#                 max_beat = conf.get_max_beat_from_time_sign(allowed_time_sign)
#                 # allowed beats are only AFTER previous beat and BEFORE max_beat from the numerator of the time_sign
#                 beat_mask = np.asarray(
#                     [False]*min_beat + \
#                     [True]*(max_beat-min_beat) + \
#                     [False]*(self.INPUT_RANGES["beat"]-max_beat)
#                 , dtype=bool)
            

#             if min_position == None:
#                 position_mask = self.default_mask[3]
#             else:
#                 position_mask = np.asarray([False]*min_position + [True]*(self.INPUT_RANGES["position"]-min_position), dtype=bool)

#             if note:
#                 duration_mask = self.full_mask[4]
#                 pitch_mask = self.full_mask[5]
#                 velocity_mask = self.full_mask[7]
#             else:
#                 duration_mask = self.default_mask[4]
#                 pitch_mask = self.default_mask[5]
#                 velocity_mask = self.default_mask[7]

#             if allowed_instruments == None:
#                 instruments_mask = self.default_mask[6]
#             else:
#                 instruments_mask = np.asarray([True if i in allowed_instruments else False for i in range(self.INPUT_RANGES["instrument"])], dtype=bool)

#             if allowed_key_sign == None:
#                 if forbidden_key_sign == None:
#                     raise AssertionError("Cannot have both allowed and forbidden key_sign not instanciated")
#                 else:
#                     key_sign_mask = np.asarray([False if i == forbidden_key_sign else True for i in range(self.INPUT_RANGES["key_sign"])], dtype=bool)
#             else:
#                 key_sign_mask = np.asarray([True if i == allowed_key_sign else False for i in range(self.INPUT_RANGES["key_sign"])], dtype=bool)

#             if allowed_time_sign == None:
#                 if forbidden_time_sign == None:
#                     raise AssertionError("Cannot have both allowed and forbidden time_sign not instanciated")
#                 else:
#                     time_sign_mask = np.asarray([False if i == forbidden_time_sign else True for i in range(self.INPUT_RANGES["time_sign"])], dtype=bool)
#             else:
#                 time_sign_mask = np.asarray([True if i == allowed_time_sign else False for i in range(self.INPUT_RANGES["time_sign"])], dtype=bool)

#             if allowed_tempo == None:
#                 if forbidden_tempo == None:
#                     raise AssertionError("Cannot have both allowed and forbidden tempo not instanciated")
#                 else:
#                     tempo_mask = np.asarray([False if i == forbidden_tempo else True for i in range(self.INPUT_RANGES["tempo"])], dtype=bool)
#             else:
#                 tempo_mask = np.asarray([True if i == allowed_tempo else False for i in range(self.INPUT_RANGES["tempo"])], dtype=bool)


#             return [
#                 self.full_mask[0], # type is not masked
#                 measure_mask,
#                 beat_mask,
#                 position_mask,
#                 duration_mask,
#                 pitch_mask,
#                 instruments_mask,
#                 velocity_mask,
#                 key_sign_mask,
#                 time_sign_mask,
#                 tempo_mask
#             ]


#     def mask_and_activate_outputs(self, out_logits, song):
#         '''
#         Takes as input:
#             - out_logits: the scores outputted by the decoder + dense layers
#             - song: the input song, token by token

#         This function, based on the chosen token "type" given by the first dense layer,
#         masks all the different parts of the token accordingly, also taking into consideration
#         the previous tokens of the song

#         Output:
#             - probabilities for each different part of the predicted token (the following one in the song)
        
#         '''

#         max_type = tf.math.reduce_max(song[:,0])

#         # do not have to be tensors because loos shouldn't flow through them
#         if max_type == 0: # only start of song token
#             # cannot be anything else than instrument choice (1)
#             type_mask = np.asarray([False, True, False, False, False, False, False, False], dtype=bool)
#         elif max_type == 1: # we reached instrument choice
#             # cannot be anything else than instrument choice (1) or start of events (2)
#             type_mask = np.asarray([False, True, True, False, False, False, False, False], dtype=bool)
#         elif max_type >= 2 and max_type < 7: # we reached start of events or notes
#             type_mask = np.asarray([False, False, False, True, True, True, True, True], dtype=bool)

#         elif max_type == 7: # at the end of the song we can ONLY GUESS "000000000" TODO: change to zero
#             type_mask = np.asarray([True, False, False, False, False, False, False, False], dtype=bool)

#         type_scores = self.masked_activations[0](out_logits[0], type_mask) # the first masked activation is for the type
        
#         # needs to be differentiable (but the masks shouldn't need to be)
#         chosen_type = self.softargmax(type_scores)

#         # could change == i with x<i+eps and x>i-eps because it's softargmax and not argmax
#         if chosen_type == 0: # TODO: change to 7 # only way it chooses 0 is that max_type==7 --> AFTER END OF SONG --> only thing the model can do is guess all zeros
#             # "does not have to learn nothing" --> it's all zeros just like the padding tensors
#             mask = self.get_mask(default = [False, True, True, True, True, True, True, True, True, True, True])
#         # instrument selection
#         if chosen_type == 1: # false only for type and instrument type (the ones that you can choose)
#             mask = self.get_mask(default = [False, True, True, True, True, True, False, True, True, True, True])
#         # start of events
#         elif chosen_type == 2: # false only for type (cannot choose anything in "start of events" token)
#             mask = self.get_mask(default = [False, True, True, True, True, True, True, True, True, True, True])
#         # notes
#         elif chosen_type == 3: # note: has same key_sign, time_sign and tempo as last previous event, everything has to be manually decided
            
#             mask = self.get_mask(
#                 min_measure = song[-1,1],   # it has to be >= than the last measure
#                 min_beat = song[-1,2],      # it has to be >= than the last beat (if measure is the same)
#                 min_position = song[-1,3],  # it has to be >= than the last position (if beat and measure are the same)
#                 note = True,
#                 allowed_instruments = np.unique(song[np.where(song[:,0] == 2),6]), # if type == 2 --> read the instruments (unique = set)
#                 allowed_key_sign =  song[np.where(song[:,0] == 4), 8][0][-1], # if type == 4 --> read the LAST key_sign
#                 allowed_time_sign = song[np.where(song[:,0] == 5), 9][0][-1], # if type == 5 --> read the LAST time_sign
#                 allowed_tempo =     song[np.where(song[:,0] == 6),10][0][-1]  # if type == 6 --> read the LAST tempo
#             )

#         # key_sign, time_sign, tempo
#         elif chosen_type == 4 or chosen_type == 5 or chosen_type == 6:
#             # if last event is at the beginning of a measure, you can add an event at the same time
#             if song[-1,3] == 0 and song[-1,2] == 0:  # if beat and position == 0, can be this measure
#                 min_measure = song[-1,1]
#             # otherwise it goes to the next measure
#             else:
#                 min_measure = song[-1,1] + 1
            
#             if chosen_type == 4:
#                 mask = self.get_mask(
#                     min_measure = min_measure,
#                     forbidden_key_sign = song[np.where(song[:,0] == 4), 8][0][-1] # cannot put the same key_sign again
#                 )

#             if chosen_type == 5:
#                 mask = self.get_mask(
#                     min_measure = min_measure,
#                     forbidden_time_sign = song[np.where(song[:,0] == 5), 9][0][-1] # cannot put the same time_sign again
#                 )
            
#             if chosen_type == 6:
#                 mask = self.get_mask(
#                     min_measure = min_measure,
#                     forbidden_tempo = song[np.where(song[:,0] == 6),10][0][-1] # cannot put the same tempo again
#                 )
        
#         elif chosen_type == 7: # end of song --> only type can be chosen, all the others are default
#             mask = self.get_mask(default = [False, True, True, True, True, True, True, True, True, True, True])

#         else:
#             raise ValueError("Impossible that chosen type isn't in [1,7] --> {}".format(chosen_type))

#         # TODO: check if it works for every type!
#         return [masked_act(type_logit, type_mask) for masked_act, type_logit, type_mask in zip(self.masked_activations, out_logits, mask)]


#     def call(self, inputs):

#         # to train you need to add a "end_song" input --> in this way you create the attention mask
#         # for the decoder

#         if type(inputs) == dict:
#             song = inputs["song"]
#             genre = inputs["genre"]
#             attention_mask = inputs["attention_mask"] # TODO: remove if decoder outputs sequences

#         elif type(inputs) == tuple:
#             song = inputs[0]
#             genre = inputs[1]
#             attention_mask = inputs[2] # TODO: remove if decoder outputs sequences
            
#         # EMBEDDING GENERATION for decoder
#         genre_embedding = self.dense_genre_emb(genre)

#         # TODO: check how they come out (should be 11*64 numbers for each line, and 6143 lines for each song in batch)
#         # TODO: batch?
#         token_embeddings = [self.embeddings[i](song[:,i]) for i in range(len(self.embeddings))]

#         final_embeddings = self.concat_layer([genre_embedding, token_embeddings])

#         decoder_output = self.decoder(
#             input_embeds = final_embeddings,
#             attention_mask = attention_mask, # TODO: remove if decoder outputs sequences
#             position_ids = self.positional_embeddings
#         )

#         out_logits = [layer(decoder_output["last_hidden_state"]) for layer in self.output_dense_layers]
        
#         # insert for if decoder outputs sequences

#         out_scores = self.mask_and_activate_outputs(out_logits, song)
        
#         # return genre_embedding, token_embeddings, final_embeddings, attention_mask, out_logits, out_scores
#         return out_scores

    
#     def train_step(self, data):
#         song, genre = data

#         with tf.GradientTape() as tape:
#             # TODO: how to do it for batches?
#             trainable_vars = self.trainable_variables
#             # we only want to predict up to the real finish of the song
#             # but they are padded --> for each song we stop at the token BEFORE the "end_song" token
#             # we predict as last token the "end_song" one

#             # pass through the network THE ENTIRE SONG (even the last PADDED TOKENS! so the batch flows together)
#             # OR we could stop on the biggest last token in batch
#             for i, y in enumerate(song[1:]):
#                 # TODO: check and simplify if decoder output is already sequence

#                 # y is the current token
#                 y_pred = self((
#                     song,
#                     genre,
#                     np.asarray([1]*(i+1) + [0]*(self.SEQ_LEN-1-(i+1))) # attention mask
#                 ))

#                 loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
#                 gradients = tape.gradient(loss, trainable_vars)

#                 self.optimizer.apply_gradients(zip(gradients, trainable_vars))
#                 self.compiled_metrics.update_state(y, y_pred)
        
#         return {m.name: m.result() for m in self.metrics}