In [1]:
import os
from copy import deepcopy

import numpy as np
import tensorflow as tf
from tensorflow import keras

from typing import List, Dict, Tuple
from tqdm import trange

import config, music_model, utils

### CONFIGURATION ###

USE_DOUBLE_HEAD = True
USE_ONE_GPU = True

### CONSTANTS ###

ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
DATASET_NAME = 'lmd_matched_final_2048_cut'
MODEL_NAME = 'model_GPT_baseline_with_mse_vellmd_matched_2048' if not USE_DOUBLE_HEAD else \
             'model_GPT_baseline_with_mse_vel_lmd_matched_2048_double_head'
WEIGHTS_PATH = os.path.join(ROOT_PATH, 'training', 'checkpoints', MODEL_NAME, MODEL_NAME)
USE_SMALL_GENRE_SET = DATASET_NAME == 'tf_data7dict'

conf = config.Config("single_instruments_type", ROOT_PATH)

if USE_SMALL_GENRE_SET:
    conf.accepted_subgenres = ['folk', 'nes', 'maestro']
# If we need to use only the first GPU
if USE_ONE_GPU:
    conf.GPUS = tf.config.experimental.list_physical_devices('GPU')[0]
    conf.BATCH_SIZE = conf.BATCH_SIZE
    conf.GLOBAL_BATCH_SIZE = conf.BATCH_SIZE
    conf.num_devices = 1

### MODEL CREATION ###

if conf.num_devices > 1:
    print("Using multiple GPUs with Mirrored Strategy")
    with conf.training_strategy.scope():
        model = music_model.create_model(conf,
                                         num_genres=len(conf.accepted_subgenres),
                                         use_regularization=False,
                                         use_masking_layers=False, 
                                         double_head=USE_DOUBLE_HEAD)
else:
    print("Using single GPU/CPU device")
    model = music_model.create_model(conf,
                                     num_genres=len(conf.accepted_subgenres),
                                     use_regularization=False,
                                     use_masking_layers=False,
                                     double_head=USE_DOUBLE_HEAD)

2023-05-23 16:20:02.514837: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-23 16:20:02.633653: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-05-23 16:20:03.095613: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/volpepe/miniconda3/envs/music_gen/lib/:/home/volpepe/miniconda3/envs/music_gen/lib/python3.10/site-packages/nvidia/cudnn/lib
2023-05-23 16:20:03.095715: W tensorflow/stream_executor/platform/default/dso_l

Using single GPU/CPU device


2023-05-23 16:20:04.690897: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-23 16:20:04.691066: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-23 16:20:04.691221: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-23 16:20:05.250702: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-05-23 16:20:05.250848: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from S

In [2]:
model.load_weights(WEIGHTS_PATH)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f4b4c973df0>

---

## Quick example

In [3]:
# Function to choose a style for the song
def encode_styles(styles_array):
    one_hot_enc = np.zeros((len(styles_array), len(conf.accepted_subgenres)), dtype=np.int8)
    for i, style in enumerate(styles_array):
        one_hot_enc[i, conf.accepted_subgenres.index(style)] = 1
    return one_hot_enc

# Example:
styles = encode_styles(['rock', 'pop', 'dance', 'electronic'])
print(styles)

[[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]]


---

## Generation function

In [4]:
INDEX = {
    'prob': {
        'type': 0, 'measure': 1, 'beat': 2, 'position': 3, 'duration': 4, 'pitch': 5, 
        'instrument': 6, 'velocity': 7, 'key_sign': 8, 'time_sign': 9, 'tempo': 10,
    }, 
    'mask': {
        'measure': 0, 'beat': 1, 'position': 2, 'duration': 3, 'pitch': 4,
        'instrument': 5, 'velocity': 6, 'key_sign': 7, 'time_sign': 8, 'tempo': 9
    }
}


