# Imports

In [None]:
import os

import numpy as np
import tensorflow as tf

import wandb
from wandb.keras import WandbCallback

import config, music_model

# Setup

Set the following variables before training

In [2]:
ADDITIONAL_MODEL_NAME = '_3'

USE_MASK    = False
USE_REG     = True

USE_WANDB   = False
USE_ONE_GPU = False            # or False if another GPU is available

USE_SMALL_GENRE_SET = True    # or False if we want to use the dataset with the full genre subset list
DATASET_NAME = 'tf_data7dict' # or whatever

In [3]:
if USE_MASK and not USE_REG:
    MODEL_NAME = 'mask_only'
elif USE_REG and not USE_MASK:
    MODEL_NAME = 'reg_only'
elif USE_REG and USE_MASK:
    MODEL_NAME = 'reg_and_mask'
else:
    MODEL_NAME = 'baseline'

In [4]:
ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
conf = config.Config("single_instruments_type", ROOT_PATH, f'model_{MODEL_NAME + ADDITIONAL_MODEL_NAME}')
# Config object has by default the full list of accepted subgenres and works on multi-gpus
# If we use the small dataset
if USE_SMALL_GENRE_SET:
    conf.accepted_subgenres = ['folk', 'nes', 'maestro']
# If we need to use only the first GPU
if USE_ONE_GPU:
    conf.GPUS = tf.config.experimental.list_physical_devices('GPU')[0]
    conf.BATCH_SIZE = 4
    conf.GLOBAL_BATCH_SIZE = conf.BATCH_SIZE
    conf.num_devices = 1

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


# Training

## Model creation

In [5]:
if conf.num_devices > 1:
    print("Using multiple GPUs with Mirrored Strategy")
    with conf.training_strategy.scope():
        model = music_model.create_model(num_genres=len(conf.accepted_subgenres), 
                                                    use_masking_layers=USE_MASK,
                                                    use_regularization=USE_REG)
else:
    print("Using single GPU/CPU device")
    model = music_model.create_model(num_genres=len(conf.accepted_subgenres), 
                                                use_masking_layers=USE_MASK,
                                                use_regularization=USE_REG)

Using multiple GPUs with Mirrored Strategy
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/t

## Wandb setup

In [6]:
if USE_WANDB:
    wandb_config = {
        'gpus': conf.num_devices,
        'dataset': DATASET_NAME,
        'genres': conf.accepted_subgenres,
        'embedding_size': conf.SINGLE_EMB_SIZE,
        'batch_size': conf.BATCH_SIZE,
        'global_batch_size': conf.GLOBAL_BATCH_SIZE,
        'reg_loss_scale': conf.REG_LOSS_SCALE,
        'masking': conf.USE_MASKING,
        'dropout_prob': conf.DROPOUT_VALUE,
        'seq_len': conf.SEQ_LEN,
        'token_dim': conf.TOKEN_DIM,
        'genre_dim': conf.GENRE_DIM,
        'n_heads': conf.ATTENTION_HEADS,
        'n_blocks': conf.ATTENTION_BLOCKS,
        'activation_func': conf.DECODER_ACTIVATION_FUNCTION
    }

    run = wandb.init(project="Music Generation", entity="marcello-e-federico",
                     group=MODEL_NAME, job_type='train', config=wandb_config)

## Loading dataset

In [7]:
dataset = tf.data.Dataset.load(conf.tf_data7dict_path).\
            batch(conf.GLOBAL_BATCH_SIZE).\
            cache().\
            shuffle(conf.SHUFFLE_SIZE).\
            prefetch(conf.PREFETCH_SIZE)

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)

train_dataset = dataset.skip(int(len(dataset)/5))
test_dataset = dataset.take(int(len(dataset)/5))

## Training

In [8]:
callbacks = conf.MODEL_CALLBACKS
if USE_WANDB:
    callbacks.append(WandbCallback(
        save_model=False, save_graph=False,
        log_weights=True
    ))

In [None]:
model.fit(
    train_dataset,
    epochs = 500,
    callbacks = callbacks,
    validation_data = test_dataset,
    # initial_epoch = initial_epoch # change if resuming from previous checkpoint
)

In [None]:
if USE_WANDB:
    run.finish()