In [16]:
from keras.layers import Input,Conv2D,Flatten,Dense,Reshape,Lambda,Conv3D,Activation,Conv2DTranspose,BatchNormalization
from keras.models import Model,Sequential
from keras import backend as K
from keras.optimizers import Adam,RMSprop

In [8]:
import numpy as np
import os
import pickle
from music21 import midi
from music21 import note,stream,duration,tempo

In [17]:
class  MusicGenGan():
    def __init__(self,
                 input_dim,
                 latent_dim,
                 critic_learning_rate,
                 generator_learning_rate,
                 optimizer,
                 batch_size,
                 n_tracks,
                 n_bars,
                 n_steps_per_bar,
                 n_pitches    
    ):
        self.input_dim=input_dim
        self.latent_dim=latent_dim
        self.critic_learning_rate=critic_learning_rate
        self.generator_learning_rate=generator_learning_rate
        self.optimizer=optimizer
        self.batch_size=batch_size
        self.n_tracks=n_tracks
        self.n_bars=n_bars
        self.n_steps_per_bar=n_steps_per_bar
        self.n_pitches=n_pitches
        
        
        self.d_losses=[]
        self.g_losses=[]
        
        self._build_critic()
        self._build_generator()
        self._build_adversarial()
        
        
    def conv(self,x,f,k,s,a,p):
        x=Conv3D(filters=f,
                kernel_size=k,strides=s,padding=p)(x)
        if a=='relu':
            x=Activation(a)(x)
        elif a=='lrelu':
            x=LeakyReLU(x)
        return x
    
    def _build_critic(self):
        critic_input=Input(shape=self.input_dim,name='critic_input')
        x=critic_input
        
        x=self.conv(x,f=128,k=(2,1,1),s=(1,1,1),a='lrelu',p='valid')
        x=self.conv(x,f=128,k=(self.n_bars-1,1,1),s=(1,1,1),a='lrelu',p='valid')
        x=self.conv(x,f=128,k=(1,1,12),s=(1,1,12),a='lrelu',p='same')
        x=self.conv(x,f=128,k=(1,1,7),s=(1,1,7),a='lrelu',p='same')
        x=self.conv(x,f=128,k=(1,2,1),s=(1,2,1),a='lrelu',p='same')
        x=self.conv(x,f=128,k=(1,2,1),s=(1,2,1),a='lrelu',p='same')
        x=self.conv(x,f=256,k=(1,4,1),s=(1,2,1),a='lrelu',p='same')
        x=self.conv(x,f=512,k=(1,3,1),s=(1,2,1),a='lrelu',p='same')
        
        x=Flatten()(x)
        
        x=Dense(1024)(x)
        x=LeakyReLU()(x)
        critic_output=Dense(1,activation=None)(x)
        self.critic=Model(critic_input,critic_output)
        
    def conv_t(self,x,f,k,s,a,p,bn):
        x=Conv2DTranspose(filters=f,
                         kernel_size=k,
                         strides=s,
                         padding=p 
                         )(x)
        if bn:
            x=BatchNormalization(momentum=0.9)(x)
        if a=='relu':
            x=Activation(a)(x)
        elif a=='lrelu':
            x=LeakyReLU(x)
        return x
    
    def TemporalNetowrk(self):
        input_layer=Input(shape=(self.latent_dim,),name='temporal_input')
        x=Reshape()