def apply_sampling_strategy(probs: np.ndarray, sampling_settings: Dict) -> np.ndarray:
    # Sample the next token from the output probabilities using the desired sampling mode
    # Use some method to modify the probabilities, such as top-k or top-p sampling
    generation_mode = sampling_settings['generation_mode']

    ## 1. Top-k sampling: sample from the top-k most probable tokens.
    ## k is expressed as a ratio of the number of possibilities, because each component has different ranges.
    if generation_mode == 'top_k_sampling':
        # Compute the top-k indices in the probability array
        k = int(np.ceil(len(probs) * sampling_settings['top_k_ratio']))
        top_k_indices = np.argsort(-np.asarray(probs))[:k]
        # Compute the mass of probability to be redistributed between the top-k elements
        redistrib_mass = sum(np.take(probs, top_k_indices))
        # Directly modify the probability array 
        # redistrib_mass : 1 = prob[i] : x --> x = prob[i] / redistrib_mass
        # For the non-top-k elements, the probability is 0
        for idx in range(len(probs)):
            probs[idx] = (probs[idx] / redistrib_mass) if idx in top_k_indices else 0

    ## 2. Top-p sampling: sample from the most probable tokens until the cumulative probability exceeds p
    elif generation_mode == 'top_p_sampling':
        best_indices = np.argsort(-np.asarray(probs))
        # Take elements until the cumulative probability exceeds the threshold (always at least one element)
        cum_prob = 0; k = 0
        while cum_prob < sampling_settings['current_top_p_ratio'] and k < len(best_indices):
            cum_prob += probs[best_indices[k]]
            k += 1
        # Directly modify the probability array
        # cum_prob : 1 = prob[i] : x --> x = prob[i] / cum_prob
        # For the other elements, the probability is 0
        for idx in range(len(probs)):
            probs[idx] = (probs[idx] / cum_prob) if idx in best_indices[:k] else 0

    elif generation_mode == 'standard_sampling':
        pass    # Maintain probabilities as they are
    else:
        raise ValueError("Unknown generation mode: {}".format(generation_mode))
    
    return probs


def sample_from(p: np.ndarray, mask: np.ndarray=None, sampling_settings: Dict={
    'generation_mode': 'standard_sampling'}) -> int:
    # Sometimes the mask is completely zero: in those cases, we return index 0 and
    # setup the rest of the generation function to deterministically choose end tokens
    if np.sum(mask) == 0:
        return 0
    # Otherwise, mask the probability array
    if mask is not None: 
        p = p * mask
    # Then modify the remaining probabilities according to the sampling strategy
    p = apply_sampling_strategy(p, sampling_settings)
    # Sample using the probability distribution
    idx = np.random.choice(
        np.arange(len(p)),
        p=p
    )
    return idx


def sample_type(song: np.ndarray, current_token_idx: int, 
                current_token_probabilities: np.ndarray, 
                use_masking:bool=True, sampling_settings: Dict={
                    'generation_mode': 'standard_sampling'}) -> int:
    type_probabilities = current_token_probabilities[INDEX['prob']['type']]
    type_mask = None
    if use_masking:
        previous_token = song[current_token_idx-1, :]
        type_mask = np.ones(conf.INPUT_RANGES['type'], dtype=np.int8)
        if previous_token[INDEX['prob']['type']] == 0: # only type 1 is acceptable
            type_mask[0] = 0; type_mask[2:] = 0
        elif previous_token[INDEX['prob']['type']] == 1: # only type 1 and 2 are acceptable
            type_mask[0] = 0; type_mask[3:] = 0
        elif previous_token[INDEX['prob']['type']] == 2: # only type 4 is acceptable
            type_mask[0:4] = 0; type_mask[5:] = 0
        elif previous_token[INDEX['prob']['type']] == 3: # only type 3-4-5-6-7 are acceptable
            type_mask[0:3] = 0
        elif previous_token[INDEX['prob']['type']] == 4: # cannot write type 0-1-2
            type_mask[0:3] = 0
            if np.sum(song[:,0] == 5) == 0: # if no time_sign has been defined, you must define it
                type_mask[3:5] = 0; type_mask[6:] = 0
        elif previous_token[INDEX['prob']['type']] == 5: # cannot write type 0-1-2
            type_mask[0:3] = 0
            if np.sum(song[:,0] == 6) == 0: # if no tempo has been defined, you must define it
                    type_mask[3:6] = 0; type_mask[7] = 0
        elif previous_token[INDEX['prob']['type']] == 6: # cannot write type 0-1-2
            type_mask[0:3] = 0
        elif previous_token[INDEX['prob']['type']] == 7: # can only write type 7
            type_mask[0:7] = 0
    # Sample the type
    sampled_type = sample_from(type_probabilities, type_mask, sampling_settings)
    return sampled_type    

