In [1]:
from transformers import GPT2Config, TFGPT2Model
import tensorflow as tf
import numpy as np
import os

import utils
import config

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
conf = config.Config("single_instruments_type", ROOT_PATH)

  from .autonotebook import tqdm as notebook_tqdm
2022-11-21 15:51:23.948439: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-21 15:51:24.093552: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-21 15:51:24.595014: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/lib64
2022-11-21 15:51:24.595054: W tensorflow/stream_executor/platform/default/dso_loa

In [3]:
min_beat = 120
allowed_time_sign = 18

max_beat = utils.get_max_beat_from_time_sign(allowed_time_sign, conf) # TODO: implement it
print(max_beat)

beat_mask = np.asarray(
    [False]*min_beat + \
    [True]*(max_beat-min_beat) + \
    [False]*(conf.INPUT_RANGES["beat"]-max_beat)
, dtype=bool)

print(beat_mask)
print(len(beat_mask))

131
[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
  True  True  True  True  True  True  True  True  True  True  True]
131


In [10]:
# dataset = tf.data.Dataset.load(conf.lmda_genres_tf_data_path)   \
dataset = tf.data.Dataset.load(conf.tf_data_path)   \
    .shuffle(conf.SHUFFLE_SIZE)                                 \
    .batch(conf.BATCH_SIZE)                                     \
    .prefetch(conf.PREFETCH_SIZE)                               \
    .cache()                                                    


