# Generating music with a GAN

We can use the techniques we've learned so far to generate music as well.  In this notebook, we'll build a GAN (generative adversarial network) that can generate reasonable sounding music based on Midi files.  This will also demonstrate how to work with MIDI in Python.

## MIDI dataset

This notebook uses the "clean midi" dataset from the Lakh Midi project: 

* http://colinraffel.com/projects/lmd/
* http://colinraffel.com/projects/lmd/#get

In [1]:
CACHE_DIR = os.path.expanduser('~/.cache/dl-cookbook')

def download(url):
    filename = os.path.join(CACHE_DIR, re.sub('[^a-zA-Z0-9.]+', '_', url))
    if os.path.exists(filename):
        return filename
    else:
        os.system('mkdir -p "%s"' % CACHE_DIR)
        assert os.system('wget -O "%s" "%s"' % (filename, url)) == 0
        return filename
    
tar_filename = download('http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz')

In [2]:
import tarfile
midi_tar = tarfile.open(tar_filename)

## Extracting notes

MIDI is a relatively complex format, but for our purposes we want to just extract a single stream of notes.  
We'll use the convenient midicsv program [http://www.fourmilab.ch/webtools/midicsv/] to help with this.

In [14]:
import os
import subprocess
import pandas as pd
import io
import tqdm
import numpy as np