In [6]:
def sample_current_token(song: np.ndarray, current_token_idx: int, 
                         current_token_probabilities: List[np.ndarray], 
                         current_settings: Dict, use_masking:bool = True,
                         sampling_settings: Dict={'generation_mode': 'standard_sampling'}) -> Tuple[np.ndarray, Dict]:
    # Create token parts
    token_components = {k: None for k in conf.INPUT_RANGES.keys()}
    
    # Sample the type
    token_components['type'] = sample_type(song, current_token_idx, current_token_probabilities, use_masking, sampling_settings)
    # Depending on the sampled type, sample the other tokens using customized masks
    if use_masking:
        mask_from_current_token = deepcopy(conf.full_mask)
        
        # Decide what to mask based on the sampled type
        if token_components['type'] == 1:
            # The other token parts (except instruments) must have index 0
            mask_from_current_token = deepcopy(conf.default_mask[:INDEX['mask']['instrument']]) + \
                                      [conf.full_mask[INDEX['mask']['instrument']].copy()]      + \
                                      deepcopy(conf.default_mask[INDEX['mask']['instrument']+1:])
            # Mask to not allow duplicate instruments
            for instrument_idx in current_settings["instruments"]:
                mask_from_current_token[INDEX['mask']['instrument']][instrument_idx] = 0
            # Sample the instrument and append it to current settings
            token_components['instrument'] = sample_from(current_token_probabilities[INDEX['prob']['instrument']], 
                                                         mask_from_current_token[INDEX['mask']['instrument']],
                                                         sampling_settings)
            current_settings['instruments'].append(token_components['instrument'])
        
        elif token_components['type'] == 2:
            mask_from_current_token = deepcopy(conf.default_mask) # only index zero on the other parts of the token

        elif token_components['type'] == 3:
            # Mask the previous measures
            m = current_settings['measure']
            mask_from_current_token[INDEX['mask']['measure']][:m] = [0]*m
            token_components['measure'] = sample_from(current_token_probabilities[INDEX['prob']['measure']], 
                                                      mask_from_current_token[INDEX['mask']['measure']], 
                                                      sampling_settings)

            # Adjust and sample beat and position
            if token_components['measure'] == m:
                b = current_settings['beat']
                mask_from_current_token[INDEX['mask']['beat']][:b] = [0]*b

                numerator = utils.time_sign_inverse_map(current_settings["time_sign"], conf)[0]
                mask_from_current_token[INDEX['mask']['beat']][numerator:] = [0]*(conf.INPUT_RANGES['beat']-numerator)

                token_components['beat'] = sample_from(current_token_probabilities[INDEX['prob']['beat']], 
                                                       mask_from_current_token[INDEX['mask']['beat']],
                                                       sampling_settings)

                if token_components['beat'] == b:
                    p = current_settings["position"]
                    mask_from_current_token[INDEX['mask']['position']][:p] = [0]*p
                    token_components['position'] = sample_from(current_token_probabilities[INDEX['prob']['position']], 
                                                               mask_from_current_token[INDEX['mask']['position']],
                                                               sampling_settings)

            # Otherwise sample beat and position without masking them
            if token_components['beat'] is None:
                token_components['beat'] = sample_from(current_token_probabilities[INDEX['prob']['beat']], 
                                                       mask_from_current_token[INDEX['mask']['beat']],
                                                       sampling_settings)
            if token_components['position'] is None:
                token_components['position'] = sample_from(current_token_probabilities[INDEX['prob']['position']], 
                                                           mask_from_current_token[INDEX['mask']['position']],
                                                           sampling_settings)

            # Mask instruments to only accept defined ones
            for instrument_idx in current_settings['instruments']:
                mask_from_current_token[INDEX['mask']['instrument']][instrument_idx] = 0
            mask_from_current_token[INDEX['mask']['instrument']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['instrument']])

            # only accept previous key_sign, time_sign and tempo
            mask_from_current_token[INDEX['mask']['key_sign']][current_settings['key_sign']] = 0
            mask_from_current_token[INDEX['mask']['key_sign']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['key_sign']])

            mask_from_current_token[INDEX['mask']['time_sign']][current_settings['time_sign']] = 0
            mask_from_current_token[INDEX['mask']['time_sign']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['time_sign']])

            mask_from_current_token[INDEX['mask']['tempo']][current_settings['tempo']] = 0
            mask_from_current_token[INDEX['mask']['tempo']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['tempo']])

            current_settings['measure']  = token_components['measure']
            current_settings['beat']     = token_components['beat']
            current_settings['position'] = token_components['position']

        elif token_components['type'] in [4,5,6]:
            # We allow the measure, key sign, time sign and tempo to change
            mask_from_current_token = [conf.full_mask[INDEX['mask']['measure']].copy()]                                 + \
                                      deepcopy(conf.default_mask[INDEX['mask']['measure']+1:INDEX['mask']['key_sign']]) + \
                                      deepcopy(conf.full_mask[INDEX['mask']['key_sign']:])
            # Mask the previous measures
            if current_settings["beat"] == 0 and current_settings["position"] == 0:
                m = current_settings["measure"]
            else:
                m = current_settings["measure"] + 1
            mask_from_current_token[0][:m] = [0]*m
            token_components['measure'] = sample_from(current_token_probabilities[INDEX['prob']['measure']], 
                                                      mask_from_current_token[INDEX['mask']['measure']],
                                                      sampling_settings)

            if token_components['type'] == 4:
                mask_from_current_token[INDEX['mask']['key_sign']][current_settings["key_sign"]] = 0 # cannot choose the same key_sign as the current

                mask_from_current_token[INDEX['mask']['time_sign']][current_settings["time_sign"]] = 0 # can only choose the current time sign
                mask_from_current_token[INDEX['mask']['time_sign']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['time_sign']])

                mask_from_current_token[INDEX['mask']['tempo']][current_settings["tempo"]] = 0 # can only choose the current tempo
                mask_from_current_token[INDEX['mask']['tempo']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['tempo']])

                # Sample the new key sign
                token_components['key_sign'] = sample_from(current_token_probabilities[INDEX['prob']['key_sign']], 
                                                           mask_from_current_token[INDEX['mask']['key_sign']],
                                                           sampling_settings)
                current_settings["key_sign"] = token_components['key_sign']

            if token_components['type'] == 5:
                mask_from_current_token[INDEX['mask']['key_sign']][current_settings["key_sign"]] = 0 # can only choose the current key_sign
                mask_from_current_token[INDEX['mask']['key_sign']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['key_sign']])

                mask_from_current_token[INDEX['mask']['time_sign']][current_settings["time_sign"]] = 0 # cannot choose the same time_sign as the current

                mask_from_current_token[INDEX['mask']['tempo']][current_settings["tempo"]] = 0 # can only choose the current tempo
                mask_from_current_token[INDEX['mask']['tempo']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['tempo']])

                token_components['time_sign'] = sample_from(current_token_probabilities[INDEX['prob']['time_sign']], 
                                                            mask_from_current_token[INDEX['mask']['time_sign']],
                                                            sampling_settings)
                current_settings["time_sign"] = token_components['time_sign']

            if token_components['type'] == 6:
                mask_from_current_token[INDEX['mask']['key_sign']][current_settings["key_sign"]] = 0 # can only choose the current key sign
                mask_from_current_token[INDEX['mask']['key_sign']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['key_sign']])

                mask_from_current_token[INDEX['mask']['time_sign']][current_settings["time_sign"]] = 0 # can only choose the current time sign
                mask_from_current_token[INDEX['mask']['time_sign']] = np.bitwise_not(mask_from_current_token[INDEX['mask']['time_sign']])

                mask_from_current_token[INDEX['mask']['tempo']][current_settings["tempo"]] = 0 # cannot choose the same tempo as the current

                token_components['tempo'] = sample_from(current_token_probabilities[INDEX['prob']['tempo']], 
                                                        mask_from_current_token[INDEX['mask']['tempo']],
                                                        sampling_settings)
                current_settings["tempo"] = token_components['tempo']
            
            current_settings["measure"] = token_components['measure']
            current_settings["beat"] = 0
            current_settings["position"] = 0

        elif token_components['type'] == 7:
            mask_from_current_token = deepcopy(conf.default_mask) # only index zero on the other parts of the token

    else:
        # No masking
        mask_from_current_token = [None] * len(conf.default_mask)

    # If any of the masks is full of False, we set the type to 7 and the other components to index 0
    if any([np.sum(mask_from_current_token[INDEX['mask'][key]]) == 0 
            for key in token_components.keys() if key != 'type']):
        token_components['type'] = 7
        for k in (token_components.keys() - {'type'}):
            token_components[k] = 0
    else:        
        # Sample all values that are still None and concatenate them (except for the velocity)
        for key in token_components.keys():
            if key == 'velocity':
                token_components['velocity'] = int(current_token_probabilities[INDEX['prob']['velocity']] * conf.INPUT_RANGES['velocity'])
            elif token_components[key] is None:
                token_components[key] = sample_from(current_token_probabilities[INDEX['prob'][key]], 
                                                    mask_from_current_token[INDEX['mask'][key]],
                                                    sampling_settings)

    current_token = np.array([
        token_components[k] 
        for k in sorted(INDEX['prob'], key=lambda x: INDEX['prob'][x])], dtype=np.int32)
    return current_token, current_settings


