In [1]:
import sys
import os
import shutil
import tarfile
import pretty_midi
import fluidsynth
from IPython import display
import xml.etree.ElementTree as ET
import pprint as pp
import glob
import ntpath
import re
import json
import xml.dom.minidom
import numpy as np
from functools import total_ordering
import random

import midi_constants as mc

cwd = os.getcwd()

midi_file_dir = cwd + '/generated_pattern_midi/'
output_path = cwd + '/generated_patterns/'
drum_pattern_names_file = cwd + '/drum_pattern_names.txt'
drum_map_path = cwd + '/drum_kit_map.json'

MAXCOUNT = 25
maxDim = 192

In [2]:
# configuration
pattern_names_list = open(drum_pattern_names_file).readlines()
# pp.pprint(pattern_names_list)

filenames = glob.glob(midi_file_dir + '*.mid*')
print('Number of files:', len(filenames))
print(filenames[0])

# Sampling rate for audio playback
_SAMPLING_RATE = 16000.0


Number of files: 290
/home/clayc/projects/midi_to_hydrogen/midi_to_hydrogen_drum_pattern/generated_pattern_midi/pattern_midi_00000122.midi


In [3]:
# convert constants to dict
midi_instruments = {}
for v in dir(mc):
    if not v.startswith('__'):
        attr = getattr(mc, v)
        # print(f'variable: {v} attr: {attr}')
        midi_instruments[v] = attr

# get overrides
over_ride_dict = {}

if os.path.isfile(drum_map_path) and os.access(drum_map_path, os.R_OK):
    with open(drum_map_path, 'r') as fp:    
        over_ride_dict = json.load(fp)
    
pp.pprint(over_ride_dict)


{'COWBELL': '11',
 'HAND CLAP': '3',
 'PAISTE BELL': '17',
 'PAISTE RIDE': '12',
 'PAISTE RIDE FLINK': '18',
 'PEARL KICK': '0',
 'PEARL SIDE STICK': '1',
 'PEARL SNARE': '2',
 'PEARL SNARE RIMSHOT': '4',
 'PEARL TOM 1': '9',
 'PEARL TOM 2': '7',
 'PEARL TOM FLOOR': '5',
 'SABIAN CRASH': '13',
 'SABIAN CRASH FLINK': '14',
 'SABIAN HAT CHOKE': '16',
 'SABIAN HAT CLOSED': '6',
 'SABIAN HAT OPEN': '21',
 'SABIAN HAT PEDAL': '8',
 'SABIAN HAT SEMI-OPEN': '20',
 'SABIAN HAT SWISH': '10',
 'ZILDJIAN SPLASH': '15',
 'ZILDJIAN SPLASH CHOKE': '19'}


In [4]:
@total_ordering
class Note:
  def __init__(self, position:int, leadlag=0, velocity=0.8, pan_L=0.5, 
               pan_R=0.5, pitch=0,key='C0', length=-1, instrument=0, 
               note_off='false', probability=1, auto_pan=True, auto_velocity=True):
    self.position = position
    self.leadlag = leadlag
    self.velocity = velocity
    if auto_velocity:
      velocity = round(random.uniform(0.5, 0.9), 2)
      self.velocity = velocity
    self.pan_L = pan_L
    self.pan_R = pan_R
    if auto_pan:
      pan_values = self.get_rand_pan_values()
      self.pan_L = pan_values[0]
      self.pan_R = pan_values[1]
    self.pitch = pitch
    self.key = key
    self.length = length
    self.instrument = instrument
    self.note_off = note_off
    self.probability = probability
  
  def create_note_element(self, root:ET.Element, el_name:str, el_value):
    se = ET.SubElement(root, el_name)
    se.text = str(el_value)        
    return se  
  
  def generate_note_root(self):
    note_root = ET.Element('note')
    self.create_note_element(note_root, 'position', self.position)
    self.create_note_element(note_root, 'leadlag', self.leadlag)
    self.create_note_element(note_root, 'velocity', self.velocity)
    self.create_note_element(note_root, 'pan_L', self.pan_L)
    self.create_note_element(note_root, 'pan_R', self.pan_R)
    self.create_note_element(note_root, 'pitch', self.pitch)
    self.create_note_element(note_root, 'key', self.key)
    self.create_note_element(note_root, 'length', self.length)
    self.create_note_element(note_root, 'instrument', self.instrument)
    self.create_note_element(note_root, 'note_off', self.note_off)
    self.create_note_element(note_root, 'probability', self.probability)
    return note_root
  
  def tostring(self):
    xmlstr= ET.tostring(self.generate_note_root(), encoding='UTF-8', xml_declaration=False)
    temp = xml.dom.minidom.parseString(xmlstr)
    return temp.toprettyxml()

  def get_rand_pan_values(self):
    pan_L = round(random.uniform(0, 1), 2)
    pan_R = round(1 - pan_L, 2)
    return [pan_L, pan_R]

  def __str__(self):   
    return self.tostring()
  
  def _is_valid_operand(self, other):
        return (hasattr(other, "position") and
                hasattr(other, "instrument"))
  def __eq__(self, other):
      if not self._is_valid_operand(other):
          return NotImplemented
      return (self.position == other.position)
  def __lt__(self, other):
      if not self._is_valid_operand(other):
          return NotImplemented
      return (self.position < other.position)

    