def extract_notes(filename, midi_data):
    p = subprocess.Popen(['midicsv | egrep "Note_"'], shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    stdout, stderr = p.communicate(midi_data.read())
    df = pd.read_csv(io.BytesIO(stdout), index_col=False,
                       names=['track', 'offset', 'op', 'channel', 'note', 'velocity'])
    df['source'] = filename
    offsets = df[['track', 'offset']].groupby(['track']).transform(lambda y: y.diff()).fillna(0)
    offsets = np.asarray(offsets)[:, 0]
#     df['offset'] = offsets
    return df

all_midis = []
for i, f in enumerate(midi_tar):
    if f.isfile():
        all_midis.append(extract_notes(f.name, midi_tar.extractfile(f)))
    
    if i % 10 == 0:
        print('\r %05d' % i, end='', flush=True)
    if len(all_midis) > 100:
        break
        
midi_df = pd.concat(all_midis, ignore_index=True)
midi_df.head()

 00110

Unnamed: 0,track,offset,op,channel,note,velocity,source
0,2,0,Note_on_c,1,73,127,clean_midi/Hugues Aufray/Celine.mid
1,2,4,Note_on_c,1,46,127,clean_midi/Hugues Aufray/Celine.mid
2,2,168,Note_off_c,1,73,0,clean_midi/Hugues Aufray/Celine.mid
3,2,192,Note_on_c,1,53,127,clean_midi/Hugues Aufray/Celine.mid
4,2,384,Note_on_c,1,73,127,clean_midi/Hugues Aufray/Celine.mid


In [15]:
midi_df.sort_values(by=['source', 'offset'], inplace=True)
midi_df.reset_index(inplace=True, drop=True)

In [19]:
midi_df['op'] = midi_df['op'] == ' Note_on_c'

In [20]:
midi_df

Unnamed: 0,track,offset,op,channel,note,velocity,source
0,5,4608,True,3,69,86,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
1,7,4608,True,5,53,99,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
2,8,4608,True,9,35,109,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
3,8,4608,True,9,42,124,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
4,8,4672,False,9,42,0,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
5,7,4712,False,5,53,0,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
6,7,4736,True,5,57,99,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
7,8,4736,True,9,42,93,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
8,8,4800,False,9,42,0,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...
9,8,4800,False,9,35,0,clean_midi/ARMSTRONG LOUIS/(What a) Wonderful ...


In [21]:
import keras.backend as K
import keras.layers as L
import keras.models as M

Using TensorFlow backend.


In [32]:
def midi_model():
    offset = L.Input(name='offset', shape=(None, 1,))
    channel = L.Input(name='channel', shape=(None, 1,))
    note = L.Input(name='note', shape=(None, 1,))
    op = L.Input(name='op', shape=(None, 1,))
    
    combined = L.concatenate([offset, channel, note, op])
    lstm_1 = L.LSTM(units=256, return_sequences=True)(combined)
    lstm_2 = L.LSTM(units=256)(lstm_1)
    output = L.Dense(4)(lstm_2)
    
    model = M.Model(inputs=[offset, channel, note, op],
                    outputs=output)
    model.compile(optimizer='adam', loss='mse')
    return model

model = midi_model()
model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
offset (InputLayer)              (None, None, 1)       0                                            
____________________________________________________________________________________________________
channel (InputLayer)             (None, None, 1)       0                                            
____________________________________________________________________________________________________
note (InputLayer)                (None, None, 1)       0                                            
____________________________________________________________________________________________________
op (InputLayer)                  (None, None, 1)       0                                            
___________________________________________________________________________________________

In [34]:
import random

TRAINING_CONTEXT = 64

def data_generator(midi_df):
    while True:
        offset = random.randint(0, len(midi_df) - TRAINING_CONTEXT)
        yield {
            'offset': midi_df['offset'][offset:offset+TRAINING_CONTEXT],
            'op': midi_df['op'][offset:offset+TRAINING_CONTEXT],
            'note': midi_df['note'][offset:offset+TRAINING_CONTEXT],
            'channel': midi_df['note'][offset:offset+TRAINING_CONTEXT],
        }
        
model.fit_generator(data_generator(midi_df), steps_per_epoch=1000, epochs=1)

Epoch 1/1


ValueError: output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: {'offset': 414093    11038
414094    11038
414095    11039
414096    11039
414097    11039
414098    11039
414099    11040
414100    11040
414101    11040
414102    11040
414103    11040
414104    11040
414105    11040
414106    11041
414107    11041
414108    11041
414109    11042
414110    11043
414111    11043
414112    11044
414113    11061
414114    11070
414115    11071
414116    11072
414117    11075
414118    11075
414119    11075
414120    11085
414121    11086
414122    11086
          ...  
414127    11088
414128    11088
414129    11088
414130    11088
414131    11088
414132    11089
414133    11089
414134    11089
414135    11090
414136    11090
414137    11091
414138    11092
414139    11108
414140    11112
414141    11113
414142    11113
414143    11120
414144    11121
414145    11123
414146    11123
414147    11125
414148    11134
414149    11135
414150    11135
414151    11135
414152    11135
414153    11135
414154    11136
414155    11136
414156    11136
Name: offset, Length: 64, dtype: int64, 'op': 414093    True
414094    True
414095    True
414096    True
414097    True
414098    True
414099    True
414100    True
414101    True
414102    True
414103    True
414104    True
414105    True
414106    True
414107    True
414108    True
414109    True
414110    True
414111    True
414112    True
414113    True
414114    True
414115    True
414116    True
414117    True
414118    True
414119    True
414120    True
414121    True
414122    True
          ... 
414127    True
414128    True
414129    True
414130    True
414131    True
414132    True
414133    True
414134    True
414135    True
414136    True
414137    True
414138    True
414139    True
414140    True
414141    True
414142    True
414143    True
414144    True
414145    True
414146    True
414147    True
414148    True
414149    True
414150    True
414151    True
414152    True
414153    True
414154    True
414155    True
414156    True
Name: op, Length: 64, dtype: bool, 'note': 414093    54
414094    45
414095    57
414096    57
414097    49
414098    52
414099    45
414100    73
414101    69
414102    57
414103    45
414104    40
414105    42
414106    42
414107    40
414108    61
414109    61
414110    64
414111    52
414112    64
414113    69
414114    73
414115    52
414116    49
414117    52
414118    45
414119    57
414120    73
414121    45
414122    69
          ..
414127    45
414128    64
414129    45
414130    54
414131    42
414132    42
414133    54
414134    61
414135    61
414136    52
414137    57
414138    57
414139    69
414140    40
414141    40
414142    73
414143    52
414144    49
414145    45
414146    57
414147    52
414148    45
414149    57
414150    52
414151    56
414152    47
414153    68
414154    44
414155    71
414156    56
Name: note, Length: 64, dtype: int64, 'channel': 414093    54
414094    45
414095    57
414096    57
414097    49
414098    52
414099    45
414100    73
414101    69
414102    57
414103    45
414104    40
414105    42
414106    42
414107    40
414108    61
414109    61
414110    64
414111    52
414112    64
414113    69
414114    73
414115    52
414116    49
414117    52
414118    45
414119    57
414120    73
414121    45
414122    69
          ..
414127    45
414128    64
414129    45
414130    54
414131    42
414132    42
414133    54
414134    61
414135    61
414136    52
414137    57
414138    57
414139    69
414140    40
414141    40
414142    73
414143    52
414144    49
414145    45
414146    57
414147    52
414148    45
414149    57
414150    52
414151    56
414152    47
414153    68
414154    44
414155    71
414156    56
Name: note, Length: 64, dtype: int64}