In [7]:
## Song generation function
def generate_songs(model, style_list:List[str], max_length:int=conf.SEQ_LEN-1, 
                   terminator_type:int=7, generation_mode:str='standard_sampling', 
                   temperature:float=1.0, top_k_ratio=0.3, top_p_start:float=0.9, 
                   top_p_min=0.9, use_masking:bool=True):
    
    # Check the validity of the parameters
    assert temperature > 0, "Temperature must be greater than 0"
    assert max_length <= conf.SEQ_LEN-1, f"The maximum length of the generated song must be less than {conf.SEQ_LEN-1}"
    assert generation_mode in ['standard_sampling', 'top_k_sampling', 'top_p_sampling'], \
        f"Parameter 'generation_mode' must be one of the following: 'standard_sampling', 'top_k_sampling', 'top_p_sampling'"

    # Collect explicitly the number of songs to generate
    num_songs = len(style_list)

    # Separate preprocessing model, transformer and output layers in order to be able to
    # inject the attention masks into the transformer and still use the loaded weights
    preprocessing_model = keras.Model(inputs=model.input, outputs=model.get_layer('final_encoding').output)
    transformer = model.get_layer('tfgpt2_model')
    output_layers = [
        model.get_layer('type_scores'), model.get_layer('measure_scores'), model.get_layer('beat_scores'),
        model.get_layer('position_scores'), model.get_layer('duration_scores'), model.get_layer('pitch_scores'),
        model.get_layer('instrument_scores'), model.get_layer('velocity_values'), model.get_layer('keysign_scores'),
        model.get_layer('timesign_scores'), model.get_layer('tempo_scores')
    ]
    activations = [tf.keras.activations.softmax] * 7 + [tf.keras.activations.relu] + [tf.keras.activations.softmax] * 3
    
    # Create the empty song array.
    # The first token of the songs is always the start token (0,...,0)
    # The rest of the tensor is filled with padding of zeroes, which will be masked by the attention masks
    generated_songs = np.zeros((num_songs, conf.SEQ_LEN-1, 11), dtype=np.int32)
    # Create the current settings for the songs
    current_settings = [{
        'instruments': [], 'measure': 0, 'beat': 0, 'position': 0, 'key_sign': 0, 'time_sign': 0, 'tempo': 0,
    } for _ in range(num_songs)]

    # Create the sampling settings dictionary
    sampling_settings = {'generation_mode': generation_mode}
    if generation_mode == 'top_k_sampling': sampling_settings['top_k_ratio'] = top_k_ratio
    elif generation_mode == 'top_p_sampling':
        # Compute the dynamic top-p thresholds at each step
        top_p_sequence = np.linspace(top_p_start, top_p_min, max_length-1)  # Generate max_length-1 evenly spaced values
    
    # Generate the one-hot encodings of the genres
    styles = encode_styles(style_list)

    # Start to generate the songs token by token
    for i in trange(1, max_length):

        # Create the attention mask (the first 2 tokens are always attended: genre and starting token)
        attention_mask = np.ones((num_songs, 1+i), dtype=np.int8)
        padding_attention_mask = np.zeros((num_songs, conf.SEQ_LEN-1-i), dtype=np.int8)
        attention_mask = np.concatenate([attention_mask, padding_attention_mask], axis=-1)

        # Preprocess the songs and pass them through the transformer
        preprocessed_tensors = preprocessing_model((generated_songs, styles))
        out_transformer      = transformer({'inputs_embeds': preprocessed_tensors},
                               attention_mask=attention_mask)['last_hidden_state']
        # Use the output layers to generate the probabilities for the next token
        # Output from transformer has SEQ_LEN tokens, so we trim it by removing the last one,
        # since it's the probability of a token that's out of our bounds.
        out_scores           = [output_layers[i](out_transformer)[:,:-1,:]
                                for i in range(len(output_layers))]
        # Apply temperature to the scores but mind the velocity scores which is a scalar
        out_scores_tempered  = [out_scores[i] / temperature for i in range(7)] + \
            [out_scores[7]]  + [out_scores[i] / temperature for i in range(8, 11)]
        out_probs            = [np.array(activations[i](out_scores_tempered[i]))
                                for i in range(len(activations))]
        
        if generation_mode == 'top_p_sampling':
            sampling_settings['current_top_p_ratio'] = top_p_sequence[i-1]

        # Sample the next token, using the requested generation mode
        next_tokens = []
        for song in range(num_songs):
            next_token, current_settings[song] = sample_current_token(song=generated_songs[song], current_token_idx=i, 
                                                    current_token_probabilities=[out_probs[h][song, i] for h in range(len(out_probs))],
                                                    current_settings=current_settings[song], use_masking=use_masking,
                                                    sampling_settings=sampling_settings)
            next_tokens.append(np.array(next_token, dtype=np.int32))
        batch = np.stack(next_tokens, axis=0)

        # Add the new tokens to the songs
        generated_songs[:, i] = batch

    # If a token in a song has type of terminator, simply overwrite the rest of the song with end tokens
    for song in range(num_songs):
        terminator_indices = np.argwhere(generated_songs[song, :, 0] == terminator_type)
        if len(terminator_indices) > 0:
            first_terminator_index = terminator_indices[0,0]
            generated_songs[song, first_terminator_index:, :] = [7] + [0]*10
        
    # Internally, the generated song is always SEQ_LEN - 1 long, so we cut it before returning it.
    return generated_songs[:, :max_length, :]

