# Rich Drum Encoding

In [27]:
import json
import os

from collections import Counter
from pathlib import Path

import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [31]:
from src.constants import MIDI_DRUM_PATCH_NAMES

Explore the distribution of tokens in the encodings:

In [56]:
DATA_DIR = Path('../experiments/rich_drum/dataset/all')

In [110]:
toks = []
for fname in os.listdir(DATA_DIR):
    with open(DATA_DIR/fname) as f:
        enc_data = json.load(f)
    toks += enc_data['data']

In [111]:
len(toks)

4036462

In [112]:
len(set(toks))

111681

In [113]:
def process_midi_program(midi_prog):
    """
    Map instruments to some canonical version (all bass drums to the 'Acoustic Bass Drum'
    program for example)
    """
    replace = {
        36: 35,
        40: 38,
        51: 49,
        57: 51
    }
    return replace.get(midi_prog, midi_prog)

In [114]:
def make_readable(tok):
    sub_toks = tok.split('_')
    readable = []
    for t in sub_toks:
        if t[0] == 'P':
            try:
                midi_prog = process_midi_program(int(t.replace('P-', '')))
                readable.append(MIDI_DRUM_PATCH_NAMES[midi_prog])
                readable.sort()
            except KeyError:
                pass
        else:
            readable.append(int(t.replace('D-', '')))
    return tuple(readable)

In [115]:
nice_toks = [make_readable(tok) for tok in toks]

In [116]:
len(set(nice_toks))

40800

In [117]:
count = Counter(nice_toks)

In [126]:
count.most_common(2500)

[(('Closed Hi Hat', 2), 178227),
 (('Closed Hi Hat', 3), 138094),
 (('Closed Hi Hat', 4), 125183),
 (('Closed Hi Hat', 1), 122771),
 (('Acoustic Bass Drum', 'Closed Hi Hat', 4), 91820),
 (('Acoustic Bass Drum', 'Closed Hi Hat', 3), 75744),
 (('Acoustic Bass Drum', 'Closed Hi Hat', 2), 57275),
 (('Acoustic Snare', 2), 52756),
 (('Tambourine', 2), 50537),
 (('Acoustic Snare', 1), 50112),
 (('Acoustic Bass Drum', 4), 48909),
 (('Acoustic Bass Drum', 2), 47870),
 (('Acoustic Bass Drum', 3), 46374),
 (('Acoustic Snare', 'Closed Hi Hat', 4), 43834),
 (('Tambourine', 1), 42813),
 (('Acoustic Bass Drum', 1), 38734),
 (('Closed Hi Hat', 0), 38040),
 (('Acoustic Snare', 'Closed Hi Hat', 3), 34014),
 (('Acoustic Bass Drum', 'Closed Hi Hat', 1), 30594),
 (('Acoustic Snare', 3), 28960),
 (('Acoustic Bass Drum', 8), 28889),
 (('Acoustic Snare', 'Closed Hi Hat', 2), 27359),
 (('Pedal Hi-Hat', 2), 27055),
 (('Acoustic Bass Drum', 0), 25708),
 (('Cabasa', 1), 25110),
 (('Cabasa', 2), 24115),
 (('Acoust

In [127]:
top_actions = [list(toks) for toks, _ in count.most_common(2500)]

In [128]:
with open('../src/encoders/rich_drum/tokens.json', 'w') as f:
    json.dump(top_actions, f)