# MuseGAN

In [29]:
import os
import matplotlib.pyplot as plt
import numpy as np
import types

from new.MuseGAN import MuseGAN

from music21 import midi
from music21 import note, stream, duration

In [30]:
def load_music(data_name, filename, n_bars, n_steps_per_bar):
    file = os.path.join("./data", data_name, filename)

    with np.load(file, encoding='bytes', allow_pickle = True) as f:
        data = f['train']

    data_ints = []

    for x in data:
        counter = 0
        cont = True
        while cont:
            if not np.any(np.isnan(x[counter:(counter+4)])):
                cont = False
            else:
                counter += 4

        if n_bars * n_steps_per_bar < x.shape[0]:
            data_ints.append(x[counter:(counter + (n_bars * n_steps_per_bar)),:])


    data_ints = np.array(data_ints)

    n_songs = data_ints.shape[0]
    n_tracks = data_ints.shape[2]

    data_ints = data_ints.reshape([n_songs, n_bars, n_steps_per_bar, n_tracks])

    max_note = 83

    where_are_NaNs = np.isnan(data_ints)
    data_ints[where_are_NaNs] = max_note + 1
    max_note = max_note + 1

    data_ints = data_ints.astype(int)

    num_classes = max_note + 1

    
    data_binary = np.eye(num_classes)[data_ints]
    data_binary[data_binary==0] = -1
    data_binary = np.delete(data_binary, max_note,-1)

    data_binary = data_binary.transpose([0,1,2, 4,3])
    return data_binary, data_ints, data

In [31]:
# run params
SECTION = 'compose'
RUN_ID = '001'
DATA_NAME = 'chorales'
FILENAME = 'Jsb16thSeparated.npz'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])



if not os.path.exists(RUN_FOLDER):
    os.makedirs(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))
    os.mkdir(os.path.join(RUN_FOLDER, 'samples'))

mode = 'build' #'load'

## данные

In [32]:
BATCH_SIZE = 64
n_bars = 2
n_steps_per_bar = 16
n_pitches = 84
n_tracks = 4

data_binary, data_ints, raw_data = load_music(DATA_NAME, FILENAME, n_bars, n_steps_per_bar)
data_binary = np.squeeze(data_binary)

## архитектура

In [27]:
gan = MuseGAN(input_dim = data_binary.shape[1:]
        , critic_learning_rate = 0.001
        , generator_learning_rate = 0.001
        , optimiser = 'adam'
        , grad_weight = 10
        , z_dim = 32
        , batch_size = BATCH_SIZE
        , n_tracks = n_tracks
        , n_bars = n_bars
        , n_steps_per_bar = n_steps_per_bar
        , n_pitches = n_pitches
        )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:                 
    gan.load_weights(RUN_FOLDER)

In [28]:
gan.chords_tempNetwork.summary()

Model: "model_27"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
temporal_input (InputLayer)  [(None, 32)]              0         
_________________________________________________________________
reshape_36 (Reshape)         (None, 1, 1, 32)          0         
_________________________________________________________________
conv2d_transpose_60 (Conv2DT (None, 2, 1, 1024)        66560     
_________________________________________________________________
batch_normalization_60 (Batc (None, 2, 1, 1024)        4096      
_________________________________________________________________
activation_60 (Activation)   (None, 2, 1, 1024)        0         
_________________________________________________________________
conv2d_transpose_61 (Conv2DT (None, 2, 1, 32)          32800     
_________________________________________________________________
batch_normalization_61 (Batc (None, 2, 1, 32)          128

In [16]:
gan.barGen[0].summary()

Model: "model_19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bar_generator_input (InputLa [(None, 128)]             0         
_________________________________________________________________
dense_8 (Dense)              (None, 1024)              132096    
_________________________________________________________________
batch_normalization_40 (Batc (None, 1024)              4096      
_________________________________________________________________
activation_40 (Activation)   (None, 1024)              0         
_________________________________________________________________
reshape_28 (Reshape)         (None, 2, 1, 512)         0         
_________________________________________________________________
conv2d_transpose_40 (Conv2DT (None, 4, 1, 512)         524800    
_________________________________________________________________
batch_normalization_41 (Batc (None, 4, 1, 512)         204

In [17]:
gan.generator.summary()

Model: "model_23"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
melody_input (InputLayer)       [(None, 4, 32)]      0                                            
__________________________________________________________________________________________________
chords_input (InputLayer)       [(None, 32)]         0                                            
__________________________________________________________________________________________________
lambda_21 (Lambda)              (None, 32)           0           melody_input[0][0]               
__________________________________________________________________________________________________
lambda_22 (Lambda)              (None, 32)           0           melody_input[0][0]               
___________________________________________________________________________________________

In [18]:
gan.critic.summary()

Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
critic_input (InputLayer)    [(None, 2, 16, 84, 4)]    0         
_________________________________________________________________
conv3d_8 (Conv3D)            multiple                  1152      
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    multiple                  0         
_________________________________________________________________
conv3d_9 (Conv3D)            multiple                  16512     
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   multiple                  0         
_________________________________________________________________
conv3d_10 (Conv3D)           multiple                  196736    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   multiple                  0  

## тренировка

In [19]:

EPOCHS = 1000
PRINT_EVERY_N_BATCHES = 10

gan.epoch = 0

In [None]:
gan.train(     
    data_binary
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
)

In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, len(gan.d_losses))

plt.show()