In [11]:
class MusicGenerator(tf.keras.Model):

    def __init__(self, conf: config.Config, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.conf = conf
        self.SEQ_LEN = conf.SEQ_LEN
        self.TOKEN_DIM = conf.TOKEN_DIM
        self.INPUT_RANGES = conf.INPUT_RANGES

        self.dense_genre_emb = tf.keras.layers.Dense(self.TOKEN_DIM)
        self.embeddings = conf.embedding_layers
        self.concat_layer = tf.keras.layers.Concatenate(axis=2)

        self.pos_embedding_matrix = conf.get_positional_embedding_matrix()
        self.positional_embeddings = tf.stack([self.pos_embedding_matrix]*conf.BATCH_SIZE)
        self.sum_layer = tf.keras.layers.Add()

        self.decoder = conf.get_decoder()

        self.output_dense_layers = conf.output_dense_layers
        self.full_mask = conf.full_mask
        self.default_mask = conf.default_mask

        self.masked_activations = [tf.keras.layers.Softmax()]*len(self.embeddings)


    def apply_activations(self, logits, masks):
        return [activation(elem, mask) for elem, mask, activation in zip(logits, masks, self.masked_activations)]


    def get_mask(self, 
            default = [], 
            min_measure = None, 
            min_beat = None, 
            min_position = None,
            note = False,
            allowed_instruments = None,
            allowed_key_sign = None,
            allowed_time_sign = None,
            allowed_tempo = None,
            forbidden_key_sign = None,
            forbidden_time_sign = None,
            forbidden_tempo = None
        ):

        if len(default) > 0:
            return [self.default_mask[i] if default[i] else self.full_mask[i] for i in range(len(default))]
        else:
            measure_mask = np.asarray([False]*min_measure + [True]*(self.INPUT_RANGES["measure"]-min_measure), dtype=bool)

            if min_beat == None:
                beat_mask = self.default_mask[2]
            else: # oss: allowed_time_sign is always != None if min_beat != None
                max_beat = conf.get_max_beat_from_time_sign(allowed_time_sign)
                # allowed beats are only AFTER previous beat and BEFORE max_beat from the numerator of the time_sign
                beat_mask = np.asarray(
                    [False]*min_beat + \
                    [True]*(max_beat-min_beat) + \
                    [False]*(self.INPUT_RANGES["beat"]-max_beat)
                , dtype=bool)
            

            if min_position == None:
                position_mask = self.default_mask[3]
            else:
                position_mask = np.asarray([False]*min_position + [True]*(self.INPUT_RANGES["position"]-min_position), dtype=bool)

            if note:
                duration_mask = self.full_mask[4]
                pitch_mask = self.full_mask[5]
                velocity_mask = self.full_mask[7]
            else:
                duration_mask = self.default_mask[4]
                pitch_mask = self.default_mask[5]
                velocity_mask = self.default_mask[7]

            if allowed_instruments == None:
                instruments_mask = self.default_mask[6]
            else:
                instruments_mask = np.asarray([True if i in allowed_instruments else False for i in range(self.INPUT_RANGES["instrument"])], dtype=bool)

            if allowed_key_sign == None:
                if forbidden_key_sign == None:
                    raise AssertionError("Cannot have both allowed and forbidden key_sign not instanciated")
                else:
                    key_sign_mask = np.asarray([False if i == forbidden_key_sign else True for i in range(self.INPUT_RANGES["key_sign"])], dtype=bool)
            else:
                key_sign_mask = np.asarray([True if i == allowed_key_sign else False for i in range(self.INPUT_RANGES["key_sign"])], dtype=bool)

            if allowed_time_sign == None:
                if forbidden_time_sign == None:
                    raise AssertionError("Cannot have both allowed and forbidden time_sign not instanciated")
                else:
                    time_sign_mask = np.asarray([False if i == forbidden_time_sign else True for i in range(self.INPUT_RANGES["time_sign"])], dtype=bool)
            else:
                time_sign_mask = np.asarray([True if i == allowed_time_sign else False for i in range(self.INPUT_RANGES["time_sign"])], dtype=bool)

            if allowed_tempo == None:
                if forbidden_tempo == None:
                    raise AssertionError("Cannot have both allowed and forbidden tempo not instanciated")
                else:
                    tempo_mask = np.asarray([False if i == forbidden_tempo else True for i in range(self.INPUT_RANGES["tempo"])], dtype=bool)
            else:
                tempo_mask = np.asarray([True if i == allowed_tempo else False for i in range(self.INPUT_RANGES["tempo"])], dtype=bool)


            return [
                self.full_mask[0], # type is not masked
                measure_mask,
                beat_mask,
                position_mask,
                duration_mask,
                pitch_mask,
                instruments_mask,
                velocity_mask,
                key_sign_mask,
                time_sign_mask,
                tempo_mask
            ]


    def mask_and_activate_outputs(self, out_logits, song):
        '''
        Takes as input:
            - out_logits: the scores outputted by the decoder + dense layers
            - song: the input song, token by token

        This function, based on the chosen token "type" given by the first dense layer,
        masks all the different parts of the token accordingly, also taking into consideration
        the previous tokens of the song

        Output:
            - probabilities for each different part of the predicted token (the following one in the song)
        
        '''

        max_type = max(song[:,0])
        # TODO: is there any point in which the model has to guess "start of song"?
        if max_type == 0: # only start of song token
            # cannot be anything else than instrument choice (1)
            type_mask = np.asarray([False, True, False, False, False, False, False, False], dtype=bool)
        elif max_type == 1: # we reached instrument choice
            # cannot be anything else than instrument choice (1) or start of events (2)
            type_mask = np.asarray([False, True, True, False, False, False, False, False], dtype=bool)
        elif max_type >= 2: # we reached start of events or notes
            type_mask = np.asarray([False, False, False, True, True, True, True, True], dtype=bool)
            
        type_scores = self.masked_activations[0](out_logits[0], type_mask) # the first masked activation is for the type
        
        chosen_type = np.argmax(type_scores)

        # instrument selection
        if chosen_type == 1: # false only for type and instrument type (the ones that you can choose)
            mask = self.get_mask(default = [False, True, True, True, True, True, False, True, True, True, True])
        # start of events
        elif chosen_type == 2: # false only for type (cannot choose anything in "start of events" token)
            mask = self.get_mask(default = [False, True, True, True, True, True, True, True, True, True, True])
        # notes
        elif chosen_type == 3: # note: has same key_sign, time_sign and tempo as last previous event, everything has to be manually decided
            
            mask = self.get_mask(
                min_measure = song[-1,1],
                min_beat = song[-1,2],
                min_position = song[-1,3],
                note = True,
                allowed_instruments = np.unique(song[np.where(song[:,0] == 2),6]), # if type == 2 --> read the instruments (unique = set)
                allowed_key_sign =  song[np.where(song[:,0] == 4), 8][0][-1], # if type == 4 --> read the LAST key_sign
                allowed_time_sign = song[np.where(song[:,0] == 5), 9][0][-1], # if type == 5 --> read the LAST time_sign
                allowed_tempo =     song[np.where(song[:,0] == 6),10][0][-1]  # if type == 6 --> read the LAST tempo
            )

        # key_sign, time_sign, tempo
        elif chosen_type == 4 or chosen_type == 5 or chosen_type == 6:
            # if last event is at the beginning of a measure, you can add an event at the same time
            if song[-1,3] == 0 and song[-1,2] == 0: 
                min_measure = song[-1,1]
            # otherwise it goes to the next measure
            else:
                min_measure = song[-1,1] + 1
            
            if chosen_type == 4:
                mask = self.get_mask(
                    min_measure = min_measure,
                    forbidden_key_sign = song[np.where(song[:,0] == 4), 8][0][-1] # cannot put the same key_sign again
                )

            if chosen_type == 5:
                mask = self.get_mask(
                    min_measure = min_measure,
                    forbidden_key_sign = song[np.where(song[:,0] == 5), 9][0][-1] # cannot put the same time_sign again
                )
            
            if chosen_type == 6:
                mask = self.get_mask(
                    min_measure = min_measure,
                    forbidden_key_sign = song[np.where(song[:,0] == 6),10][0][-1] # cannot put the same tempo again
                )
        
        elif chosen_type == 7: # end of song --> only type can be chosen, all the others are default
            mask = self.get_mask(default = [False, True, True, True, True, True, True, True, True, True, True])

        else:
            raise ValueError("Impossible that chosen type isn't in [1,7] --> {}".format(chosen_type))

        # TODO: check if it works for every type!
        return [masked_act(type_logit, type_mask) for masked_act, type_logit, type_mask in zip(self.masked_activations, out_logits, mask)]


    def call(self, inputs):

        # to train you need to add a "end_song" input --> in this way you create the attention mask
        # for the decoder
        end_song = 0

        if type(inputs) == dict:
            song = inputs["song"]
            genre = inputs["genre"]
            if "end_song" in inputs.keys():
                end_song = inputs["end_song"]

        elif type(inputs) == tuple:
            song = inputs[0]
            genre = inputs[1]
            if len(inputs) == 3:
                end_song = inputs[2]
            elif len(inputs) == 2:
                found_end_song = len( tmp := np.where(song[:,0] == 7)[0])
                if found_end_song == 0:
                    raise ValueError("Passed a song without an end_token and without the 'end_song' input")
                elif found_end_song == 1:
                    end_song = tmp
            else:
                raise ValueError("Inputs are incorrect for this model: pass song and genre (optional:end_song)")

        assert end_song != 0, "Something wrong with end_song"

        # EMBEDDING GENERATION for decoder
        genre_embedding = self.dense_genre_emb(genre)

        # TODO: check how they come out (should be 11*64 numbers for each line, and 6143 lines for each song in batch)
        # TODO: batch?
        token_embeddings = [self.embeddings[i](song[:,i]) for i in range(len(self.embeddings))]

        final_embeddings = self.concat_layer([genre_embedding, token_embeddings])

        # ATTENTION MASK TODO: batch?
        attention_mask = [1]*end_song + [0]*(self.SEQ_LEN-end_song)

        decoder_output = self.decoder(
            input_embeds = final_embeddings,
            attention_mask = attention_mask,
            position_ids = self.positional_embeddings
        )

        out_logits = [layer(decoder_output["last_hidden_state"]) for layer in self.output_dense_layers]

        out_scores = self.mask_and_activate_outputs(out_logits, song)
        
        return genre_embedding, token_embeddings, final_embeddings, attention_mask

    def train_step(self, data):
        song, genre = data

        with tf.GradientTape() as tape:
            # TODO: how to do it for batches?
            trainable_vars = self.trainable_variables
            # we only want to predict up to the real finish of the song
            # but they are padded --> for each song we stop at the token BEFORE the "end_song" token
            # we predict as last token the "end_song" one 
            for i, y in enumerate(np.where(song==7)-1): # TODO: for batches they end in different indexes!!
                y_pred = self((
                    song[:i], 
                    genre
                ))

                loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
                gradients = tape.gradient(loss, trainable_vars)

                self.optimizer.apply_gradients(zip(gradients, trainable_vars))
                self.compiled_metrics.update_state(y, y_pred)
        
        return {m.name: m.result() for m in self.metrics}







In [14]:
model = MusicGenerator(conf)

for song_batch, genre_batch in dataset.take(1):
    tf.gather(song_batch)
    notes = [tf.slice(song_batch, [0, 0, i], [-1, 0, i]) for i in range(8192)]
    print(notes[0])

2022-11-17 18:59:03.635505: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


InvalidArgumentError: {{function_node __wrapped__Slice_device_/job:localhost/replica:0/task:0/device:CPU:0}} Expected size[2] in [0, 5], but got 6 [Op:Slice]