In [1]:
DRUM_TYPES = {
    "CC": [
        49,  # crash cymbal 1
        57,  # crash cymbal 2
        52,  # china cymbal
        55,  # splash cymbal
        51,  # ride cymbal
        59,  # ride cymbal 2
        "CC",  # crash (ddm-own)
    ],  # crash
    "OH": [
        "ohh",
        46,  # hi-hat open
        "overheads",  # drum kit data
    ],  # hi-hat open
    "CH": [
        "chh",
        42,  # hi-hat cloased
        "HH",  # closed hi-hat (ddm-own)
    ],  # hi-hat closed
    "TT": [
        "mt",
        45,  # mid tom
        47,  # mid tom
        48,  # high-mid tom
        50,  # high tom
        "toms",  # tom (drum kit data)
    ],  # tom
    "SD": [
        "sd",
        38,  # snare drum
        40,  # electric snare drum
        "snare",  # snare drum (drum kit data)
        "SD",  # snare (ddm-own)
    ],  # snare
    "KK": [
        "bd",
        35,  # bass drum
        36,  # kick drum
        "kick",  # kick (drum kit data)
        "KD",  # kick (idmt)
        "KK",  # kick (ddm-own)
    ],  # kick
}
DRUM_MAP = {}
# Iterate over the DRUM_TYPES
for drum_type, values in DRUM_TYPES.items():
    # Iterate over the values for each drum_type
    for value in values:
        # Add the mapping to the new dictionary
        DRUM_MAP[value] = drum_type

In [3]:
import pretty_midi

def _get_drum_track_from_mid(midi_data):
    # Find the drum track
    drum_track = next(
        (instrument for instrument in midi_data.instruments if instrument.is_drum),
        None,
    )

    if drum_track is None:
        print("No drum track found in the MIDI file.")
        return None

    return drum_track

def get_onsets_from_mid(
    midi_path: str, start: float = 0, end: float = None
):
    """
    -- MID file에서 onset 읽어오기
    """
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    drum_track = _get_drum_track_from_mid(midi_data)
    onset_times = [note.start for note in drum_track.notes]
    onset_times.sort()
    return onset_times

@staticmethod
def get_onsets_instrument_from_mid(
    midi_path: str, start: float = 0, end: float = None, onset_dict={}
):
    """
    -- midi file에서 드럼 악기별로 onset을 구하는 함수

    onset_dict : {'CC':[], 'OH':[], 'CH':[], 'TT':[], 'SD':[], 'HH':[]}
    """
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    drum_track = _get_drum_track_from_mid(midi_data)

    # Dictionary to store onsets for each selected drum instrument
    drum_onsets = onset_dict
    # print(drum_onsets)

    # 악기의 노트를 순회하며 onset을 찾음
    for note in drum_track.notes:
        if note.pitch in DRUM_MAP:
            drum_onsets[DRUM_MAP[note.pitch]].append(note.start)

    return drum_onsets

In [25]:
import pretty_midi

# midi 읽기
midi_path = "../data/test/8_rock_100_beat_4-4_1.midi"
midi_data = pretty_midi.PrettyMIDI(midi_path)
drum_track = _get_drum_track_from_mid(midi_data)

# Dictionary to store onsets for each selected drum instrument
drum_onsets = []
# print(drum_onsets)

# 악기의 노트를 순회하며 onset을 찾음
for idx, note in enumerate(drum_track.notes):
    drum_onsets.append({"pitch": note.pitch, "start": note.start})
    if note.pitch == 58:
        print(idx, "삐융")

In [26]:
drum_onsets

[{'pitch': 38, 'start': 0.0},
 {'pitch': 22, 'start': 0.0},
 {'pitch': 22, 'start': 0.28250000000000003},
 {'pitch': 36, 'start': 0.30625},
 {'pitch': 36, 'start': 0.5325},
 {'pitch': 38, 'start': 0.58125},
 {'pitch': 22, 'start': 0.58875},
 {'pitch': 22, 'start': 0.8562500000000001},
 {'pitch': 36, 'start': 0.9925},
 {'pitch': 36, 'start': 1.12125},
 {'pitch': 26, 'start': 1.17875},
 {'pitch': 40, 'start': 1.17875},
 {'pitch': 26, 'start': 1.4775},
 {'pitch': 36, 'start': 1.485},
 {'pitch': 38, 'start': 1.65875},
 {'pitch': 38, 'start': 1.71},
 {'pitch': 38, 'start': 1.8025},
 {'pitch': 44, 'start': 1.85},
 {'pitch': 38, 'start': 1.94875},
 {'pitch': 48, 'start': 2.0975},
 {'pitch': 48, 'start': 2.2475},
 {'pitch': 44, 'start': 2.3775},
 {'pitch': 55, 'start': 2.4},
 {'pitch': 36, 'start': 2.4175},
 {'pitch': 22, 'start': 2.69625},
 {'pitch': 36, 'start': 2.8375},
 {'pitch': 36, 'start': 2.9162500000000002},
 {'pitch': 38, 'start': 3.02375},
 {'pitch': 22, 'start': 3.0275},
 {'pitch':

In [14]:
onset_init={v: [] for v, _ in DRUM_TYPES.items()}

In [13]:
import glob
e_gmd_data = glob.glob('../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/**/*.midi', recursive=True)
no_hihat = []
for midi_file in e_gmd_data:
    # print(midi_file)
    drum_onsets = get_onsets_instrument_from_mid(midi_file, onset_dict={v: [] for v, _ in DRUM_TYPES.items()})
    chk = True
    for k, _ in DRUM_TYPES.items():
        if len(drum_onsets[k]) < 10:
            chk = False
            break
    if chk:
        print(midi_file)

../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_2.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_51.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_19.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_18.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_22.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_36.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_4.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_54.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_15.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_25.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_42.midi
../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer3/session1/1_rock_105_beat_4-4_43

KeyboardInterrupt: 

In [15]:
drum_onsets = get_onsets_instrument_from_mid("../data/raw/e-gmd-v1.0.0/e-gmd-v1.0.0/drummer8/eval_session/1_funk-groove1_138_beat_4-4_1.midi",onset_dict=onset_init)


{'CC': [],
 'OH': [],
 'CH': [0.0,
  0.206521925,
  0.423913425,
  0.6521745,
  0.8650370104166667,
  1.0860517020833333,
  1.30706639375,
  1.5525376291666666,
  1.7590595541666667,
  1.9773568520833333,
  2.2010889375000002,
  2.4112340541666666,
  2.615038585416667,
  2.819748914583333,
  3.2771768625,
  3.49185096875,
  3.7038076812500003,
  4.1349674895833335,
  4.3378662229166665,
  4.560692510416667,
  4.778989808333334,
  5.215584404166667,
  5.416671541666667,
  5.863229914583333,
  6.07337503125,
  6.29076653125,
  6.50815803125,
  6.747288681250001,
  6.9628685854166665,
  7.1766368937500005,
  7.4030863729166665,
  7.621383670833334,
  7.8242824041666665,
  8.032615925,
  8.50000765,
  8.930261660416667,
  9.362327266666666,
  9.572472383333334,
  9.78533489375,
  10.232799064583334,
  10.443849979166666,
  10.6576182875,
  10.87229239375,
  11.0869665,
  11.295300020833334,
  11.734612010416667,
  11.965590479166666,
  12.178452989583334,
  12.393127095833334,
  12.8460260