In [8]:
for genre in conf.accepted_subgenres:
    print(f"Generating {genre} songs...")
    out_songs = generate_songs(model, [genre]*4, max_length=2047, temperature=0.9, 
                            generation_mode='top_p_sampling', top_p_start=0.9, top_p_min=0.6)
    song_name = f'songs_{genre}{"_double_head" if USE_DOUBLE_HEAD else ""}.npy'
    np.save(os.path.join(conf.DATA_PATH, 'generated_songs', 'repr', song_name), out_songs, allow_pickle=True)

Generating rock songs...


  0%|          | 0/2046 [00:00<?, ?it/s]2023-05-23 16:20:50.029992: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
100%|██████████| 2046/2046 [06:59<00:00,  4.88it/s]


Generating pop songs...


100%|██████████| 2046/2046 [06:57<00:00,  4.90it/s]


Generating dance songs...


100%|██████████| 2046/2046 [07:03<00:00,  4.84it/s]


Generating country songs...


100%|██████████| 2046/2046 [07:03<00:00,  4.83it/s]


Generating metal songs...


100%|██████████| 2046/2046 [07:03<00:00,  4.83it/s]


Generating classical songs...


100%|██████████| 2046/2046 [06:40<00:00,  5.11it/s]


Generating folk songs...


100%|██████████| 2046/2046 [06:37<00:00,  5.15it/s]


Generating blues songs...


100%|██████████| 2046/2046 [06:37<00:00,  5.15it/s]


Generating house songs...


100%|██████████| 2046/2046 [07:00<00:00,  4.87it/s]


Generating indie songs...


100%|██████████| 2046/2046 [06:38<00:00,  5.14it/s]


Generating latin songs...


100%|██████████| 2046/2046 [06:46<00:00,  5.03it/s]


Generating jazz songs...


100%|██████████| 2046/2046 [06:46<00:00,  5.03it/s]


Generating funk songs...


100%|██████████| 2046/2046 [06:46<00:00,  5.04it/s]


Generating rap songs...


100%|██████████| 2046/2046 [06:47<00:00,  5.02it/s]


Generating punk songs...


100%|██████████| 2046/2046 [06:47<00:00,  5.03it/s]


Generating r&b songs...


100%|██████████| 2046/2046 [06:46<00:00,  5.03it/s]


Generating gospel songs...


100%|██████████| 2046/2046 [06:45<00:00,  5.04it/s]


Generating electronic songs...


100%|██████████| 2046/2046 [04:34<00:00,  7.46it/s]
