# Imports

In [None]:
import os

import numpy as np
import tensorflow as tf

import wandb
from wandb.keras import WandbCallback

import config
import utils
import music_model

# Setup

Set the following variables before training

In [None]:
ADDITIONAL_MODEL_NAME = ''
LOG_NAME = ''                 # A log name for visualization on Wandb (if left empty it will be a random name)

MODEL_TYPE  = 'GPT'            # "GPT" or "XL"
USE_MASK    = True
USE_REG     = True
USE_MSE_FOR_VELOCITY = True

USE_WANDB   = True
USE_ONE_GPU = True           # or False if another GPU is available

USE_SMALL_GENRE_SET = True  # or False if we want to use the dataset with the full genre subset list
DATASET_NAME = 'tf_data7dict' # or whatever

In [None]:
if USE_MASK and not USE_REG:
    MODEL_NAME = 'mask_only'
elif USE_REG and not USE_MASK:
    MODEL_NAME = 'reg_only'
elif USE_REG and USE_MASK:
    MODEL_NAME = 'reg_and_mask'
else:
    MODEL_NAME = 'baseline'

In [None]:
ROOT_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
conf = config.Config(config_string="single_instruments_type", 
                     root_path=ROOT_PATH, 
                     model_type=MODEL_TYPE,
                     model_name=f'model_{MODEL_TYPE}_{MODEL_NAME + ADDITIONAL_MODEL_NAME}')

# Config object has by default the full list of accepted subgenres and works on multi-gpus
# If we use the small dataset
if USE_SMALL_GENRE_SET:
    conf.accepted_subgenres = ['folk', 'nes', 'maestro']

# If we need to use only the first GPU
if USE_ONE_GPU:
    conf.GPUS = tf.config.experimental.list_physical_devices('GPU')[0]
    conf.BATCH_SIZE = 4
    conf.GLOBAL_BATCH_SIZE = conf.BATCH_SIZE
    conf.num_devices = 1

# Training

## Model creation

In [None]:
if conf.num_devices > 1:
    print("Using multiple GPUs with Mirrored Strategy")
    with conf.training_strategy.scope():
        model = music_model.create_model(conf, 
                                         use_masking_layers=USE_MASK,
                                         use_regularization=USE_REG,
                                         use_mse_for_velocity=USE_MSE_FOR_VELOCITY)
else:
    print("Using single GPU/CPU device")
    model = music_model.create_model(conf, 
                                     use_masking_layers=USE_MASK,
                                     use_regularization=USE_REG,
                                     use_mse_for_velocity=USE_MSE_FOR_VELOCITY)

## Wandb setup

In [None]:
if USE_WANDB:
    wandb_config = {
        'gpus': conf.num_devices,
        'dataset': DATASET_NAME,
        'genres': conf.accepted_subgenres,
        'embedding_size': conf.SINGLE_EMB_SIZE,
        'batch_size': conf.BATCH_SIZE,
        'global_batch_size': conf.GLOBAL_BATCH_SIZE,
        'mse_for_velocity': USE_MSE_FOR_VELOCITY,
        'reg_loss_scale': conf.REG_LOSS_SCALE,
        'masking': conf.USE_MASKING,
        'dropout_prob': conf.DROPOUT_VALUE,
        'seq_len': conf.SEQ_LEN,
        'token_dim': conf.TOKEN_DIM,
        'genre_dim': conf.GENRE_DIM,
        'attn_heads': conf.ATTENTION_HEADS,
        'attn_blocks': conf.ATTENTION_BLOCKS,
    }

    if MODEL_TYPE == 'GPT':
        wandb_config['activation_func'] = conf.DECODER_ACTIVATION_FUNCTION
    elif MODEL_TYPE == 'XL':
        wandb_config['sequence_blocks'] = conf.DIV_VAL
        wandb_config['head_dim']  = conf.HEAD_DIM
        wandb_config['inner_dim'] = conf.INNER_DIM
        wandb_config['memory_length'] = conf.MEMORY_LEN

    run = wandb.init(project="Music Generation", entity="marcello-e-federico",
                     group=MODEL_NAME, job_type='train', config=wandb_config,
                     name=LOG_NAME if LOG_NAME != '' else None)

## Loading dataset

In [None]:
dataset_path = conf.dataset_paths[DATASET_NAME]
train_dataset, val_dataset, test_dataset = utils.get_dataset_splits(dataset_path, conf)

## Training

In [None]:
callbacks = conf.MODEL_CALLBACKS
if USE_WANDB:
    callbacks.append(WandbCallback(
        save_model=False, save_graph=False,
        log_weights=True
    ))

In [None]:
model.fit(
    train_dataset,
    epochs = 100,
    callbacks = callbacks,
    validation_data = val_dataset,
    # initial_epoch = initial_epoch # change if resuming from previous checkpoint
)

In [None]:
if USE_WANDB:
    run.finish()