In [5]:
# functions

def get_pattern_name(pat_name_list:list):
    if pat_name_list:
        pat_name = pat_name_list[int(round(random.uniform(0, len(pat_name_list)-1),0))]    
        return pat_name.strip().strip('\"')
    else:
        return None

# delete all files and subdirectories in directory
def delete_directory(directory):
    for root, dirs, files in os.walk(directory, topdown=False):
        for name in files:
            os.remove(os.path.join(root, name))
        for name in dirs:
            os.rmdir(os.path.join(root, name))
    os.rmdir(directory)

# create new directory or delete directory and recreate if exists
def recreate_directory(directory):
    if  os.path.exists(directory):
        delete_directory(directory)
    os.makedirs(directory)

def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):
    waveform = pm.fluidsynth(fs=_SAMPLING_RATE)
    # Take a sample of the generated waveform to mitigate kernel resets
    waveform_short = waveform[:seconds*int(_SAMPLING_RATE)]
    return display.Audio(waveform_short, rate=int(_SAMPLING_RATE))

def write_pattern_file(el_root:str, output_path:str):
    xmlstr= ET.tostring(el_root, encoding='UTF-8', xml_declaration=True)
    temp = xml.dom.minidom.parseString(xmlstr)
    pat_xml = temp.toprettyxml()
    output_xml_file = output_path
    # print(f'Output dir: {output_xml_file}')
    with open(output_xml_file, "w") as f:
        f.write(pat_xml)
        
def create_subelement(root:ET.Element, el_name:str, el_text = None):
    se = ET.SubElement(root, el_name)
    if el_text:
        se.text = el_text        
    return se

def create_note(root:ET.Element, note:Note):
    se = ET.SubElement(root, note.tostring())    
    return se

def print_xmltree(root:ET.Element):
    xmlstr= ET.tostring(root, encoding='UTF-8', xml_declaration=True)
    temp = xml.dom.minidom.parseString(xmlstr)
    new_xml = temp.toprettyxml()
    print(new_xml)
    
def get_hydrogen_inst(note_name:str, inst_dict:dict, override_dict:dict=None):
    n_name = 'H_' + note_name.upper().replace(' ', '_')
    # print (f'Checking {n_name}')
    inst =  inst_dict.get(n_name, 1)
    if override_dict:
        name_list = note_name.upper().split()
        for name in name_list:
            for key in override_dict.keys():
                if name in key:            
                    override = override_dict.get(key)
                    if override:
                        inst = override
                        print(f'Found override {key}:{override} for instrument {note_name}')
                        break
        
    return inst
 

In [6]:
# more functions
# find closest key in array
def closest(lst, K):      
     lst = np.asarray(lst)
     idx = (np.abs(lst - K)).argmin()
     return lst[idx]

def get_hyd_beat_from_start_time(midi_beat_array:list, beat_map:dict, note_start:float):
     max_beat = max(midi_beat_array)
     if note_start > max_beat:
          # print('Note outside of pattern range')
          return None
     c_beat = closest(midi_beat_array, note_start)
     h_beat = beat_map.get(c_beat)
     # print(h_beat)
     return h_beat


def generate_pattern_xml(pattern_note_list:list, pattern_name='test_pattern'):
    dkp_root = ET.Element('drumkit_pattern')
    create_subelement(dkp_root,'pattern_for_drumkit', 'Test Kit')
    create_subelement(dkp_root,'author', 'Clayton Corrello')
    create_subelement(dkp_root,'license', 'LAL-1.3/GPL-2')

    # create pattern elements
    pattern = create_subelement(dkp_root, 'pattern')
    create_subelement(pattern, 'pattern_name', pattern_name)
    create_subelement(pattern, 'info', )
    create_subelement(pattern, 'category', 'ML Generated')
    create_subelement(pattern, 'size', '192')

    # create notelist and add Note objects
    note_list = create_subelement(pattern, 'noteList')
    for note in pattern_note_list:
        note_list.append(note.generate_note_root())
    return dkp_root 
 
def convert_pmnote(note: pretty_midi.Note):
    position = get_hyd_beat_from_start_time(midi_beat_array=midi_beat_array, beat_map=beat_map, note_start=note.start)
    velocity = round(note.velocity/127, 2)
    drum_name = pretty_midi.note_number_to_drum_name(note.pitch)    
    instrument = get_hydrogen_inst(drum_name, midi_instruments, override_dict=over_ride_dict)
    return Note(position=position, velocity=velocity, instrument=instrument)


In [7]:
midi_names = []
counter = 0
for filename in filenames:
    f_name = os.path.basename(filename)
    midi_names.append(f_name)
    counter = counter + 1
    # use for limiting output during testing, etc
    if counter > MAXCOUNT:
        break

# clear and recreate output dir
recreate_directory(directory=output_path)

In [8]:
for midi_name in midi_names:
    m_file_name = os.path.splitext(midi_name)[0]
    path = os.path.join(midi_file_dir, midi_name)
    pm = pretty_midi.PrettyMIDI(path)
    try:
        mba = pm.get_beats()
        if len(mba)< 16:
            continue
        mba2 = mba[0 :16]
        midi_beat_array = np.around(mba2, decimals=2)
        hyd_beat_array = list(range(0,192, 12))
        beat_map = {midi_beat_array[i]: hyd_beat_array[i] for i in range(len(hyd_beat_array))}
    except Exception as e:
        print(f'Processing file {m_file_name} Exception: {e}')
        print(f'{midi_beat_array=}')
        print(f'{hyd_beat_array=}')
        continue
    p_notes = []
    for instrument in pm.instruments:
        # print(f'{instrument}')
        counter = 0
        for note in instrument.notes:
            # print(f'{note=}')
            pat_note = convert_pmnote(note)
            if pat_note:
                if pat_note.position:
                    p_notes.append(pat_note)
            counter = counter + 1
            # if counter > 10:
            #     break
    pattern_notes = []
    [pattern_notes.append(x) for x in p_notes if x not in pattern_notes]
    if len(pattern_notes) < 1:
        continue
    pattern_name = get_pattern_name(pattern_names_list)
    dkp_root = generate_pattern_xml(pattern_notes, pattern_name=pattern_name)
    
    pat_name = pattern_name + '_' + m_file_name
    pattern_file = output_path + pat_name + '.h2pattern'
    write_pattern_file(dkp_root, pattern_file)

Found override PEARL TOM 1:9 for instrument Bass Drum 1
Found override PEARL TOM 1:9 for instrument Bass Drum 1
Found override SABIAN HAT PEDAL:8 for instrument Pedal Hi Hat
Found override SABIAN HAT CLOSED:6 for instrument Pedal Hi Hat
Found override PEARL SIDE STICK:1 for instrument Side Stick
Found override PEARL SIDE STICK:1 for instrument Side Stick
Found override SABIAN HAT PEDAL:8 for instrument Pedal Hi Hat
Found override SABIAN HAT CLOSED:6 for instrument Pedal Hi Hat
Found override PEARL SIDE STICK:1 for instrument Side Stick
Found override PEARL SIDE STICK:1 for instrument Side Stick
Found override SABIAN HAT PEDAL:8 for instrument Pedal Hi Hat
Found override SABIAN HAT CLOSED:6 for instrument Pedal Hi Hat
Found override SABIAN HAT CLOSED:6 for instrument Closed Hi Hat
Found override SABIAN HAT CLOSED:6 for instrument Closed Hi Hat
Found override SABIAN HAT PEDAL:8 for instrument Pedal Hi Hat
Found override SABIAN HAT CLOSED:6 for instrument Pedal Hi Hat
Found override PEARL