In [1]:
!pip install torchlibrosa

Collecting torchlibrosa
  Downloading torchlibrosa-0.1.0-py3-none-any.whl.metadata (3.5 kB)
Downloading torchlibrosa-0.1.0-py3-none-any.whl (11 kB)
Installing collected packages: torchlibrosa
Successfully installed torchlibrosa-0.1.0


In [2]:
!pip install mido

Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mido
Successfully installed mido-1.3.3


In [3]:
!pip install sox

Collecting sox
  Downloading sox-1.5.0.tar.gz (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.9/63.9 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
Building wheels for collected packages: sox
  Building wheel for sox (setup.py) ... [?25l- \ | done
[?25h  Created wheel for sox: filename=sox-1.5.0-py3-none-any.whl size=40038 sha256=f63b6c7d3184be5c6b5b06bc366984d99b93adef12c409a94d61a5d42e6f7a25
  Stored in directory: /root/.cache/pip/wheels/74/e7/7b/8033be3ec5e4994595d01269fc9657c8fd83a0dcbf8536666a
Successfully built sox
Installing collected packages: sox
Successfully installed sox-1.5.0


In [4]:
!pip install mir_eval

Collecting mir_eval
  Downloading mir_eval-0.7.tar.gz (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.7/90.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: mir_eval
  Building wheel for mir_eval (setup.py) ... [?25l- \ done
[?25h  Created wheel for mir_eval: filename=mir_eval-0.7-py3-none-any.whl size=100701 sha256=ecf4e70863fefc973a568f6a785395e473c8c4a92ac47112aab3f2fcf64fa178
  Stored in directory: /root/.cache/pip/wheels/3e/2f/0d/dda9c4c77a170e21356b6afa2f7d9bb078338634ba05d94e3f
Successfully built mir_eval
Installing collected packages: mir_eval
Successfully installed mir_eval-0.7


## config.py -----------------------------------------------------------------------------------------------------------------------

In [5]:
class Configuration(object):
    def __init__(self):
        self.sample_rate = 16000
        self.classes_num = 88    # Number of notes of piano
        self.begin_note = 21     # MIDI note of A0, the lowest note of a piano.
        self.segment_seconds = 10.     # Training segment duration
        self.hop_seconds = 1.
        self.frames_per_second = 100
        self.velocity_scale = 128

config = Configuration()
print(config.sample_rate)

16000


## piano_vad.py -------------------------------------------------------------------------------------------------------------------

In [6]:
import numpy as np


def note_detection_with_onset_offset_regress(frame_output, onset_output, 
    onset_shift_output, offset_output, offset_shift_output, velocity_output,
    frame_threshold):
    """Process prediction matrices to note events information.
    First, detect onsets with onset outputs. Then, detect offsets
    with frame and offset outputs.
    
    Args:
      frame_output: (frames_num,)
      onset_output: (frames_num,)
      onset_shift_output: (frames_num,)
      offset_output: (frames_num,)
      offset_shift_output: (frames_num,)
      velocity_output: (frames_num,)
      frame_threshold: float

    Returns: 
      output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity], 
      e.g., [
        [1821, 1909, 0.47498, 0.3048533, 0.72119445], 
        [1909, 1947, 0.30730522, -0.45764327, 0.64200014], 
        ...]
    """
    output_tuples = []
    bgn = None
    frame_disappear = None
    offset_occur = None

    for i in range(onset_output.shape[0]):
        if onset_output[i] == 1:
            """Onset detected"""
            if bgn:
                """Consecutive onsets. E.g., pedal is not released, but two 
                consecutive notes being played."""
                fin = max(i - 1, 0)
                output_tuples.append([bgn, fin, onset_shift_output[bgn], 
                    0, velocity_output[bgn]])
                frame_disappear, offset_occur = None, None
            bgn = i

        if bgn and i > bgn:
            """If onset found, then search offset"""
            if frame_output[i] <= frame_threshold and not frame_disappear:
                """Frame disappear detected"""
                frame_disappear = i

            if offset_output[i] == 1 and not offset_occur:
                """Offset detected"""
                offset_occur = i

            if frame_disappear:
                if offset_occur and offset_occur - bgn > frame_disappear - offset_occur:
                    """bgn --------- offset_occur --- frame_disappear"""
                    fin = offset_occur
                else:
                    """bgn --- offset_occur --------- frame_disappear"""
                    fin = frame_disappear
                output_tuples.append([bgn, fin, onset_shift_output[bgn], 
                    offset_shift_output[fin], velocity_output[bgn]])
                bgn, frame_disappear, offset_occur = None, None, None

            if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1):
                """Offset not detected"""
                fin = i
                output_tuples.append([bgn, fin, onset_shift_output[bgn], 
                    offset_shift_output[fin], velocity_output[bgn]])
                bgn, frame_disappear, offset_occur = None, None, None

    # Sort pairs by onsets
    output_tuples.sort(key=lambda pair: pair[0])

    return output_tuples


def pedal_detection_with_onset_offset_regress(frame_output, offset_output, 
    offset_shift_output, frame_threshold):
    """Process prediction array to pedal events information.
    
    Args:
      frame_output: (frames_num,)
      offset_output: (frames_num,)
      offset_shift_output: (frames_num,)
      frame_threshold: float

    Returns: 
      output_tuples: list of [bgn, fin, onset_shift, offset_shift], 
      e.g., [
        [1821, 1909, 0.4749851, 0.3048533], 
        [1909, 1947, 0.30730522, -0.45764327], 
        ...]
    """
    output_tuples = []
    bgn = None
    frame_disappear = None
    offset_occur = None

    for i in range(1, frame_output.shape[0]):
        if frame_output[i] >= frame_threshold and frame_output[i] > frame_output[i - 1]:
            """Pedal onset detected"""
            if bgn:
                pass
            else:
                bgn = i

        if bgn and i > bgn:
            """If onset found, then search offset"""
            if frame_output[i] <= frame_threshold and not frame_disappear:
                """Frame disappear detected"""
                frame_disappear = i

            if offset_output[i] == 1 and not offset_occur:
                """Offset detected"""
                offset_occur = i

            if offset_occur:
                fin = offset_occur
                output_tuples.append([bgn, fin, 0., offset_shift_output[fin]])
                bgn, frame_disappear, offset_occur = None, None, None

            if frame_disappear and i - frame_disappear >= 10:
                """offset not detected but frame disappear"""
                fin = frame_disappear
                output_tuples.append([bgn, fin, 0., offset_shift_output[fin]])
                bgn, frame_disappear, offset_occur = None, None, None

    # Sort pairs by onsets
    output_tuples.sort(key=lambda pair: pair[0])

    return output_tuples


###### Google's onsets and frames post processing. Only used for comparison ######
def onsets_frames_note_detection(frame_output, onset_output, offset_output, 
    velocity_output, threshold):
    """Process pedal prediction matrices to note events information. onset_ouput 
    is used to detect the presence of notes. frame_output is used to detect the 
    offset of notes.
    
    Args:
      frame_output: (frames_num,)
      onset_output: (frames_num,)
      threshold: float
    
    Returns: 
      bgn_fin_pairs: list of [bgn, fin, velocity]. E.g. 
        [[1821, 1909, 0.47498, 0.72119445], 
         [1909, 1947, 0.30730522, 0.64200014], 
         ...]
    """
    output_tuples = []

    loct = None
    for i in range(onset_output.shape[0]):
        # Use onset_output is used to detect the presence of notes
        if onset_output[i] > threshold:
            if loct:
                output_tuples.append([loct, i, velocity_output[loct]])
            loct = i
        if loct and i > loct:
            # Use frame_output is used to detect the offset of notes
            if frame_output[i] <= threshold:
                output_tuples.append([loct, i, velocity_output[loct]])
                loct = None

    output_tuples.sort(key=lambda pair: pair[0])

    return output_tuples


def onsets_frames_pedal_detection(frame_output, offset_output, frame_threshold):
    """Process pedal prediction matrices to pedal events information.
    
    Args:
      frame_output: (frames_num,)
      offset_output: (frames_num,)
      offset_shift_output: (frames_num,)
      frame_threshold: float

    Returns: 
      output_tuples: list of [bgn, fin], 
      e.g., [
        [1821, 1909], 
        [1909, 1947], 
        ...]
    """
    output_tuples = []
    bgn = None
    frame_disappear = None
    offset_occur = None

    for i in range(1, frame_output.shape[0]):
        if frame_output[i] >= frame_threshold and frame_output[i] > frame_output[i - 1]:
            if bgn:
                pass
            else:
                bgn = i

        if bgn and i > bgn:
            """If onset found, then search offset"""
            if frame_output[i] <= frame_threshold and not frame_disappear:
                """Frame disappear detected"""
                frame_disappear = i

            if offset_output[i] == 1 and not offset_occur:
                """Offset detected"""
                offset_occur = i

            if offset_occur:
                fin = offset_occur
                output_tuples.append([bgn, fin])
                bgn, frame_disappear, offset_occur = None, None, None

            if frame_disappear and i - frame_disappear >= 10:
                """offset not detected but frame disappear"""
                fin = frame_disappear
                output_tuples.append([bgn, fin])
                bgn, frame_disappear, offset_occur = None, None, None

    # Sort pairs by onsets
    output_tuples.sort(key=lambda pair: pair[0])

    return output_tuples

## utilities.py ----------------------------------------------------------------------------------------------------------------------

In [7]:
import os
import logging
import h5py
import soundfile
import librosa
import audioread
import numpy as np
import pandas as pd
import csv
import datetime
import collections
import pickle
from mido import MidiFile


def create_folder(fd):
    if not os.path.exists(fd):
        os.makedirs(fd)
        
        
def get_filename(path):
    path = os.path.realpath(path)
    na_ext = path.split('/')[-1]
    na = os.path.splitext(na_ext)[0]
    return na


def traverse_folder(folder):
    paths = []
    names = []
    
    for root, dirs, files in os.walk(folder):
        for name in files:
            filepath = os.path.join(root, name)
            names.append(name)
            paths.append(filepath)
            
    return names, paths


def note_to_freq(piano_note):
    return 2 ** ((piano_note - 39) / 12) * 440

    
def create_logging(log_dir, filemode):
    create_folder(log_dir)
    i1 = 0

    while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))):
        i1 += 1
        
    log_path = os.path.join(log_dir, '{:04d}.log'.format(i1))
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
        datefmt='%a, %d %b %Y %H:%M:%S',
        filename=log_path,
        filemode=filemode)

    # Print to console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    
    return logging


def float32_to_int16(x):
    assert np.max(np.abs(x)) <= 1.
    return (x * 32767.).astype(np.int16)


def int16_to_float32(x):
    return (x / 32767.).astype(np.float32)
    

def pad_truncate_sequence(x, max_len):
    if len(x) < max_len:
        return np.concatenate((x, np.zeros(max_len - len(x))))
    else:
        return x[0 : max_len]


def read_metadata(csv_path):
    """Read metadata of MAESTRO dataset from csv file.

    Args:
      csv_path: str

    Returns:
      meta_dict, dict, e.g. {
        'canonical_composer': ['Alban Berg', ...], 
        'canonical_title': ['Sonata Op. 1', ...], 
        'split': ['train', ...], 
        'year': ['2018', ...]
        'midi_filename': ['2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.midi', ...], 
        'audio_filename': ['2018/MIDI-Unprocessed_Chamber3_MID--AUDIO_10_R3_2018_wav--1.wav', ...],
        'duration': [698.66116031, ...]}
    """

    with open(csv_path, 'r') as fr:
        reader = csv.reader(fr, delimiter=',')
        lines = list(reader)

    meta_dict = {'canonical_composer': [], 'canonical_title': [], 'split': [], 
        'year': [], 'midi_filename': [], 'audio_filename': [], 'duration': []}

    for n in range(1, len(lines)):
        meta_dict['canonical_composer'].append(lines[n][0])
        meta_dict['canonical_title'].append(lines[n][1])
        meta_dict['split'].append(lines[n][2])
        meta_dict['year'].append(lines[n][3])
        meta_dict['midi_filename'].append(lines[n][4])
        meta_dict['audio_filename'].append(lines[n][5])
        meta_dict['duration'].append(float(lines[n][6]))

    for key in meta_dict.keys():
        meta_dict[key] = np.array(meta_dict[key])
    
    return meta_dict


def read_midi(midi_path):
    """Parse MIDI file.

    Args:
      midi_path: str

    Returns:
      midi_dict: dict, e.g. {
        'midi_event': [
            'program_change channel=0 program=0 time=0', 
            'control_change channel=0 control=64 value=127 time=0', 
            'control_change channel=0 control=64 value=63 time=236', 
            ...],
        'midi_event_time': [0., 0, 0.98307292, ...]}
    """

    midi_file = MidiFile(midi_path)
    ticks_per_beat = midi_file.ticks_per_beat

    assert len(midi_file.tracks) == 2
    """The first track contains tempo, time signature. The second track 
    contains piano events."""

    microseconds_per_beat = midi_file.tracks[0][0].tempo
    beats_per_second = 1e6 / microseconds_per_beat
    ticks_per_second = ticks_per_beat * beats_per_second

    message_list = []

    ticks = 0
    time_in_second = []

    for message in midi_file.tracks[1]:
        message_list.append(str(message))
        ticks += message.time
        time_in_second.append(ticks / ticks_per_second)

    midi_dict = {
        'midi_event': np.array(message_list), 
        'midi_event_time': np.array(time_in_second)}

    return midi_dict


def read_maps_midi(midi_path):
    """Parse MIDI file of MAPS dataset. Not used anymore.

    Args:
      midi_path: str

    Returns:
      midi_dict: dict, e.g. {
        'midi_event': [
            '<meta message set_tempo tempo=439440 time=0>',
            'control_change channel=0 control=64 value=0 time=0',
            'control_change channel=0 control=64 value=0 time=7531',
            ...],
        'midi_event_time': [0., 0.53200309, 0.53200309, ...]}
    """

    midi_file = MidiFile(midi_path)
    ticks_per_beat = midi_file.ticks_per_beat

    assert len(midi_file.tracks) == 1

    microseconds_per_beat = midi_file.tracks[0][0].tempo
    beats_per_second = 1e6 / microseconds_per_beat
    ticks_per_second = ticks_per_beat * beats_per_second

    message_list = []

    ticks = 0
    time_in_second = []

    for message in midi_file.tracks[0]:
        message_list.append(str(message))
        ticks += message.time
        time_in_second.append(ticks / ticks_per_second)

    midi_dict = {
        'midi_event': np.array(message_list), 
        'midi_event_time': np.array(time_in_second)}

    return midi_dict


class TargetProcessor(object):
    def __init__(self, segment_seconds, frames_per_second, begin_note, 
        classes_num):
        """Class for processing MIDI events to target.

        Args:
          segment_seconds: float
          frames_per_second: int
          begin_note: int, A0 MIDI note of a piano
          classes_num: int
        """
        self.segment_seconds = segment_seconds
        self.frames_per_second = frames_per_second
        self.begin_note = begin_note
        self.classes_num = classes_num
        self.max_piano_note = self.classes_num - 1

    def process(self, start_time, midi_events_time, midi_events, 
        extend_pedal=True, note_shift=0):
        """Process MIDI events of an audio segment to target for training, 
        includes: 
        1. Parse MIDI events
        2. Prepare note targets
        3. Prepare pedal targets

        Args:
          start_time: float, start time of a segment
          midi_events_time: list of float, times of MIDI events of a recording, 
            e.g. [0, 3.3, 5.1, ...]
          midi_events: list of str, MIDI events of a recording, e.g.
            ['note_on channel=0 note=75 velocity=37 time=14',
             'control_change channel=0 control=64 value=54 time=20',
             ...]
          extend_pedal, bool, True: Notes will be set to ON until pedal is 
            released. False: Ignore pedal events.

        Returns:
          target_dict: {
            'onset_roll': (frames_num, classes_num), 
            'offset_roll': (frames_num, classes_num), 
            'reg_onset_roll': (frames_num, classes_num), 
            'reg_offset_roll': (frames_num, classes_num), 
            'frame_roll': (frames_num, classes_num), 
            'velocity_roll': (frames_num, classes_num), 
            'mask_roll':  (frames_num, classes_num), 
            'pedal_onset_roll': (frames_num,), 
            'pedal_offset_roll': (frames_num,), 
            'reg_pedal_onset_roll': (frames_num,), 
            'reg_pedal_offset_roll': (frames_num,), 
            'pedal_frame_roll': (frames_num,)}

          note_events: list of dict, e.g. [
            {'midi_note': 51, 'onset_time': 696.64, 'offset_time': 697.00, 'velocity': 44}, 
            {'midi_note': 58, 'onset_time': 697.00, 'offset_time': 697.19, 'velocity': 50}
            ...]

          pedal_events: list of dict, e.g. [
            {'onset_time': 149.37, 'offset_time': 150.35}, 
            {'onset_time': 150.54, 'offset_time': 152.06}, 
            ...]
        """

        # ------ 1. Parse MIDI events ------
        # Search the begin index of a segment
        for bgn_idx, event_time in enumerate(midi_events_time):
            if event_time > start_time:
                break
        """E.g., start_time: 709.0, bgn_idx: 18003, event_time: 709.0146"""

        # Search the end index of a segment
        for fin_idx, event_time in enumerate(midi_events_time):
            if event_time > start_time + self.segment_seconds:
                break
        """E.g., start_time: 709.0, bgn_idx: 18196, event_time: 719.0115"""

        note_events = []
        """E.g. [
            {'midi_note': 51, 'onset_time': 696.63544, 'offset_time': 696.9948, 'velocity': 44}, 
            {'midi_note': 58, 'onset_time': 696.99585, 'offset_time': 697.18646, 'velocity': 50}
            ...]"""

        pedal_events = []
        """E.g. [
            {'onset_time': 696.46875, 'offset_time': 696.62604}, 
            {'onset_time': 696.8063, 'offset_time': 698.50836}, 
            ...]"""

        buffer_dict = {}    # Used to store onset of notes to be paired with offsets
        pedal_dict = {}     # Used to store onset of pedal to be paired with offset of pedal

        # Backtrack bgn_idx to earlier indexes: ex_bgn_idx, which is used for 
        # searching cross segment pedal and note events. E.g.: bgn_idx: 1149, 
        # ex_bgn_idx: 981
        _delta = int((fin_idx - bgn_idx) * 1.)  
        ex_bgn_idx = max(bgn_idx - _delta, 0)
        
        for i in range(ex_bgn_idx, fin_idx):
            # Parse MIDI messiage
            attribute_list = midi_events[i].split(' ')

            # Note
            if attribute_list[0] in ['note_on', 'note_off']:
                """E.g. attribute_list: ['note_on', 'channel=0', 'note=41', 'velocity=0', 'time=10']"""

                midi_note = int(attribute_list[2].split('=')[1])
                velocity = int(attribute_list[3].split('=')[1])

                # Onset
                if attribute_list[0] == 'note_on' and velocity > 0:
                    buffer_dict[midi_note] = {
                        'onset_time': midi_events_time[i], 
                        'velocity': velocity}

                # Offset
                else:
                    if midi_note in buffer_dict.keys():
                        note_events.append({
                            'midi_note': midi_note, 
                            'onset_time': buffer_dict[midi_note]['onset_time'], 
                            'offset_time': midi_events_time[i], 
                            'velocity': buffer_dict[midi_note]['velocity']})
                        del buffer_dict[midi_note]

            # Pedal
            elif attribute_list[0] == 'control_change' and attribute_list[2] == 'control=64':
                """control=64 corresponds to pedal MIDI event. E.g. 
                attribute_list: ['control_change', 'channel=0', 'control=64', 'value=45', 'time=43']"""

                ped_value = int(attribute_list[3].split('=')[1])
                if ped_value >= 64:
                    if 'onset_time' not in pedal_dict:
                        pedal_dict['onset_time'] = midi_events_time[i]
                else:
                    if 'onset_time' in pedal_dict:
                        pedal_events.append({
                            'onset_time': pedal_dict['onset_time'], 
                            'offset_time': midi_events_time[i]})
                        pedal_dict = {}

        # Add unpaired onsets to events
        for midi_note in buffer_dict.keys():
            note_events.append({
                'midi_note': midi_note, 
                'onset_time': buffer_dict[midi_note]['onset_time'], 
                'offset_time': start_time + self.segment_seconds, 
                'velocity': buffer_dict[midi_note]['velocity']})

        # Add unpaired pedal onsets to data
        if 'onset_time' in pedal_dict.keys():
            pedal_events.append({
                'onset_time': pedal_dict['onset_time'], 
                'offset_time': start_time + self.segment_seconds})

        # Set notes to ON until pedal is released
        if extend_pedal:
            note_events = self.extend_pedal(note_events, pedal_events)
        
        # Prepare targets
        frames_num = int(round(self.segment_seconds * self.frames_per_second)) + 1
        onset_roll = np.zeros((frames_num, self.classes_num))
        offset_roll = np.zeros((frames_num, self.classes_num))
        reg_onset_roll = np.ones((frames_num, self.classes_num))
        reg_offset_roll = np.ones((frames_num, self.classes_num))
        frame_roll = np.zeros((frames_num, self.classes_num))
        velocity_roll = np.zeros((frames_num, self.classes_num))
        mask_roll = np.ones((frames_num, self.classes_num))
        """mask_roll is used for masking out cross segment notes"""

        pedal_onset_roll = np.zeros(frames_num)
        pedal_offset_roll = np.zeros(frames_num)
        reg_pedal_onset_roll = np.ones(frames_num)
        reg_pedal_offset_roll = np.ones(frames_num)
        pedal_frame_roll = np.zeros(frames_num)

        # ------ 2. Get note targets ------
        # Process note events to target
        for note_event in note_events:
            """note_event: e.g., {'midi_note': 60, 'onset_time': 722.0719, 'offset_time': 722.47815, 'velocity': 103}"""

            piano_note = np.clip(note_event['midi_note'] - self.begin_note + note_shift, 0, self.max_piano_note) 
            """There are 88 keys on a piano"""

            if 0 <= piano_note <= self.max_piano_note:
                bgn_frame = int(round((note_event['onset_time'] - start_time) * self.frames_per_second))
                fin_frame = int(round((note_event['offset_time'] - start_time) * self.frames_per_second))

                if fin_frame >= 0:
                    frame_roll[max(bgn_frame, 0) : fin_frame + 1, piano_note] = 1

                    offset_roll[fin_frame, piano_note] = 1
                    velocity_roll[max(bgn_frame, 0) : fin_frame + 1, piano_note] = note_event['velocity']

                    # Vector from the center of a frame to ground truth offset
                    reg_offset_roll[fin_frame, piano_note] = \
                        (note_event['offset_time'] - start_time) - (fin_frame / self.frames_per_second)

                    if bgn_frame >= 0:
                        onset_roll[bgn_frame, piano_note] = 1

                        # Vector from the center of a frame to ground truth onset
                        reg_onset_roll[bgn_frame, piano_note] = \
                            (note_event['onset_time'] - start_time) - (bgn_frame / self.frames_per_second)
                
                    # Mask out segment notes
                    else:
                        mask_roll[: fin_frame + 1, piano_note] = 0

        for k in range(self.classes_num):
            """Get regression targets"""
            reg_onset_roll[:, k] = self.get_regression(reg_onset_roll[:, k])
            reg_offset_roll[:, k] = self.get_regression(reg_offset_roll[:, k])

        # Process unpaired onsets to target
        for midi_note in buffer_dict.keys():
            piano_note = np.clip(midi_note - self.begin_note + note_shift, 0, self.max_piano_note)
            if 0 <= piano_note <= self.max_piano_note:
                bgn_frame = int(round((buffer_dict[midi_note]['onset_time'] - start_time) * self.frames_per_second))
                mask_roll[bgn_frame :, piano_note] = 0     

        # ------ 3. Get pedal targets ------
        # Process pedal events to target
        for pedal_event in pedal_events:
            bgn_frame = int(round((pedal_event['onset_time'] - start_time) * self.frames_per_second))
            fin_frame = int(round((pedal_event['offset_time'] - start_time) * self.frames_per_second))

            if fin_frame >= 0:
                pedal_frame_roll[max(bgn_frame, 0) : fin_frame + 1] = 1

                pedal_offset_roll[fin_frame] = 1
                reg_pedal_offset_roll[fin_frame] = \
                    (pedal_event['offset_time'] - start_time) - (fin_frame / self.frames_per_second)

                if bgn_frame >= 0:
                    pedal_onset_roll[bgn_frame] = 1
                    reg_pedal_onset_roll[bgn_frame] = \
                        (pedal_event['onset_time'] - start_time) - (bgn_frame / self.frames_per_second)

        # Get regresssion padal targets
        reg_pedal_onset_roll = self.get_regression(reg_pedal_onset_roll)
        reg_pedal_offset_roll = self.get_regression(reg_pedal_offset_roll)

        target_dict = {
            'onset_roll': onset_roll, 'offset_roll': offset_roll,
            'reg_onset_roll': reg_onset_roll, 'reg_offset_roll': reg_offset_roll,
            'frame_roll': frame_roll, 'velocity_roll': velocity_roll, 
            'mask_roll': mask_roll, 'reg_pedal_onset_roll': reg_pedal_onset_roll, 
            'pedal_onset_roll': pedal_onset_roll, 'pedal_offset_roll': pedal_offset_roll, 
            'reg_pedal_offset_roll': reg_pedal_offset_roll, 'pedal_frame_roll': pedal_frame_roll
            }

        return target_dict, note_events, pedal_events

    def extend_pedal(self, note_events, pedal_events):
        """Update the offset of all notes until pedal is released.

        Args:
          note_events: list of dict, e.g., [
            {'midi_note': 51, 'onset_time': 696.63544, 'offset_time': 696.9948, 'velocity': 44}, 
            {'midi_note': 58, 'onset_time': 696.99585, 'offset_time': 697.18646, 'velocity': 50}
            ...]
          pedal_events: list of dict, e.g., [
            {'onset_time': 696.46875, 'offset_time': 696.62604}, 
            {'onset_time': 696.8063, 'offset_time': 698.50836}, 
            ...]

        Returns:
          ex_note_events: list of dict, e.g., [
            {'midi_note': 51, 'onset_time': 696.63544, 'offset_time': 696.9948, 'velocity': 44}, 
            {'midi_note': 58, 'onset_time': 696.99585, 'offset_time': 697.18646, 'velocity': 50}
            ...]
        """
        note_events = collections.deque(note_events)
        pedal_events = collections.deque(pedal_events)
        ex_note_events = []

        idx = 0     # Index of note events
        while pedal_events: # Go through all pedal events
            pedal_event = pedal_events.popleft()
            buffer_dict = {}    # keys: midi notes, value for each key: event index

            while note_events:
                note_event = note_events.popleft()

                # If a note offset is between the onset and offset of a pedal, 
                # Then set the note offset to when the pedal is released.
                if pedal_event['onset_time'] < note_event['offset_time'] < pedal_event['offset_time']:
                    
                    midi_note = note_event['midi_note']

                    if midi_note in buffer_dict.keys():
                        """Multiple same note inside a pedal"""
                        _idx = buffer_dict[midi_note]
                        del buffer_dict[midi_note]
                        ex_note_events[_idx]['offset_time'] = note_event['onset_time']

                    # Set note offset to pedal offset
                    note_event['offset_time'] = pedal_event['offset_time']
                    buffer_dict[midi_note] = idx
                
                ex_note_events.append(note_event)
                idx += 1

                # Break loop and pop next pedal
                if note_event['offset_time'] > pedal_event['offset_time']:
                    break

        while note_events:
            """Append left notes"""
            ex_note_events.append(note_events.popleft())

        return ex_note_events

    def get_regression(self, input):
        """Get regression target. See Fig. 2 of [1] for an example.
        [1] Q. Kong, et al., High-resolution Piano Transcription with Pedals by 
        Regressing Onsets and Offsets Times, 2020.

        input:
          input: (frames_num,)

        Returns: (frames_num,), e.g., [0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 0.9, 0.7, 0.5, 0.3, 0.1, 0, 0, ...]
        """
        step = 1. / self.frames_per_second
        output = np.ones_like(input)
        
        locts = np.where(input < 0.5)[0] 
        if len(locts) > 0:
            for t in range(0, locts[0]):
                output[t] = step * (t - locts[0]) - input[locts[0]]

            for i in range(0, len(locts) - 1):
                for t in range(locts[i], (locts[i] + locts[i + 1]) // 2):
                    output[t] = step * (t - locts[i]) - input[locts[i]]

                for t in range((locts[i] + locts[i + 1]) // 2, locts[i + 1]):
                    output[t] = step * (t - locts[i + 1]) - input[locts[i]]

            for t in range(locts[-1], len(input)):
                output[t] = step * (t - locts[-1]) - input[locts[-1]]

        output = np.clip(np.abs(output), 0., 0.05) * 20
        output = (1. - output)

        return output


def write_events_to_midi(start_time, note_events, pedal_events, midi_path):
    """Write out note events to MIDI file.

    Args:
      start_time: float
      note_events: list of dict, e.g. [
        {'midi_note': 51, 'onset_time': 696.63544, 'offset_time': 696.9948, 'velocity': 44}, 
        {'midi_note': 58, 'onset_time': 696.99585, 'offset_time': 697.18646, 'velocity': 50}
        ...]
      midi_path: str
    """
    from mido import Message, MidiFile, MidiTrack, MetaMessage
    
    # This configuration is the same as MIDIs in MAESTRO dataset
    ticks_per_beat = 384
    beats_per_second = 2
    ticks_per_second = ticks_per_beat * beats_per_second
    microseconds_per_beat = int(1e6 // beats_per_second)

    midi_file = MidiFile()
    midi_file.ticks_per_beat = ticks_per_beat

    # Track 0
    track0 = MidiTrack()
    track0.append(MetaMessage('set_tempo', tempo=microseconds_per_beat, time=0))
    track0.append(MetaMessage('time_signature', numerator=4, denominator=4, time=0))
    track0.append(MetaMessage('end_of_track', time=1))
    midi_file.tracks.append(track0)

    # Track 1
    track1 = MidiTrack()
    
    # Message rolls of MIDI
    message_roll = []

    for note_event in note_events:
        # Onset
        message_roll.append({
            'time': note_event['onset_time'], 
            'midi_note': note_event['midi_note'], 
            'velocity': note_event['velocity']})

        # Offset
        message_roll.append({
            'time': note_event['offset_time'], 
            'midi_note': note_event['midi_note'], 
            'velocity': 0})

    if pedal_events:
        for pedal_event in pedal_events:
            message_roll.append({'time': pedal_event['onset_time'], 'control_change': 64, 'value': 127})
            message_roll.append({'time': pedal_event['offset_time'], 'control_change': 64, 'value': 0})

    # Sort MIDI messages by time
    message_roll.sort(key=lambda note_event: note_event['time'])

    previous_ticks = 0
    for message in message_roll:
        this_ticks = int((message['time'] - start_time) * ticks_per_second)
        if this_ticks >= 0:
            diff_ticks = this_ticks - previous_ticks
            previous_ticks = this_ticks
            if 'midi_note' in message.keys():
                track1.append(Message('note_on', note=message['midi_note'], velocity=message['velocity'], time=diff_ticks))
            elif 'control_change' in message.keys():
                track1.append(Message('control_change', channel=0, control=message['control_change'], value=message['value'], time=diff_ticks))
    track1.append(MetaMessage('end_of_track', time=1))
    midi_file.tracks.append(track1)

    midi_file.save(midi_path)


def plot_waveform_midi_targets(data_dict, start_time, note_events):
    """For debugging. Write out waveform, MIDI and plot targets for an 
    audio segment.

    Args:
      data_dict: {
        'waveform': (samples_num,),
        'onset_roll': (frames_num, classes_num), 
        'offset_roll': (frames_num, classes_num), 
        'reg_onset_roll': (frames_num, classes_num), 
        'reg_offset_roll': (frames_num, classes_num), 
        'frame_roll': (frames_num, classes_num), 
        'velocity_roll': (frames_num, classes_num), 
        'mask_roll':  (frames_num, classes_num), 
        'reg_pedal_onset_roll': (frames_num,),
        'reg_pedal_offset_roll': (frames_num,),
        'pedal_frame_roll': (frames_num,)}
      start_time: float
      note_events: list of dict, e.g. [
        {'midi_note': 51, 'onset_time': 696.63544, 'offset_time': 696.9948, 'velocity': 44}, 
        {'midi_note': 58, 'onset_time': 696.99585, 'offset_time': 697.18646, 'velocity': 50}
    """
    import matplotlib.pyplot as plt

    create_folder('debug')
    audio_path = 'debug/debug.wav'
    midi_path = 'debug/debug.mid'
    fig_path = 'debug/debug.pdf'

    librosa.output.write_wav(audio_path, data_dict['waveform'], sr=config.sample_rate)
    write_events_to_midi(start_time, note_events, midi_path)
    x = librosa.core.stft(y=data_dict['waveform'], n_fft=2048, hop_length=160, window='hann', center=True)
    x = np.abs(x) ** 2

    fig, axs = plt.subplots(11, 1, sharex=True, figsize=(30, 30))
    fontsize = 20
    axs[0].matshow(np.log(x), origin='lower', aspect='auto', cmap='jet')
    axs[1].matshow(data_dict['onset_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[2].matshow(data_dict['offset_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[3].matshow(data_dict['reg_onset_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[4].matshow(data_dict['reg_offset_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[5].matshow(data_dict['frame_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[6].matshow(data_dict['velocity_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[7].matshow(data_dict['mask_roll'].T, origin='lower', aspect='auto', cmap='jet')
    axs[8].matshow(data_dict['reg_pedal_onset_roll'][:, None].T, origin='lower', aspect='auto', cmap='jet')
    axs[9].matshow(data_dict['reg_pedal_offset_roll'][:, None].T, origin='lower', aspect='auto', cmap='jet')
    axs[10].matshow(data_dict['pedal_frame_roll'][:, None].T, origin='lower', aspect='auto', cmap='jet')
    axs[0].set_title('Log spectrogram', fontsize=fontsize)
    axs[1].set_title('onset_roll', fontsize=fontsize)
    axs[2].set_title('offset_roll', fontsize=fontsize)
    axs[3].set_title('reg_onset_roll', fontsize=fontsize)
    axs[4].set_title('reg_offset_roll', fontsize=fontsize)
    axs[5].set_title('frame_roll', fontsize=fontsize)
    axs[6].set_title('velocity_roll', fontsize=fontsize)
    axs[7].set_title('mask_roll', fontsize=fontsize)
    axs[8].set_title('reg_pedal_onset_roll', fontsize=fontsize)
    axs[9].set_title('reg_pedal_offset_roll', fontsize=fontsize)
    axs[10].set_title('pedal_frame_roll', fontsize=fontsize)
    axs[10].set_xlabel('frames')
    axs[10].xaxis.set_label_position('bottom')
    axs[10].xaxis.set_ticks_position('bottom')
    plt.tight_layout(1, 1, 1)
    plt.savefig(fig_path)

    print('Write out to {}, {}, {}!'.format(audio_path, midi_path, fig_path))


class RegressionPostProcessor(object):
    def __init__(self, frames_per_second, classes_num, onset_threshold, 
        offset_threshold, frame_threshold, pedal_offset_threshold):
        """Postprocess the output probabilities of a transription model to MIDI 
        events.

        Args:
          frames_per_second: int
          classes_num: int
          onset_threshold: float
          offset_threshold: float
          frame_threshold: float
          pedal_offset_threshold: float
        """
        self.frames_per_second = frames_per_second
        self.classes_num = classes_num
        self.onset_threshold = onset_threshold
        self.offset_threshold = offset_threshold
        self.frame_threshold = frame_threshold
        self.pedal_offset_threshold = pedal_offset_threshold
        self.begin_note = config.begin_note
        self.velocity_scale = config.velocity_scale

    def output_dict_to_midi_events(self, output_dict):
        """Main function. Post process model outputs to MIDI events.

        Args:
          output_dict: {
            'reg_onset_output': (segment_frames, classes_num), 
            'reg_offset_output': (segment_frames, classes_num), 
            'frame_output': (segment_frames, classes_num), 
            'velocity_output': (segment_frames, classes_num), 
            'reg_pedal_onset_output': (segment_frames, 1), 
            'reg_pedal_offset_output': (segment_frames, 1), 
            'pedal_frame_output': (segment_frames, 1)}

        Outputs:
          est_note_events: list of dict, e.g. [
            {'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83}, 
            {'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]

          est_pedal_events: list of dict, e.g. [
            {'onset_time': 0.17, 'offset_time': 0.96}, 
            {'osnet_time': 1.17, 'offset_time': 2.65}]
        """

        # Post process piano note outputs to piano note and pedal events information
        (est_on_off_note_vels, est_pedal_on_offs) = \
            self.output_dict_to_note_pedal_arrays(output_dict)
        """est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity], 
        est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""

        # Reformat notes to MIDI events
        est_note_events = self.detected_notes_to_events(est_on_off_note_vels)

        if est_pedal_on_offs is None:
            est_pedal_events = None
        else:
            est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)

        return est_note_events, est_pedal_events

    def output_dict_to_note_pedal_arrays(self, output_dict):
        """Postprocess the output probabilities of a transription model to MIDI 
        events.

        Args:
          output_dict: dict, {
            'reg_onset_output': (frames_num, classes_num), 
            'reg_offset_output': (frames_num, classes_num), 
            'frame_output': (frames_num, classes_num), 
            'velocity_output': (frames_num, classes_num), 
            ...}

        Returns:
          est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time, 
            offset_time, piano_note and velocity. E.g. [
             [39.74, 39.87, 27, 0.65], 
             [11.98, 12.11, 33, 0.69], 
             ...]

          est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time 
            and offset_time. E.g. [
             [0.17, 0.96], 
             [1.17, 2.65], 
             ...]
        """

        # ------ 1. Process regression outputs to binarized outputs ------
        # For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.]
        # will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]

        # Calculate binarized onset output from regression output
        (onset_output, onset_shift_output) = \
            self.get_binarized_output_from_regression(
                reg_output=output_dict['reg_onset_output'], 
                threshold=self.onset_threshold, neighbour=2)

        output_dict['onset_output'] = onset_output  # Values are 0 or 1
        output_dict['onset_shift_output'] = onset_shift_output  

        # Calculate binarized offset output from regression output
        (offset_output, offset_shift_output) = \
            self.get_binarized_output_from_regression(
                reg_output=output_dict['reg_offset_output'], 
                threshold=self.offset_threshold, neighbour=4)

        output_dict['offset_output'] = offset_output  # Values are 0 or 1
        output_dict['offset_shift_output'] = offset_shift_output

        if 'reg_pedal_onset_output' in output_dict.keys():
            """Pedal onsets are not used in inference. Instead, frame-wise pedal
            predictions are used to detect onsets. We empirically found this is 
            more accurate to detect pedal onsets."""
            pass

        if 'reg_pedal_offset_output' in output_dict.keys():
            # Calculate binarized pedal offset output from regression output
            (pedal_offset_output, pedal_offset_shift_output) = \
                self.get_binarized_output_from_regression(
                    reg_output=output_dict['reg_pedal_offset_output'], 
                    threshold=self.pedal_offset_threshold, neighbour=4)

            output_dict['pedal_offset_output'] = pedal_offset_output  # Values are 0 or 1
            output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output

        # ------ 2. Process matrices results to event results ------
        # Detect piano notes from output_dict
        est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict)

        if 'reg_pedal_onset_output' in output_dict.keys():
            # Detect piano pedals from output_dict
            est_pedal_on_offs = self.output_dict_to_detected_pedals(output_dict)
 
        else:
            est_pedal_on_offs = None    

        return est_on_off_note_vels, est_pedal_on_offs

    def get_binarized_output_from_regression(self, reg_output, threshold, neighbour):
        """Calculate binarized output and shifts of onsets or offsets from the
        regression results.

        Args:
          reg_output: (frames_num, classes_num)
          threshold: float
          neighbour: int

        Returns:
          binary_output: (frames_num, classes_num)
          shift_output: (frames_num, classes_num)
        """
        binary_output = np.zeros_like(reg_output)
        shift_output = np.zeros_like(reg_output)
        (frames_num, classes_num) = reg_output.shape
        
        for k in range(classes_num):
            x = reg_output[:, k]
            for n in range(neighbour, frames_num - neighbour):
                if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour):
                    binary_output[n, k] = 1

                    """See Section III-D in [1] for deduction.
                    [1] Q. Kong, et al., High-resolution Piano Transcription 
                    with Pedals by Regressing Onsets and Offsets Times, 2020."""
                    if x[n - 1] > x[n + 1]:
                        shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2
                    else:
                        shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2
                    shift_output[n, k] = shift

        return binary_output, shift_output

    def is_monotonic_neighbour(self, x, n, neighbour):
        """Detect if values are monotonic in both side of x[n].

        Args:
          x: (frames_num,)
          n: int
          neighbour: int

        Returns:
          monotonic: bool
        """
        monotonic = True
        for i in range(neighbour):
            if x[n - i] < x[n - i - 1]:
                monotonic = False
            if x[n + i] < x[n + i + 1]:
                monotonic = False

        return monotonic

    def output_dict_to_detected_notes(self, output_dict):
        """Postprocess output_dict to piano notes.

        Args:
          output_dict: dict, e.g. {
            'onset_output': (frames_num, classes_num),
            'onset_shift_output': (frames_num, classes_num),
            'offset_output': (frames_num, classes_num),
            'offset_shift_output': (frames_num, classes_num),
            'frame_output': (frames_num, classes_num),
            'onset_output': (frames_num, classes_num),
            ...}

        Returns:
          est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets, 
          MIDI notes and velocities. E.g.,
            [[39.7375, 39.7500, 27., 0.6638],
             [11.9824, 12.5000, 33., 0.6892],
             ...]
        """
        est_tuples = []
        est_midi_notes = []
        classes_num = output_dict['frame_output'].shape[-1]
 
        for piano_note in range(classes_num):
            """Detect piano notes"""
            est_tuples_per_note = note_detection_with_onset_offset_regress(
                frame_output=output_dict['frame_output'][:, piano_note], 
                onset_output=output_dict['onset_output'][:, piano_note], 
                onset_shift_output=output_dict['onset_shift_output'][:, piano_note], 
                offset_output=output_dict['offset_output'][:, piano_note], 
                offset_shift_output=output_dict['offset_shift_output'][:, piano_note], 
                velocity_output=output_dict['velocity_output'][:, piano_note], 
                frame_threshold=self.frame_threshold)
            
            est_tuples += est_tuples_per_note
            est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)

        est_tuples = np.array(est_tuples)   # (notes, 5)
        """(notes, 5), the five columns are onset, offset, onset_shift, 
        offset_shift and normalized_velocity"""

        est_midi_notes = np.array(est_midi_notes) # (notes,)

        onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
        offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
        velocities = est_tuples[:, 4]
        
        est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
        """(notes, 3), the three columns are onset_times, offset_times and velocity."""

        est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)

        return est_on_off_note_vels

    def output_dict_to_detected_pedals(self, output_dict):
        """Postprocess output_dict to piano pedals.

        Args:
          output_dict: dict, e.g. {
            'pedal_frame_output': (frames_num,),
            'pedal_offset_output': (frames_num,),
            'pedal_offset_shift_output': (frames_num,),
            ...}

        Returns:
          est_on_off: (notes, 2), the two columns are pedal onsets and pedal
            offsets. E.g.,
              [[0.1800, 0.9669],
               [1.1400, 2.6458],
               ...]
        """
        frames_num = output_dict['pedal_frame_output'].shape[0]
        
        est_tuples = pedal_detection_with_onset_offset_regress(
            frame_output=output_dict['pedal_frame_output'][:, 0], 
            offset_output=output_dict['pedal_offset_output'][:, 0], 
            offset_shift_output=output_dict['pedal_offset_shift_output'][:, 0], 
            frame_threshold=0.5)

        est_tuples = np.array(est_tuples)
        """(notes, 2), the two columns are pedal onsets and pedal offsets"""
        
        if len(est_tuples) == 0:
            return np.array([])

        else:
            onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
            offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
            est_on_off = np.stack((onset_times, offset_times), axis=-1)
            est_on_off = est_on_off.astype(np.float32)
            return est_on_off

    def detected_notes_to_events(self, est_on_off_note_vels):
        """Reformat detected notes to midi events.

        Args:
          est_on_off_vels: (notes, 3), the three columns are onset_times, 
            offset_times and velocity. E.g.
            [[32.8376, 35.7700, 0.7932],
             [37.3712, 39.9300, 0.8058],
             ...]
        
        Returns:
          midi_events, list, e.g.,
            [{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
             {'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
             ...]
        """
        midi_events = []
        for i in range(est_on_off_note_vels.shape[0]):
            midi_events.append({
                'onset_time': est_on_off_note_vels[i][0], 
                'offset_time': est_on_off_note_vels[i][1], 
                'midi_note': int(est_on_off_note_vels[i][2]), 
                'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})

        return midi_events

    def detected_pedals_to_events(self, pedal_on_offs):
        """Reformat detected pedal onset and offsets to events.

        Args:
          pedal_on_offs: (notes, 2), the two columns are pedal onsets and pedal
          offsets. E.g., 
            [[0.1800, 0.9669],
             [1.1400, 2.6458],
             ...]

        Returns:
          pedal_events: list of dict, e.g.,
            [{'onset_time': 0.1800, 'offset_time': 0.9669}, 
             {'onset_time': 1.1400, 'offset_time': 2.6458},
             ...]
        """
        pedal_events = []
        for i in range(len(pedal_on_offs)):
            pedal_events.append({
                'onset_time': pedal_on_offs[i, 0], 
                'offset_time': pedal_on_offs[i, 1]})
        
        return pedal_events


class OnsetsFramesPostProcessor(object):
    def __init__(self, frames_per_second, classes_num):
        """Postprocess the Googl's onsets and frames system output. Only used
        for comparison.

        Args:
          frames_per_second: int
          classes_num: int
        """
        self.frames_per_second = frames_per_second
        self.classes_num = classes_num
        self.begin_note = config.begin_note
        self.velocity_scale = config.velocity_scale
        
        self.frame_threshold = 0.5
        self.onset_threshold = 0.1
        self.offset_threshold = 0.3

    def output_dict_to_midi_events(self, output_dict):
        """Main function. Post process model outputs to MIDI events.

        Args:
          output_dict: {
            'reg_onset_output': (segment_frames, classes_num), 
            'reg_offset_output': (segment_frames, classes_num), 
            'frame_output': (segment_frames, classes_num), 
            'velocity_output': (segment_frames, classes_num), 
            'reg_pedal_onset_output': (segment_frames, 1), 
            'reg_pedal_offset_output': (segment_frames, 1), 
            'pedal_frame_output': (segment_frames, 1)}

        Outputs:
          est_note_events: list of dict, e.g. [
            {'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83}, 
            {'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]

          est_pedal_events: list of dict, e.g. [
            {'onset_time': 0.17, 'offset_time': 0.96}, 
            {'osnet_time': 1.17, 'offset_time': 2.65}]
        """

        # Post process piano note outputs to piano note and pedal events information
        (est_on_off_note_vels, est_pedal_on_offs) = \
            self.output_dict_to_note_pedal_arrays(output_dict)
        """est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity], 
        est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""
        
        # Reformat notes to MIDI events
        est_note_events = self.detected_notes_to_events(est_on_off_note_vels)

        if est_pedal_on_offs is None:
            est_pedal_events = None
        else:
            est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)

        return est_note_events, est_pedal_events

    def output_dict_to_note_pedal_arrays(self, output_dict):
        """Postprocess the output probabilities of a transription model to MIDI 
        events.

        Args:
          output_dict: dict, {
            'reg_onset_output': (frames_num, classes_num), 
            'reg_offset_output': (frames_num, classes_num), 
            'frame_output': (frames_num, classes_num), 
            'velocity_output': (frames_num, classes_num), 
            ...}

        Returns:
          est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time, 
            offset_time, piano_note and velocity. E.g. [
             [39.74, 39.87, 27, 0.65], 
             [11.98, 12.11, 33, 0.69], 
             ...]

          est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time 
            and offset_time. E.g. [
             [0.17, 0.96], 
             [1.17, 2.65], 
             ...]
        """

        # Sharp onsets and offsets
        output_dict = self.sharp_output_dict(
            output_dict, onset_threshold=self.onset_threshold, 
            offset_threshold=self.offset_threshold)

        # Post process output_dict to piano notes
        est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict, 
            frame_threshold=self.frame_threshold)

        if 'reg_pedal_onset_output' in output_dict.keys():
            # Detect piano pedals from output_dict
            est_pedal_on_offs = self.output_dict_to_detected_pedals(output_dict)
 
        else:
            est_pedal_on_offs = None    

        return est_on_off_note_vels, est_pedal_on_offs

    def sharp_output_dict(self, output_dict, onset_threshold, offset_threshold):
        """Sharp onsets and offsets. E.g. when threshold=0.3, for a note, 
        [0, 0.1, 0.4, 0.7, 0, 0] will be sharped to [0, 0, 0, 1, 0, 0]
        [0., 0., 1., 0., 0., 0.]

        Args:
          output_dict: {
            'reg_onset_output': (frames_num, classes_num), 
            'reg_offset_output': (frames_num, classes_num), 
            ...}
          onset_threshold: float
          offset_threshold: float

        Returns:
          output_dict: {
            'onset_output': (frames_num, classes_num), 
            'offset_output': (frames_num, classes_num)}
        """
        if 'reg_onset_output' in output_dict.keys():
            output_dict['onset_output'] = self.sharp_output(
                output_dict['reg_onset_output'], 
                threshold=onset_threshold)

        if 'reg_offset_output' in output_dict.keys():
            output_dict['offset_output'] = self.sharp_output(
                output_dict['reg_offset_output'], 
                threshold=offset_threshold)

        return output_dict

    def sharp_output(self, input, threshold=0.3):
        """Used for sharping onset or offset. E.g. when threshold=0.3, for a note, 
        [0, 0.1, 0.4, 0.7, 0, 0] will be sharped to [0, 0, 0, 1, 0, 0]

        Args:
          input: (frames_num, classes_num)

        Returns:
          output: (frames_num, classes_num)
        """
        (frames_num, classes_num) = input.shape
        output = np.zeros_like(input)

        for piano_note in range(classes_num):
            loct = None
            for i in range(1, frames_num - 1):
                if input[i, piano_note] > threshold and input[i, piano_note] > input[i - 1, piano_note] and input[i, piano_note] > input[i + 1, piano_note]:
                    loct = i
                else:
                    if loct is not None:
                        output[loct, piano_note] = 1
                        loct = None

        return output

    def output_dict_to_detected_notes(self, output_dict, frame_threshold):
        """Postprocess output_dict to piano notes.

        Args:
          output_dict: dict, e.g. {
            'onset_output': (frames_num, classes_num),
            'onset_shift_output': (frames_num, classes_num),
            'offset_output': (frames_num, classes_num),
            'offset_shift_output': (frames_num, classes_num),
            'frame_output': (frames_num, classes_num),
            'onset_output': (frames_num, classes_num),
            ...}

        Returns:
          est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets, 
          MIDI notes and velocities. E.g.,
            [[39.7375, 39.7500, 27., 0.6638],
             [11.9824, 12.5000, 33., 0.6892],
             ...]
        """

        est_tuples = []
        est_midi_notes = []

        for piano_note in range(self.classes_num):
            
            est_tuples_per_note = onsets_frames_note_detection(
                frame_output=output_dict['frame_output'][:, piano_note], 
                onset_output=output_dict['onset_output'][:, piano_note], 
                offset_output=output_dict['offset_output'][:, piano_note], 
                velocity_output=output_dict['velocity_output'][:, piano_note], 
                threshold=frame_threshold)

            est_tuples += est_tuples_per_note
            est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)

        est_tuples = np.array(est_tuples)   # (notes, 3)
        """(notes, 5), the five columns are onset, offset, onset_shift, 
        offset_shift and normalized_velocity"""

        est_midi_notes = np.array(est_midi_notes) # (notes,)
        
        if len(est_midi_notes) == 0:
            return []
        else:
            onset_times = est_tuples[:, 0] / self.frames_per_second
            offset_times = est_tuples[:, 1] / self.frames_per_second
            velocities = est_tuples[:, 2]
        
            est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
            """(notes, 3), the three columns are onset_times, offset_times and velocity."""

            est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)

            return est_on_off_note_vels

    def output_dict_to_detected_pedals(self, output_dict):
        """Postprocess output_dict to piano pedals.

        Args:
          output_dict: dict, e.g. {
            'pedal_frame_output': (frames_num,),
            'pedal_offset_output': (frames_num,),
            'pedal_offset_shift_output': (frames_num,),
            ...}

        Returns:
          est_on_off: (notes, 2), the two columns are pedal onsets and pedal
            offsets. E.g.,
              [[0.1800, 0.9669],
               [1.1400, 2.6458],
               ...]
        """

        frames_num = output_dict['pedal_frame_output'].shape[0]
        
        est_tuples = onsets_frames_pedal_detection(
            frame_output=output_dict['pedal_frame_output'][:, 0], 
            offset_output=output_dict['reg_pedal_offset_output'][:, 0], 
            frame_threshold=0.5)

        est_tuples = np.array(est_tuples)
        """(notes, 2), the two columns are pedal onsets and pedal offsets"""
        
        if len(est_tuples) == 0:
            return np.array([])

        else:
            onset_times = est_tuples[:, 0] / self.frames_per_second
            offset_times = est_tuples[:, 1] / self.frames_per_second
            est_on_off = np.stack((onset_times, offset_times), axis=-1)
            est_on_off = est_on_off.astype(np.float32)
            return est_on_off

    def detected_notes_to_events(self, est_on_off_note_vels):
        """Reformat detected notes to midi events.

        Args:
          est_on_off_vels: (notes, 3), the three columns are onset_times, 
            offset_times and velocity. E.g.
            [[32.8376, 35.7700, 0.7932],
             [37.3712, 39.9300, 0.8058],
             ...]
        
        Returns:
          midi_events, list, e.g.,
            [{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
             {'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
             ...]
        """
        midi_events = []
        for i in range(len(est_on_off_note_vels)):
            midi_events.append({
                'onset_time': est_on_off_note_vels[i][0], 
                'offset_time': est_on_off_note_vels[i][1], 
                'midi_note': int(est_on_off_note_vels[i][2]), 
                'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})

        return midi_events

    def detected_pedals_to_events(self, pedal_on_offs):
        """Reformat detected pedal onset and offsets to events.

        Args:
          pedal_on_offs: (notes, 2), the two columns are pedal onsets and pedal
          offsets. E.g., 
            [[0.1800, 0.9669],
             [1.1400, 2.6458],
             ...]

        Returns:
          pedal_events: list of dict, e.g.,
            [{'onset_time': 0.1800, 'offset_time': 0.9669}, 
             {'onset_time': 1.1400, 'offset_time': 2.6458},
             ...]
        """
        pedal_events = []
        for i in range(len(pedal_on_offs)):
            pedal_events.append({
                'onset_time': pedal_on_offs[i, 0], 
                'offset_time': pedal_on_offs[i, 1]})
        
        return pedal_events


class StatisticsContainer(object):
    def __init__(self, statistics_path):
        """Contain statistics of different training iterations.
        """
        self.statistics_path = statistics_path

        self.backup_statistics_path = '{}_{}.pkl'.format(
            os.path.splitext(self.statistics_path)[0], 
            datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))

        self.statistics_dict = {'train': [], 'validation': [], 'test': []}

    def append(self, iteration, statistics, data_type):
        statistics['iteration'] = iteration
        self.statistics_dict[data_type].append(statistics)
        
    def dump(self):
        pickle.dump(self.statistics_dict, open(self.statistics_path, 'wb'))
        pickle.dump(self.statistics_dict, open(self.backup_statistics_path, 'wb'))
        logging.info('    Dump statistics to {}'.format(self.statistics_path))
        logging.info('    Dump statistics to {}'.format(self.backup_statistics_path))
        
    def load_state_dict(self, resume_iteration):
        self.statistics_dict = pickle.load(open(self.statistics_path, 'rb'))

        resume_statistics_dict = {'train': [], 'validation': [], 'test': []}
        
        for key in self.statistics_dict.keys():
            for statistics in self.statistics_dict[key]:
                if statistics['iteration'] <= resume_iteration:
                    resume_statistics_dict[key].append(statistics)
                
        self.statistics_dict = resume_statistics_dict


def load_audio(path, sr=22050, mono=True, offset=0.0, duration=None,
    dtype=np.float32, res_type='kaiser_best', 
    backends=[audioread.ffdec.FFmpegAudioFile]):
    """Load audio. Copied from librosa.core.load() except that ffmpeg backend is 
    always used in this function."""

    y = []
    with audioread.audio_open(os.path.realpath(path), backends=backends) as input_file:
        sr_native = input_file.samplerate
        n_channels = input_file.channels

        s_start = int(np.round(sr_native * offset)) * n_channels

        if duration is None:
            s_end = np.inf
        else:
            s_end = s_start + (int(np.round(sr_native * duration))
                               * n_channels)

        n = 0

        for frame in input_file:
            frame = librosa.core.audio.util.buf_to_float(frame, dtype=dtype)
            n_prev = n
            n = n + len(frame)

            if n < s_start:
                # offset is after the current frame
                # keep reading
                continue

            if s_end < n_prev:
                # we're off the end.  stop reading
                break

            if s_end < n:
                # the end is in this frame.  crop.
                frame = frame[:s_end - n_prev]

            if n_prev <= s_start <= n:
                # beginning is in this frame
                frame = frame[(s_start - n_prev):]

            # tack on the current frame
            y.append(frame)

    if y:
        y = np.concatenate(y)

        if n_channels > 1:
            y = y.reshape((-1, n_channels)).T
            if mono:
                y = librosa.core.audio.to_mono(y)

        if sr is not None:
            y = librosa.core.audio.resample(y, sr_native, sr, res_type=res_type)

        else:
            sr = sr_native

    # Final cleanup for dtype and contiguity
    y = np.ascontiguousarray(y, dtype=dtype)

    return (y, sr)

## pytorch_utils.py ----------------------------------------------------------------------------------------------------------------

In [8]:
import os
import sys
import numpy as np
import time
import librosa
import torch
import torch.nn as nn


def move_data_to_device(x, device):
    if 'float' in str(x.dtype):
        x = torch.Tensor(x)
    elif 'int' in str(x.dtype):
        x = torch.LongTensor(x)
    else:
        return x

    return x.to(device)


def append_to_dict(dict, key, value):
    
    if key in dict.keys():
        dict[key].append(value)
    else:
        dict[key] = [value]


def forward_dataloader(model, dataloader, batch_size, return_target=True):
    """Forward data generated from dataloader to model.

    Args:
      model: object
      dataloader: object, used to generate mini-batches for evaluation.
      batch_size: int
      return_target: bool

    Returns:
      output_dict: dict, e.g. {
        'frame_output': (segments_num, frames_num, classes_num),
        'onset_output': (segments_num, frames_num, classes_num),
        'frame_roll': (segments_num, frames_num, classes_num),
        'onset_roll': (segments_num, frames_num, classes_num),
        ...}
    """

    output_dict = {}
    device = next(model.parameters()).device

    for n, batch_data_dict in enumerate(dataloader):
        
        batch_waveform = move_data_to_device(batch_data_dict['waveform'], device)

        with torch.no_grad():
            model.eval()
            batch_output_dict = model(batch_waveform)

        for key in batch_output_dict.keys():
            if '_list' not in key:
                append_to_dict(output_dict, key, 
                    batch_output_dict[key].data.cpu().numpy())

        if return_target:
            for target_type in batch_data_dict.keys():
                if 'roll' in target_type or 'reg_distance' in target_type or \
                    'reg_tail' in target_type:
                    append_to_dict(output_dict, target_type, 
                        batch_data_dict[target_type])

    for key in output_dict.keys():
        output_dict[key] = np.concatenate(output_dict[key], axis=0)
    
    return output_dict


def forward(model, x, batch_size):
    """Forward data to model in mini-batch. 
    
    Args: 
      model: object
      x: (N, segment_samples)
      batch_size: int

    Returns:
      output_dict: dict, e.g. {
        'frame_output': (segments_num, frames_num, classes_num),
        'onset_output': (segments_num, frames_num, classes_num),
        ...}
    """
    
    output_dict = {}
    device = next(model.parameters()).device
    
    pointer = 0
    while True:
        if pointer >= len(x):
            break

        batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device)
        pointer += batch_size

        with torch.no_grad():
            model.eval()
            batch_output_dict = model(batch_waveform)

        for key in batch_output_dict.keys():
            # if '_list' not in key:
            append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())

    for key in output_dict.keys():
        output_dict[key] = np.concatenate(output_dict[key], axis=0)

    return output_dict

## models.py ----------------------------------------------------------------------------------------------------------------------

In [9]:
import os
import sys
import math
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchlibrosa.stft import Spectrogram, LogmelFilterBank


def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


def init_gru(rnn):
    """Initialize a GRU layer. """
    
    def _concat_init(tensor, init_funcs):
        (length, fan_out) = tensor.shape
        fan_in = length // len(init_funcs)
    
        for (i, init_func) in enumerate(init_funcs):
            init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
        
    def _inner_uniform(tensor):
        fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
        nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
    
    for i in range(rnn.num_layers):
        _concat_init(
            getattr(rnn, 'weight_ih_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, _inner_uniform]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)

        _concat_init(
            getattr(rnn, 'weight_hh_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, nn.init.orthogonal_]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, momentum):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels, momentum)
        self.bn2 = nn.BatchNorm2d(out_channels, momentum)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        """
        Args:
          input: (batch_size, in_channels, time_steps, freq_bins)

        Outputs:
          output: (batch_size, out_channels, classes_num)
        """

        x = F.relu_(self.bn1(self.conv1(input)))
        x = F.relu_(self.bn2(self.conv2(x)))
        
        if pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        
        return x


class AcousticModelCRnn8Dropout(nn.Module):
    def __init__(self, classes_num, midfeat, momentum):
        super(AcousticModelCRnn8Dropout, self).__init__()

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=48, momentum=momentum)
        self.conv_block2 = ConvBlock(in_channels=48, out_channels=64, momentum=momentum)
        self.conv_block3 = ConvBlock(in_channels=64, out_channels=96, momentum=momentum)
        self.conv_block4 = ConvBlock(in_channels=96, out_channels=128, momentum=momentum)

        self.fc5 = nn.Linear(midfeat, 768, bias=False)
        self.bn5 = nn.BatchNorm1d(768, momentum=momentum)
        
        # 256 -> 64 outputshape for dense is 64*2
        self.gru = nn.GRU(input_size=768, hidden_size=64, num_layers=2, 
            bias=True, batch_first=True, dropout=0., bidirectional=True)

        # 512 -> 128
        self.fc = nn.Linear(128, classes_num, bias=True)
        
        self.init_weight()

    def init_weight(self):
        init_layer(self.fc5)
        init_bn(self.bn5)
        init_gru(self.gru)
        init_layer(self.fc)

    def forward(self, input):
        """
        Args:
          input: (batch_size, channels_num, time_steps, freq_bins)

        Outputs:
          output: (batch_size, time_steps, classes_num)
        """

        x = self.conv_block1(input, pool_size=(1, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(1, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)

        x = x.transpose(1, 2).flatten(2)
        x = F.relu(self.bn5(self.fc5(x).transpose(1, 2)).transpose(1, 2))
        x = F.dropout(x, p=0.5, training=self.training, inplace=False)
        
        (x, _) = self.gru(x)
        x = F.dropout(x, p=0.5, training=self.training, inplace=False)
        output = torch.sigmoid(self.fc(x))
        return output


class Regress_onset_offset_frame_velocity_CRNN(nn.Module):
    def __init__(self, frames_per_second, classes_num):
        super(Regress_onset_offset_frame_velocity_CRNN, self).__init__()

        sample_rate = 16000
        window_size = 2048
        hop_size = sample_rate // frames_per_second
        mel_bins = 229
        fmin = 30
        fmax = sample_rate // 2

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        midfeat = 1792
        momentum = 0.01

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, 
            hop_length=hop_size, win_length=window_size, window=window, 
            center=center, pad_mode=pad_mode, freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, 
            n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, 
            amin=amin, top_db=top_db, freeze_parameters=True)

        self.bn0 = nn.BatchNorm2d(mel_bins, momentum)

        self.frame_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
        self.reg_onset_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
        self.reg_offset_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
        self.velocity_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)

        # change all gru hidden size from 256 -> 64
        # change all linear from 512 -> 128
        self.reg_onset_gru = nn.GRU(input_size=88 * 2, hidden_size=64, num_layers=1, 
            bias=True, batch_first=True, dropout=0., bidirectional=True)
        self.reg_onset_fc = nn.Linear(128, classes_num, bias=True)

        self.frame_gru = nn.GRU(input_size=88 * 3, hidden_size=64, num_layers=1, 
            bias=True, batch_first=True, dropout=0., bidirectional=True)
        self.frame_fc = nn.Linear(128, classes_num, bias=True)

        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_gru(self.reg_onset_gru)
        init_gru(self.frame_gru)
        init_layer(self.reg_onset_fc)
        init_layer(self.frame_fc)
 
    def forward(self, input):
        """
        Args:
          input: (batch_size, data_length)

        Outputs:
          output_dict: dict, {
            'reg_onset_output': (batch_size, time_steps, classes_num),
            'reg_offset_output': (batch_size, time_steps, classes_num),
            'frame_output': (batch_size, time_steps, classes_num),
            'velocity_output': (batch_size, time_steps, classes_num)
          }
        """

        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        frame_output = self.frame_model(x)  # (batch_size, time_steps, classes_num)
        reg_onset_output = self.reg_onset_model(x)  # (batch_size, time_steps, classes_num)
        reg_offset_output = self.reg_offset_model(x)    # (batch_size, time_steps, classes_num)
        velocity_output = self.velocity_model(x)    # (batch_size, time_steps, classes_num)
 
        # Use velocities to condition onset regression
        x = torch.cat((reg_onset_output, (reg_onset_output ** 0.5) * velocity_output.detach()), dim=2)
        (x, _) = self.reg_onset_gru(x)
        x = F.dropout(x, p=0.5, training=self.training, inplace=False)
        reg_onset_output = torch.sigmoid(self.reg_onset_fc(x))
        """(batch_size, time_steps, classes_num)"""

        # Use onsets and offsets to condition frame-wise classification
        x = torch.cat((frame_output, reg_onset_output.detach(), reg_offset_output.detach()), dim=2)
        (x, _) = self.frame_gru(x)
        x = F.dropout(x, p=0.5, training=self.training, inplace=False)
        frame_output = torch.sigmoid(self.frame_fc(x))  # (batch_size, time_steps, classes_num)
        """(batch_size, time_steps, classes_num)"""

        output_dict = {
            'reg_onset_output': reg_onset_output, 
            'reg_offset_output': reg_offset_output, 
            'frame_output': frame_output, 
            'velocity_output': velocity_output}

        return output_dict


class Regress_pedal_CRNN(nn.Module):
    def __init__(self, frames_per_second, classes_num):
        super(Regress_pedal_CRNN, self).__init__()

        sample_rate = 16000
        window_size = 2048
        hop_size = sample_rate // frames_per_second
        mel_bins = 229
        fmin = 30
        fmax = sample_rate // 2

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        midfeat = 1792
        momentum = 0.01

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, 
            hop_length=hop_size, win_length=window_size, window=window, 
            center=center, pad_mode=pad_mode, freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, 
            n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, 
            amin=amin, top_db=top_db, freeze_parameters=True)

        self.bn0 = nn.BatchNorm2d(mel_bins, momentum)

        self.reg_pedal_onset_model = AcousticModelCRnn8Dropout(1, midfeat, momentum)
        self.reg_pedal_offset_model = AcousticModelCRnn8Dropout(1, midfeat, momentum)
        self.reg_pedal_frame_model = AcousticModelCRnn8Dropout(1, midfeat, momentum)
        
        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        
    def forward(self, input):
        """
        Args:
          input: (batch_size, data_length)

        Outputs:
          output_dict: dict, {
            'reg_onset_output': (batch_size, time_steps, classes_num),
            'reg_offset_output': (batch_size, time_steps, classes_num),
            'frame_output': (batch_size, time_steps, classes_num),
            'velocity_output': (batch_size, time_steps, classes_num)
          }
        """

        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        reg_pedal_onset_output = self.reg_pedal_onset_model(x)  # (batch_size, time_steps, classes_num)
        reg_pedal_offset_output = self.reg_pedal_offset_model(x)  # (batch_size, time_steps, classes_num)
        pedal_frame_output = self.reg_pedal_frame_model(x)  # (batch_size, time_steps, classes_num)
        
        output_dict = {
            'reg_pedal_onset_output': reg_pedal_onset_output, 
            'reg_pedal_offset_output': reg_pedal_offset_output,
            'pedal_frame_output': pedal_frame_output}

        return output_dict


# This model is not trained, but is combined from the trained note and pedal models.
class Note_pedal(nn.Module):
    def __init__(self, frames_per_second, classes_num):
        """The combination of note and pedal model.
        """
        super(Note_pedal, self).__init__()

        self.note_model = Regress_onset_offset_frame_velocity_CRNN(frames_per_second, classes_num)
        self.pedal_model = Regress_pedal_CRNN(frames_per_second, classes_num)

    def load_state_dict(self, m, strict=False):
        self.note_model.load_state_dict(m['note_model'], strict=strict)
        self.pedal_model.load_state_dict(m['pedal_model'], strict=strict)

    def forward(self, input):
        note_output_dict = self.note_model(input)
        pedal_output_dict = self.pedal_model(input)

        full_output_dict = {}
        full_output_dict.update(note_output_dict)
        full_output_dict.update(pedal_output_dict)
        return full_output_dict

## data_generator.py -------------------------------------------------------------------------------------------------------------

In [10]:
import os
import sys
import numpy as np
import h5py
import csv
import time
import collections
import librosa
import sox
import logging


class MaestroDataset(object):
    def __init__(self, hdf5s_dir, segment_seconds, frames_per_second, 
        max_note_shift=0, augmentor=None):
        """This class takes the meta of an audio segment as input, and return 
        the waveform and targets of the audio segment. This class is used by 
        DataLoader. 
        
        Args:
          feature_hdf5s_dir: str
          segment_seconds: float
          frames_per_second: int
          max_note_shift: int, number of semitone for pitch augmentation
          augmentor: object
        """
        self.hdf5s_dir = hdf5s_dir
        self.segment_seconds = segment_seconds
        self.frames_per_second = frames_per_second
        self.sample_rate = config.sample_rate
        self.max_note_shift = max_note_shift
        self.begin_note = config.begin_note
        self.classes_num = config.classes_num
        self.segment_samples = int(self.sample_rate * self.segment_seconds)
        self.augmentor = augmentor

        self.random_state = np.random.RandomState(1234)

        self.target_processor = TargetProcessor(self.segment_seconds, 
            self.frames_per_second, self.begin_note, self.classes_num)
        """Used for processing MIDI events to target."""

    def __getitem__(self, meta):
        """Prepare input and target of a segment for training.
        
        Args:
          meta: dict, e.g. {
            'year': '2004', 
            'hdf5_name': 'MIDI-Unprocessed_SMF_12_01_2004_01-05_ORIG_MID--AUDIO_12_R1_2004_10_Track10_wav.h5, 
            'start_time': 65.0}

        Returns:
          data_dict: {
            'waveform': (samples_num,)
            'onset_roll': (frames_num, classes_num), 
            'offset_roll': (frames_num, classes_num), 
            'reg_onset_roll': (frames_num, classes_num), 
            'reg_offset_roll': (frames_num, classes_num), 
            'frame_roll': (frames_num, classes_num), 
            'velocity_roll': (frames_num, classes_num), 
            'mask_roll':  (frames_num, classes_num), 
            'pedal_onset_roll': (frames_num,), 
            'pedal_offset_roll': (frames_num,), 
            'reg_pedal_onset_roll': (frames_num,), 
            'reg_pedal_offset_roll': (frames_num,), 
            'pedal_frame_roll': (frames_num,)}
        """
        [year, hdf5_name, start_time] = meta
        hdf5_path = os.path.join(self.hdf5s_dir, year, hdf5_name)
         
        data_dict = {}

        note_shift = self.random_state.randint(low=-self.max_note_shift, 
            high=self.max_note_shift + 1)

        # Load hdf5
        with h5py.File(hdf5_path, 'r') as hf:
            start_sample = int(start_time * self.sample_rate)
            end_sample = start_sample + self.segment_samples

            if end_sample >= hf['waveform'].shape[0]:
                start_sample -= self.segment_samples
                end_sample -= self.segment_samples

            waveform = int16_to_float32(hf['waveform'][start_sample : end_sample])

            if self.augmentor:
                waveform = self.augmentor.augment(waveform)

            if note_shift != 0:
                """Augment pitch"""
                waveform = librosa.effects.pitch_shift(waveform, self.sample_rate, 
                    note_shift, bins_per_octave=12)

            data_dict['waveform'] = waveform

            midi_events = [e.decode() for e in hf['midi_event'][:]]
            midi_events_time = hf['midi_event_time'][:]

            # Process MIDI events to target
            (target_dict, note_events, pedal_events) = \
                self.target_processor.process(start_time, midi_events_time, 
                    midi_events, extend_pedal=True, note_shift=note_shift)

        # Combine input and target
        for key in target_dict.keys():
            data_dict[key] = target_dict[key]

        debugging = False
        if debugging:
            plot_waveform_midi_targets(data_dict, start_time, note_events)
            exit()

        return data_dict


class Augmentor(object):
    def __init__(self):
        """Data augmentor."""
        
        self.sample_rate = config.sample_rate
        self.random_state = np.random.RandomState(1234)

    def augment(self, x):
        clip_samples = len(x)

        logger = logging.getLogger('sox')
        logger.propagate = False

        tfm = sox.Transformer()
        tfm.set_globals(verbosity=0)

        tfm.pitch(self.random_state.uniform(-0.1, 0.1, 1)[0])
        tfm.contrast(self.random_state.uniform(0, 100, 1)[0])

        tfm.equalizer(frequency=self.loguniform(32, 4096, 1)[0], 
            width_q=self.random_state.uniform(1, 2, 1)[0], 
            gain_db=self.random_state.uniform(-30, 10, 1)[0])

        tfm.equalizer(frequency=self.loguniform(32, 4096, 1)[0], 
            width_q=self.random_state.uniform(1, 2, 1)[0], 
            gain_db=self.random_state.uniform(-30, 10, 1)[0])
        
        tfm.reverb(reverberance=self.random_state.uniform(0, 70, 1)[0])

        aug_x = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate)
        aug_x = pad_truncate_sequence(aug_x, clip_samples)
        
        return aug_x

    def loguniform(self, low, high, size):
        return np.exp(self.random_state.uniform(np.log(low), np.log(high), size))


class Sampler(object):
    def __init__(self, hdf5s_dir, split, segment_seconds, hop_seconds, 
            batch_size, mini_data, random_seed=1234):
        """Sampler is used to sample segments for training or evaluation.

        Args:
          hdf5s_dir: str
          split: 'train' | 'validation' | 'test'
          segment_seconds: float
          hop_seconds: float
          batch_size: int
          mini_data: bool, sample from a small amount of data for debugging
        """
        assert split in ['train', 'validation', 'test']
        self.hdf5s_dir = hdf5s_dir
        self.segment_seconds = segment_seconds
        self.hop_seconds = hop_seconds
        self.sample_rate = config.sample_rate
        self.batch_size = batch_size
        self.random_state = np.random.RandomState(random_seed)

        (hdf5_names, hdf5_paths) = traverse_folder(hdf5s_dir)
        self.segment_list = []

        n = 0
        for hdf5_path in hdf5_paths:
            with h5py.File(hdf5_path, 'r') as hf:
                if hf.attrs['split'].decode() == split:
                    audio_name = hdf5_path.split('/')[-1]
                    year = hf.attrs['year'].decode()
                    start_time = 0
                    while (start_time + self.segment_seconds < hf.attrs['duration']):
                        self.segment_list.append([year, audio_name, start_time])
                        start_time += self.hop_seconds
                    
                    n += 1
                    if mini_data and n == 10:
                        break
        """self.segment_list looks like:
        [['2004', 'MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_17_Track17_wav.h5', 0], 
         ['2004', 'MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_17_Track17_wav.h5', 1.0], 
         ['2004', 'MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_17_Track17_wav.h5', 2.0]
         ...]"""

        logging.info('{} segments: {}'.format(split, len(self.segment_list)))

        self.pointer = 0
        self.segment_indexes = np.arange(len(self.segment_list))
        self.random_state.shuffle(self.segment_indexes)

    def __iter__(self):
        while True:
            batch_segment_list = []
            i = 0
            while i < self.batch_size:
                index = self.segment_indexes[self.pointer]
                self.pointer += 1

                if self.pointer >= len(self.segment_indexes):
                    self.pointer = 0
                    self.random_state.shuffle(self.segment_indexes)

                batch_segment_list.append(self.segment_list[index])
                i += 1

            yield batch_segment_list

    def __len__(self):
        return -1
        
    def state_dict(self):
        state = {
            'pointer': self.pointer, 
            'segment_indexes': self.segment_indexes}
        return state
            
    def load_state_dict(self, state):
        self.pointer = state['pointer']
        self.segment_indexes = state['segment_indexes']


class TestSampler(object):
    def __init__(self, hdf5s_dir, split, segment_seconds, hop_seconds, 
            batch_size, mini_data, random_seed=1234):
        """Sampler for testing.

        Args:
          hdf5s_dir: str
          split: 'train' | 'validation' | 'test'
          segment_seconds: float
          hop_seconds: float
          batch_size: int
          mini_data: bool, sample from a small amount of data for debugging
        """
        assert split in ['train', 'validation', 'test']
        self.hdf5s_dir = hdf5s_dir
        self.segment_seconds = segment_seconds
        self.hop_seconds = hop_seconds
        self.sample_rate = config.sample_rate
        self.batch_size = batch_size
        self.random_state = np.random.RandomState(random_seed)
        self.max_evaluate_iteration = 20    # Number of mini-batches to validate

        (hdf5_names, hdf5_paths) = traverse_folder(hdf5s_dir)
        self.segment_list = []

        n = 0
        for hdf5_path in hdf5_paths:
            with h5py.File(hdf5_path, 'r') as hf:
                if hf.attrs['split'].decode() == split:
                    audio_name = hdf5_path.split('/')[-1]
                    year = hf.attrs['year'].decode()
                    start_time = 0
                    while (start_time + self.segment_seconds < hf.attrs['duration']):
                        self.segment_list.append([year, audio_name, start_time])
                        start_time += self.hop_seconds
                    
                    n += 1
                    if mini_data and n == 10:
                        break
        """self.segment_list looks like:
        [['2004', 'MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_17_Track17_wav.h5', 0], 
         ['2004', 'MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_17_Track17_wav.h5', 1.0], 
         ['2004', 'MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_17_Track17_wav.h5', 2.0]
         ...]"""

        logging.info('Evaluate {} segments: {}'.format(split, len(self.segment_list)))

        self.segment_indexes = np.arange(len(self.segment_list))
        self.random_state.shuffle(self.segment_indexes)

    def __iter__(self):
        pointer = 0
        iteration = 0

        while True:
            if iteration == self.max_evaluate_iteration:
                break

            batch_segment_list = []
            i = 0
            while i < self.batch_size:
                index = self.segment_indexes[pointer]
                pointer += 1
                
                batch_segment_list.append(self.segment_list[index])
                i += 1

            iteration += 1

            yield batch_segment_list

    def __len__(self):
        return -1


def collate_fn(list_data_dict):
    """Collate input and target of segments to a mini-batch.

    Args:
      list_data_dict: e.g. [
        {'waveform': (segment_samples,), 'frame_roll': (segment_frames, classes_num), ...}, 
        {'waveform': (segment_samples,), 'frame_roll': (segment_frames, classes_num), ...}, 
        ...]

    Returns:
      np_data_dict: e.g. {
        'waveform': (batch_size, segment_samples)
        'frame_roll': (batch_size, segment_frames, classes_num), 
        ...}
    """
    np_data_dict = {}
    for key in list_data_dict[0].keys():
        np_data_dict[key] = np.array([data_dict[key] for data_dict in list_data_dict])
    
    return np_data_dict

/bin/sh: 1: sox: not found


## losses.py -----------------------------------------------------------------------------------------------------------------------

In [11]:
import torch
import torch.nn.functional as F


def bce(output, target, mask):
    """Binary crossentropy (BCE) with mask. The positions where mask=0 will be 
    deactivated when calculation BCE.
    """
    eps = 1e-7
    output = torch.clamp(output, eps, 1. - eps)
    matrix = - target * torch.log(output) - (1. - target) * torch.log(1. - output)
    return torch.sum(matrix * mask) / torch.sum(mask)

############ High-resolution regression loss ############
def regress_onset_offset_frame_velocity_bce(model, output_dict, target_dict):
    """High-resolution piano note regression loss, including onset regression, 
    offset regression, velocity regression and frame-wise classification losses.
    """
    onset_loss = bce(output_dict['reg_onset_output'], target_dict['reg_onset_roll'], target_dict['mask_roll'])
    offset_loss = bce(output_dict['reg_offset_output'], target_dict['reg_offset_roll'], target_dict['mask_roll'])
    frame_loss = bce(output_dict['frame_output'], target_dict['frame_roll'], target_dict['mask_roll'])
    velocity_loss = bce(output_dict['velocity_output'], target_dict['velocity_roll'] / 128, target_dict['onset_roll'])
    total_loss = onset_loss + offset_loss + frame_loss + velocity_loss
    return total_loss


def regress_pedal_bce(model, output_dict, target_dict):
    """High-resolution piano pedal regression loss, including pedal onset 
    regression, pedal offset regression and pedal frame-wise classification losses.
    """
    onset_pedal_loss = F.binary_cross_entropy(output_dict['reg_pedal_onset_output'], target_dict['reg_pedal_onset_roll'][:, :, None])
    offset_pedal_loss = F.binary_cross_entropy(output_dict['reg_pedal_offset_output'], target_dict['reg_pedal_offset_roll'][:, :, None])
    frame_pedal_loss = F.binary_cross_entropy(output_dict['pedal_frame_output'], target_dict['pedal_frame_roll'][:, :, None])
    total_loss = onset_pedal_loss + offset_pedal_loss + frame_pedal_loss
    return total_loss

############ Google's onsets and frames system loss ############
def google_onset_offset_frame_velocity_bce(model, output_dict, target_dict):
    """Google's onsets and frames system piano note loss. Only used for comparison.
    """
    onset_loss = bce(output_dict['reg_onset_output'], target_dict['onset_roll'], target_dict['mask_roll'])
    offset_loss = bce(output_dict['reg_offset_output'], target_dict['offset_roll'], target_dict['mask_roll'])
    frame_loss = bce(output_dict['frame_output'], target_dict['frame_roll'], target_dict['mask_roll'])
    velocity_loss = bce(output_dict['velocity_output'], target_dict['velocity_roll'] / 128, target_dict['onset_roll'])
    total_loss = onset_loss + offset_loss + frame_loss + velocity_loss
    return total_loss


def google_pedal_bce(model, output_dict, target_dict):
    """Google's onsets and frames system piano pedal loss. Only used for comparison.
    """
    onset_pedal_loss = F.binary_cross_entropy(output_dict['reg_pedal_onset_output'], target_dict['pedal_onset_roll'][:, :, None])
    offset_pedal_loss = F.binary_cross_entropy(output_dict['reg_pedal_offset_output'], target_dict['pedal_offset_roll'][:, :, None])
    frame_pedal_loss = F.binary_cross_entropy(output_dict['pedal_frame_output'], target_dict['pedal_frame_roll'][:, :, None])
    total_loss = onset_pedal_loss + offset_pedal_loss + frame_pedal_loss
    return total_loss


def get_loss_func(loss_type):
    if loss_type == 'regress_onset_offset_frame_velocity_bce':
        return regress_onset_offset_frame_velocity_bce

    elif loss_type == 'regress_pedal_bce':
        return regress_pedal_bce

    elif loss_type == 'google_onset_offset_frame_velocity_bce':
        return google_onset_offset_frame_velocity_bce

    elif loss_type == 'google_pedal_bce':
        return google_pedal_bce

    else:
        raise Exception('Incorrect loss_type!')

In [12]:
import os
import sys
import numpy as np
import torch
import h5py
import time
import mir_eval
import librosa
import logging
from sklearn import metrics


def mae(target, output, mask):
    if mask is None:
        return np.mean(np.abs(target - output))
    else:
        target *= mask
        output *= mask
        return np.sum(np.abs(target - output)) / np.clip(np.sum(mask), 1e-8, np.inf)


class SegmentEvaluator(object):
    def __init__(self, model, batch_size):
        """Evaluate segment-wise metrics.

        Args:
          model: object
          batch_size: int
        """
        self.model = model
        self.batch_size = batch_size

    def evaluate(self, dataloader):
        """Evaluate over a few mini-batches.

        Args:
          dataloader: object, used to generate mini-batches for evaluation.

        Returns:
          statistics: dict, e.g. {
            'frame_f1': 0.800, 
            (if exist) 'onset_f1': 0.500, 
            (if exist) 'offset_f1': 0.300, 
            ...}
        """

        statistics = {}
        output_dict = forward_dataloader(self.model, dataloader, self.batch_size)
        
        # Frame and onset evaluation
        if 'frame_output' in output_dict.keys():
            statistics['frame_ap'] = metrics.average_precision_score(
                output_dict['frame_roll'].flatten(), 
                output_dict['frame_output'].flatten(), average='macro')
        
        if 'onset_output' in output_dict.keys():
            statistics['onset_macro_ap'] = metrics.average_precision_score(
                output_dict['onset_roll'].flatten(), 
                output_dict['onset_output'].flatten(), average='macro')

        if 'offset_output' in output_dict.keys():
            statistics['offset_ap'] = metrics.average_precision_score(
                output_dict['offset_roll'].flatten(), 
                output_dict['offset_output'].flatten(), average='macro')

        if 'reg_onset_output' in output_dict.keys():
            """Mask indictes only evaluate where either prediction or ground truth exists"""
            mask = (np.sign(output_dict['reg_onset_output'] + output_dict['reg_onset_roll'] - 0.01) + 1) / 2
            statistics['reg_onset_mae'] = mae(output_dict['reg_onset_output'], 
                output_dict['reg_onset_roll'], mask)

        if 'reg_offset_output' in output_dict.keys():
            """Mask indictes only evaluate where either prediction or ground truth exists"""
            mask = (np.sign(output_dict['reg_offset_output'] + output_dict['reg_offset_roll'] - 0.01) + 1) / 2
            statistics['reg_offset_mae'] = mae(output_dict['reg_offset_output'], 
                output_dict['reg_offset_roll'], mask)

        if 'velocity_output' in output_dict.keys():
            """Mask indictes only evaluate where onset exists"""
            statistics['velocity_mae'] = mae(output_dict['velocity_output'], 
                output_dict['velocity_roll'] / 128, output_dict['onset_roll'])

        if 'reg_pedal_onset_output' in output_dict.keys():
            statistics['reg_pedal_onset_mae'] = mae(
                output_dict['reg_pedal_onset_roll'].flatten(), 
                output_dict['reg_pedal_onset_output'].flatten(), 
                mask=None)

        if 'reg_pedal_offset_output' in output_dict.keys():
            statistics['reg_pedal_offset_mae'] = mae(
                output_dict['reg_pedal_offset_output'].flatten(), 
                output_dict['reg_pedal_offset_roll'].flatten(), 
                mask=None)

        if 'pedal_frame_output' in output_dict.keys():
            statistics['pedal_frame_mae'] = mae(
                output_dict['pedal_frame_output'].flatten(), 
                output_dict['pedal_frame_roll'].flatten(), 
                mask=None)

        for key in statistics.keys():
            statistics[key] = np.around(statistics[key], decimals=4)

        return statistics

## evaluate.py ---------------------------------------------------------------------------------------------------------------------

## main.py ------------------------------------------------------------------------------------------------------------------------

In [13]:
import os
import sys
import numpy as np
import argparse
import h5py
import math
import time
import librosa
import logging
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data


def train(args):
    """Train a piano transcription system.

    Args:
      workspace: str, directory of your workspace
      model_type: str, e.g. 'Regressonset_regressoffset_frame_velocity_CRNN'
      loss_type: str, e.g. 'regress_onset_offset_frame_velocity_bce'
      augmentation: str, e.g. 'none'
      batch_size: int
      learning_rate: float
      reduce_iteration: int
      resume_iteration: int
      early_stop: int
      device: 'cuda' | 'cpu'
      mini_data: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    model_type = args.model_type
    loss_type = args.loss_type
    augmentation = args.augmentation
    max_note_shift = args.max_note_shift
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    reduce_iteration = args.reduce_iteration
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
    mini_data = args.mini_data
    filename = args.filename

    sample_rate = config.sample_rate
    segment_seconds = config.segment_seconds
    hop_seconds = config.hop_seconds
    segment_samples = int(segment_seconds * sample_rate)
    frames_per_second = config.frames_per_second
    classes_num = config.classes_num
    num_workers = 4  # change this as needed

    # Loss function
    loss_func = get_loss_func(loss_type)

    # Paths
    # modify the hdf5s_dir to input directory ----------------------------------------------------------------------------------------------
    h5_files_input_path = '/kaggle/input/studio-maestro-hdf5'
#     hdf5s_dir = os.path.join(workspace, 'hdf5s', 'maestro')
#     print(hdf5s_dir)
    hdf5s_dir = os.path.join(h5_files_input_path, 'hdf5s', 'maestro')
    print(hdf5s_dir)

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, 
        model_type, 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 
        'max_note_shift={}'.format(max_note_shift),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(workspace, 'statistics', filename, 
        model_type, 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 
        'max_note_shift={}'.format(max_note_shift), 
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', filename, 
        model_type, 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 
        'max_note_shift={}'.format(max_note_shift), 
        'batch_size={}'.format(batch_size))
    
    create_folder(logs_dir)

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'
    
    # Model
    Model = eval(model_type)
    model = Model(frames_per_second=frames_per_second, classes_num=classes_num)
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    if augmentation == 'none':
        augmentor = None
    elif augmentation == 'aug':
        augmentor = Augmentor()
    else:
        raise Exception('Incorrect argumentation!')
    
    # Dataset
    train_dataset = MaestroDataset(hdf5s_dir=hdf5s_dir, 
        segment_seconds=segment_seconds, frames_per_second=frames_per_second, 
        max_note_shift=max_note_shift, augmentor=augmentor)

    evaluate_dataset = MaestroDataset(hdf5s_dir=hdf5s_dir, 
        segment_seconds=segment_seconds, frames_per_second=frames_per_second, 
        max_note_shift=0)

    # Sampler for training
    train_sampler = Sampler(hdf5s_dir=hdf5s_dir, split='train', 
        segment_seconds=segment_seconds, hop_seconds=hop_seconds, 
        batch_size=batch_size, mini_data=mini_data)

    # Sampler for evaluation
    evaluate_train_sampler = TestSampler(hdf5s_dir=hdf5s_dir, 
        split='train', segment_seconds=segment_seconds, hop_seconds=hop_seconds, 
        batch_size=batch_size, mini_data=mini_data)

    evaluate_validate_sampler = TestSampler(hdf5s_dir=hdf5s_dir, 
        split='validation', segment_seconds=segment_seconds, hop_seconds=hop_seconds, 
        batch_size=batch_size, mini_data=mini_data)

    evaluate_test_sampler = TestSampler(hdf5s_dir=hdf5s_dir, 
        split='test', segment_seconds=segment_seconds, hop_seconds=hop_seconds, 
        batch_size=batch_size, mini_data=mini_data)

    # Dataloader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
        batch_sampler=train_sampler, collate_fn=collate_fn, 
        num_workers=num_workers, pin_memory=True)

    evaluate_train_loader = torch.utils.data.DataLoader(dataset=evaluate_dataset, 
        batch_sampler=evaluate_train_sampler, collate_fn=collate_fn, 
        num_workers=num_workers, pin_memory=True)

    validate_loader = torch.utils.data.DataLoader(dataset=evaluate_dataset, 
        batch_sampler=evaluate_validate_sampler, collate_fn=collate_fn, 
        num_workers=num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=evaluate_dataset, 
        batch_sampler=evaluate_test_sampler, collate_fn=collate_fn, 
        num_workers=num_workers, pin_memory=True)

    # Evaluator
    evaluator = SegmentEvaluator(model, batch_size)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(workspace, 'checkpoints', filename, 
            model_type, 'loss_type={}'.format(loss_type), 
            'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
                '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0
    
    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    train_bgn_time = time.time()

    for batch_data_dict in train_loader:
        
        # Evaluation 
        if iteration % 500 == 0:  # and iteration > 0:
            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            evaluate_train_statistics = evaluator.evaluate(evaluate_train_loader)
            validate_statistics = evaluator.evaluate(validate_loader)
            test_statistics = evaluator.evaluate(test_loader)

            logging.info('    Train statistics: {}'.format(evaluate_train_statistics))
            logging.info('    Validation statistics: {}'.format(validate_statistics))
            logging.info('    Test statistics: {}'.format(test_statistics))

            statistics_container.append(iteration, evaluate_train_statistics, data_type='train')
            statistics_container.append(iteration, validate_statistics, data_type='validation')
            statistics_container.append(iteration, test_statistics, data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'Train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(train_time, validate_time))

            train_bgn_time = time.time()
        
        # Save model
        if iteration % 2000 == 0:
            checkpoint = {
                'iteration': iteration, 
                'model': model.module.state_dict(), 
                'sampler': train_sampler.state_dict()}

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))
                
            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
        
        # Reduce learning rate
        if iteration % reduce_iteration == 0 and iteration > 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.9
        
        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
         
        model.train()
        batch_output_dict = model(batch_data_dict['waveform'])

        loss = loss_func(model, batch_output_dict, batch_data_dict)

        print(iteration, loss)

        # Backward
        torch.autograd.set_detect_anomaly(True)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        # Stop learning
        if iteration == early_stop:
            break

        iteration += 1


class Arguments(object):
    def __init__(self, workspace, model_type, loss_type, augmentation, max_note_shift, batch_size, learning_rate, reduce_iteration, resume_iteration, early_stop, mini_data=False, filename='SAP_P02_train'):
        self.workspace = workspace
        self.model_type = model_type
        self.loss_type = loss_type
        self.augmentation = augmentation
        self.max_note_shift = max_note_shift
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.reduce_iteration = reduce_iteration
        self.resume_iteration = resume_iteration
        self.early_stop = early_stop
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.cuda = torch.cuda.is_available()
        self.mini_data = mini_data
        self.filename = filename

args = Arguments(workspace='/kaggle/working/', model_type='Regress_onset_offset_frame_velocity_CRNN', loss_type='regress_onset_offset_frame_velocity_bce', augmentation='none', max_note_shift=0, batch_size=12, learning_rate=5e-4, reduce_iteration=250, resume_iteration=0, early_stop=10000)
# args = Arguments(workspace='/kaggle/working/', model_type='Regress_pedal_CRNN', loss_type='regress_pedal_bce', augmentation='none', max_note_shift=0, batch_size=12, learning_rate=5e-4, reduce_iteration=250, resume_iteration=0, early_stop=10000)

train(args)

root        : INFO     <__main__.Arguments object at 0x7943ab94cf70>
root        : INFO     Using GPU.


/kaggle/input/studio-maestro-hdf5/hdf5s/maestro
9311642


root        : INFO     train segments: 564137
root        : INFO     Evaluate train segments: 564137
root        : INFO     Evaluate validation segments: 68646
root        : INFO     Evaluate test segments: 70246


GPU number: 2


root        : INFO     ------------------------------------
root        : INFO     Iteration: 0
root        : INFO         Train statistics: {'frame_ap': 0.0574, 'reg_onset_mae': 0.5111, 'reg_offset_mae': 0.5009, 'velocity_mae': 0.1701}
root        : INFO         Validation statistics: {'frame_ap': 0.0542, 'reg_onset_mae': 0.5111, 'reg_offset_mae': 0.501, 'velocity_mae': 0.1685}
root        : INFO         Test statistics: {'frame_ap': 0.0599, 'reg_onset_mae': 0.511, 'reg_offset_mae': 0.5009, 'velocity_mae': 0.1703}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics

0 tensor(2.9084, device='cuda:0', grad_fn=<AddBackward0>)
1 tensor(2.8259, device='cuda:0', grad_fn=<AddBackward0>)
2 tensor(2.7837, device='cuda:0', grad_fn=<AddBackward0>)
3 tensor(2.6827, device='cuda:0', grad_fn=<AddBackward0>)
4 tensor(2.5931, device='cuda:0', grad_fn=<AddBackward0>)
5 tensor(2.4947, device='cuda:0', grad_fn=<AddBackward0>)
6 tensor(2.4236, device='cuda:0', grad_fn=<AddBackward0>)
7 tensor(2.3472, device='cuda:0', grad_fn=<AddBackward0>)
8 tensor(2.2844, device='cuda:0', grad_fn=<AddBackward0>)
9 tensor(2.2219, device='cuda:0', grad_fn=<AddBackward0>)
10 tensor(2.1706, device='cuda:0', grad_fn=<AddBackward0>)
11 tensor(2.1055, device='cuda:0', grad_fn=<AddBackward0>)
12 tensor(2.0568, device='cuda:0', grad_fn=<AddBackward0>)
13 tensor(2.0156, device='cuda:0', grad_fn=<AddBackward0>)
14 tensor(1.9565, device='cuda:0', grad_fn=<AddBackward0>)
15 tensor(1.9491, device='cuda:0', grad_fn=<AddBackward0>)
16 tensor(1.8878, device='cuda:0', grad_fn=<AddBackward0>)
17 tens

root        : INFO     ------------------------------------
root        : INFO     Iteration: 500
root        : INFO         Train statistics: {'frame_ap': 0.1137, 'reg_onset_mae': 0.4787, 'reg_offset_mae': 0.3671, 'velocity_mae': 0.1158}
root        : INFO         Validation statistics: {'frame_ap': 0.1151, 'reg_onset_mae': 0.478, 'reg_offset_mae': 0.3654, 'velocity_mae': 0.1161}
root        : INFO         Test statistics: {'frame_ap': 0.1185, 'reg_onset_mae': 0.479, 'reg_offset_mae': 0.3699, 'velocity_mae': 0.1188}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statisti

500 tensor(0.9426, device='cuda:0', grad_fn=<AddBackward0>)
501 tensor(0.9642, device='cuda:0', grad_fn=<AddBackward0>)
502 tensor(0.9709, device='cuda:0', grad_fn=<AddBackward0>)
503 tensor(0.9169, device='cuda:0', grad_fn=<AddBackward0>)
504 tensor(0.9365, device='cuda:0', grad_fn=<AddBackward0>)
505 tensor(0.9411, device='cuda:0', grad_fn=<AddBackward0>)
506 tensor(0.9987, device='cuda:0', grad_fn=<AddBackward0>)
507 tensor(0.9237, device='cuda:0', grad_fn=<AddBackward0>)
508 tensor(0.9795, device='cuda:0', grad_fn=<AddBackward0>)
509 tensor(0.9200, device='cuda:0', grad_fn=<AddBackward0>)
510 tensor(0.9563, device='cuda:0', grad_fn=<AddBackward0>)
511 tensor(0.9414, device='cuda:0', grad_fn=<AddBackward0>)
512 tensor(0.9025, device='cuda:0', grad_fn=<AddBackward0>)
513 tensor(0.9278, device='cuda:0', grad_fn=<AddBackward0>)
514 tensor(0.9316, device='cuda:0', grad_fn=<AddBackward0>)
515 tensor(0.9286, device='cuda:0', grad_fn=<AddBackward0>)
516 tensor(0.9321, device='cuda:0', grad

root        : INFO     ------------------------------------
root        : INFO     Iteration: 1000
root        : INFO         Train statistics: {'frame_ap': 0.1137, 'reg_onset_mae': 0.481, 'reg_offset_mae': 0.2099, 'velocity_mae': 0.1137}
root        : INFO         Validation statistics: {'frame_ap': 0.1146, 'reg_onset_mae': 0.4802, 'reg_offset_mae': 0.2085, 'velocity_mae': 0.1141}
root        : INFO         Test statistics: {'frame_ap': 0.1183, 'reg_onset_mae': 0.4812, 'reg_offset_mae': 0.2129, 'velocity_mae': 0.1156}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

1000 tensor(0.9497, device='cuda:0', grad_fn=<AddBackward0>)
1001 tensor(0.9569, device='cuda:0', grad_fn=<AddBackward0>)
1002 tensor(0.9405, device='cuda:0', grad_fn=<AddBackward0>)
1003 tensor(0.9125, device='cuda:0', grad_fn=<AddBackward0>)
1004 tensor(0.8966, device='cuda:0', grad_fn=<AddBackward0>)
1005 tensor(0.9358, device='cuda:0', grad_fn=<AddBackward0>)
1006 tensor(0.9493, device='cuda:0', grad_fn=<AddBackward0>)
1007 tensor(0.9139, device='cuda:0', grad_fn=<AddBackward0>)
1008 tensor(1.0087, device='cuda:0', grad_fn=<AddBackward0>)
1009 tensor(0.9265, device='cuda:0', grad_fn=<AddBackward0>)
1010 tensor(0.9323, device='cuda:0', grad_fn=<AddBackward0>)
1011 tensor(0.9457, device='cuda:0', grad_fn=<AddBackward0>)
1012 tensor(0.9264, device='cuda:0', grad_fn=<AddBackward0>)
1013 tensor(0.9177, device='cuda:0', grad_fn=<AddBackward0>)
1014 tensor(0.9499, device='cuda:0', grad_fn=<AddBackward0>)
1015 tensor(0.9003, device='cuda:0', grad_fn=<AddBackward0>)
1016 tensor(0.9197, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 1500
root        : INFO         Train statistics: {'frame_ap': 0.1138, 'reg_onset_mae': 0.2433, 'reg_offset_mae': 0.135, 'velocity_mae': 0.1131}
root        : INFO         Validation statistics: {'frame_ap': 0.1154, 'reg_onset_mae': 0.243, 'reg_offset_mae': 0.1339, 'velocity_mae': 0.1132}
root        : INFO         Test statistics: {'frame_ap': 0.1173, 'reg_onset_mae': 0.2468, 'reg_offset_mae': 0.1371, 'velocity_mae': 0.1147}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statist

1500 tensor(0.9304, device='cuda:0', grad_fn=<AddBackward0>)
1501 tensor(0.9453, device='cuda:0', grad_fn=<AddBackward0>)
1502 tensor(0.9180, device='cuda:0', grad_fn=<AddBackward0>)
1503 tensor(0.9235, device='cuda:0', grad_fn=<AddBackward0>)
1504 tensor(0.9056, device='cuda:0', grad_fn=<AddBackward0>)
1505 tensor(0.9272, device='cuda:0', grad_fn=<AddBackward0>)
1506 tensor(0.9735, device='cuda:0', grad_fn=<AddBackward0>)
1507 tensor(0.9607, device='cuda:0', grad_fn=<AddBackward0>)
1508 tensor(0.8759, device='cuda:0', grad_fn=<AddBackward0>)
1509 tensor(0.8934, device='cuda:0', grad_fn=<AddBackward0>)
1510 tensor(0.9061, device='cuda:0', grad_fn=<AddBackward0>)
1511 tensor(0.9653, device='cuda:0', grad_fn=<AddBackward0>)
1512 tensor(0.9369, device='cuda:0', grad_fn=<AddBackward0>)
1513 tensor(0.9184, device='cuda:0', grad_fn=<AddBackward0>)
1514 tensor(0.9243, device='cuda:0', grad_fn=<AddBackward0>)
1515 tensor(0.9376, device='cuda:0', grad_fn=<AddBackward0>)
1516 tensor(0.9297, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 2000
root        : INFO         Train statistics: {'frame_ap': 0.1134, 'reg_onset_mae': 0.1268, 'reg_offset_mae': 0.1164, 'velocity_mae': 0.1129}
root        : INFO         Validation statistics: {'frame_ap': 0.1151, 'reg_onset_mae': 0.1264, 'reg_offset_mae': 0.1154, 'velocity_mae': 0.1131}
root        : INFO         Test statistics: {'frame_ap': 0.1178, 'reg_onset_mae': 0.1291, 'reg_offset_mae': 0.1182, 'velocity_mae': 0.1144}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/stati

2000 tensor(0.9263, device='cuda:0', grad_fn=<AddBackward0>)
2001 tensor(0.8997, device='cuda:0', grad_fn=<AddBackward0>)
2002 tensor(0.9500, device='cuda:0', grad_fn=<AddBackward0>)
2003 tensor(0.9118, device='cuda:0', grad_fn=<AddBackward0>)
2004 tensor(0.9079, device='cuda:0', grad_fn=<AddBackward0>)
2005 tensor(0.9581, device='cuda:0', grad_fn=<AddBackward0>)
2006 tensor(0.9473, device='cuda:0', grad_fn=<AddBackward0>)
2007 tensor(0.9010, device='cuda:0', grad_fn=<AddBackward0>)
2008 tensor(0.9286, device='cuda:0', grad_fn=<AddBackward0>)
2009 tensor(0.9174, device='cuda:0', grad_fn=<AddBackward0>)
2010 tensor(0.9133, device='cuda:0', grad_fn=<AddBackward0>)
2011 tensor(0.9225, device='cuda:0', grad_fn=<AddBackward0>)
2012 tensor(0.9401, device='cuda:0', grad_fn=<AddBackward0>)
2013 tensor(0.9785, device='cuda:0', grad_fn=<AddBackward0>)
2014 tensor(0.9362, device='cuda:0', grad_fn=<AddBackward0>)
2015 tensor(0.9525, device='cuda:0', grad_fn=<AddBackward0>)
2016 tensor(0.9255, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 2500
root        : INFO         Train statistics: {'frame_ap': 0.1133, 'reg_onset_mae': 0.1015, 'reg_offset_mae': 0.0927, 'velocity_mae': 0.1136}
root        : INFO         Validation statistics: {'frame_ap': 0.1151, 'reg_onset_mae': 0.1011, 'reg_offset_mae': 0.0919, 'velocity_mae': 0.1139}
root        : INFO         Test statistics: {'frame_ap': 0.1147, 'reg_onset_mae': 0.1033, 'reg_offset_mae': 0.0942, 'velocity_mae': 0.1144}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/stati

2500 tensor(0.9643, device='cuda:0', grad_fn=<AddBackward0>)
2501 tensor(0.9577, device='cuda:0', grad_fn=<AddBackward0>)
2502 tensor(0.8951, device='cuda:0', grad_fn=<AddBackward0>)
2503 tensor(0.8953, device='cuda:0', grad_fn=<AddBackward0>)
2504 tensor(0.9711, device='cuda:0', grad_fn=<AddBackward0>)
2505 tensor(0.9565, device='cuda:0', grad_fn=<AddBackward0>)
2506 tensor(0.9778, device='cuda:0', grad_fn=<AddBackward0>)
2507 tensor(0.9259, device='cuda:0', grad_fn=<AddBackward0>)
2508 tensor(0.9195, device='cuda:0', grad_fn=<AddBackward0>)
2509 tensor(0.9594, device='cuda:0', grad_fn=<AddBackward0>)
2510 tensor(0.9398, device='cuda:0', grad_fn=<AddBackward0>)
2511 tensor(0.9602, device='cuda:0', grad_fn=<AddBackward0>)
2512 tensor(0.9736, device='cuda:0', grad_fn=<AddBackward0>)
2513 tensor(0.9291, device='cuda:0', grad_fn=<AddBackward0>)
2514 tensor(0.9365, device='cuda:0', grad_fn=<AddBackward0>)
2515 tensor(0.9083, device='cuda:0', grad_fn=<AddBackward0>)
2516 tensor(0.9129, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 3000
root        : INFO         Train statistics: {'frame_ap': 0.1128, 'reg_onset_mae': 0.077, 'reg_offset_mae': 0.0974, 'velocity_mae': 0.1144}
root        : INFO         Validation statistics: {'frame_ap': 0.1142, 'reg_onset_mae': 0.0768, 'reg_offset_mae': 0.0966, 'velocity_mae': 0.1144}
root        : INFO         Test statistics: {'frame_ap': 0.1162, 'reg_onset_mae': 0.0785, 'reg_offset_mae': 0.0989, 'velocity_mae': 0.1147}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

3000 tensor(0.9051, device='cuda:0', grad_fn=<AddBackward0>)
3001 tensor(0.9271, device='cuda:0', grad_fn=<AddBackward0>)
3002 tensor(0.9722, device='cuda:0', grad_fn=<AddBackward0>)
3003 tensor(0.9900, device='cuda:0', grad_fn=<AddBackward0>)
3004 tensor(0.9603, device='cuda:0', grad_fn=<AddBackward0>)
3005 tensor(0.9385, device='cuda:0', grad_fn=<AddBackward0>)
3006 tensor(0.9379, device='cuda:0', grad_fn=<AddBackward0>)
3007 tensor(0.9655, device='cuda:0', grad_fn=<AddBackward0>)
3008 tensor(0.9331, device='cuda:0', grad_fn=<AddBackward0>)
3009 tensor(0.9427, device='cuda:0', grad_fn=<AddBackward0>)
3010 tensor(0.9353, device='cuda:0', grad_fn=<AddBackward0>)
3011 tensor(0.9700, device='cuda:0', grad_fn=<AddBackward0>)
3012 tensor(0.9300, device='cuda:0', grad_fn=<AddBackward0>)
3013 tensor(0.9438, device='cuda:0', grad_fn=<AddBackward0>)
3014 tensor(0.9707, device='cuda:0', grad_fn=<AddBackward0>)
3015 tensor(0.9339, device='cuda:0', grad_fn=<AddBackward0>)
3016 tensor(0.8927, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 3500
root        : INFO         Train statistics: {'frame_ap': 0.1132, 'reg_onset_mae': 0.0688, 'reg_offset_mae': 0.0862, 'velocity_mae': 0.1146}
root        : INFO         Validation statistics: {'frame_ap': 0.114, 'reg_onset_mae': 0.0685, 'reg_offset_mae': 0.0854, 'velocity_mae': 0.1145}
root        : INFO         Test statistics: {'frame_ap': 0.1161, 'reg_onset_mae': 0.07, 'reg_offset_mae': 0.0875, 'velocity_mae': 0.1147}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statisti

3500 tensor(0.9333, device='cuda:0', grad_fn=<AddBackward0>)
3501 tensor(0.9473, device='cuda:0', grad_fn=<AddBackward0>)
3502 tensor(0.9402, device='cuda:0', grad_fn=<AddBackward0>)
3503 tensor(0.9196, device='cuda:0', grad_fn=<AddBackward0>)
3504 tensor(0.9123, device='cuda:0', grad_fn=<AddBackward0>)
3505 tensor(0.9413, device='cuda:0', grad_fn=<AddBackward0>)
3506 tensor(0.9263, device='cuda:0', grad_fn=<AddBackward0>)
3507 tensor(0.9191, device='cuda:0', grad_fn=<AddBackward0>)
3508 tensor(0.9180, device='cuda:0', grad_fn=<AddBackward0>)
3509 tensor(0.9676, device='cuda:0', grad_fn=<AddBackward0>)
3510 tensor(0.9549, device='cuda:0', grad_fn=<AddBackward0>)
3511 tensor(0.9539, device='cuda:0', grad_fn=<AddBackward0>)
3512 tensor(0.9545, device='cuda:0', grad_fn=<AddBackward0>)
3513 tensor(0.8867, device='cuda:0', grad_fn=<AddBackward0>)
3514 tensor(0.9578, device='cuda:0', grad_fn=<AddBackward0>)
3515 tensor(0.9637, device='cuda:0', grad_fn=<AddBackward0>)
3516 tensor(0.9403, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 4000
root        : INFO         Train statistics: {'frame_ap': 0.1132, 'reg_onset_mae': 0.0631, 'reg_offset_mae': 0.0905, 'velocity_mae': 0.1146}
root        : INFO         Validation statistics: {'frame_ap': 0.1145, 'reg_onset_mae': 0.0629, 'reg_offset_mae': 0.0897, 'velocity_mae': 0.1145}
root        : INFO         Test statistics: {'frame_ap': 0.1173, 'reg_onset_mae': 0.0643, 'reg_offset_mae': 0.0919, 'velocity_mae': 0.1148}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/stati

4000 tensor(0.9390, device='cuda:0', grad_fn=<AddBackward0>)
4001 tensor(0.9139, device='cuda:0', grad_fn=<AddBackward0>)
4002 tensor(0.9157, device='cuda:0', grad_fn=<AddBackward0>)
4003 tensor(0.9789, device='cuda:0', grad_fn=<AddBackward0>)
4004 tensor(0.9253, device='cuda:0', grad_fn=<AddBackward0>)
4005 tensor(0.9514, device='cuda:0', grad_fn=<AddBackward0>)
4006 tensor(0.9544, device='cuda:0', grad_fn=<AddBackward0>)
4007 tensor(0.9759, device='cuda:0', grad_fn=<AddBackward0>)
4008 tensor(0.9832, device='cuda:0', grad_fn=<AddBackward0>)
4009 tensor(0.9908, device='cuda:0', grad_fn=<AddBackward0>)
4010 tensor(0.9070, device='cuda:0', grad_fn=<AddBackward0>)
4011 tensor(0.9172, device='cuda:0', grad_fn=<AddBackward0>)
4012 tensor(0.9650, device='cuda:0', grad_fn=<AddBackward0>)
4013 tensor(1.0014, device='cuda:0', grad_fn=<AddBackward0>)
4014 tensor(0.9280, device='cuda:0', grad_fn=<AddBackward0>)
4015 tensor(0.9129, device='cuda:0', grad_fn=<AddBackward0>)
4016 tensor(0.9300, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 4500
root        : INFO         Train statistics: {'frame_ap': 0.113, 'reg_onset_mae': 0.0467, 'reg_offset_mae': 0.0521, 'velocity_mae': 0.1148}
root        : INFO         Validation statistics: {'frame_ap': 0.117, 'reg_onset_mae': 0.0466, 'reg_offset_mae': 0.0517, 'velocity_mae': 0.1147}
root        : INFO         Test statistics: {'frame_ap': 0.1157, 'reg_onset_mae': 0.0475, 'reg_offset_mae': 0.0529, 'velocity_mae': 0.115}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statisti

4500 tensor(0.8829, device='cuda:0', grad_fn=<AddBackward0>)
4501 tensor(0.9228, device='cuda:0', grad_fn=<AddBackward0>)
4502 tensor(0.9430, device='cuda:0', grad_fn=<AddBackward0>)
4503 tensor(0.9170, device='cuda:0', grad_fn=<AddBackward0>)
4504 tensor(0.9071, device='cuda:0', grad_fn=<AddBackward0>)
4505 tensor(0.9007, device='cuda:0', grad_fn=<AddBackward0>)
4506 tensor(0.8889, device='cuda:0', grad_fn=<AddBackward0>)
4507 tensor(0.9458, device='cuda:0', grad_fn=<AddBackward0>)
4508 tensor(0.9312, device='cuda:0', grad_fn=<AddBackward0>)
4509 tensor(0.9331, device='cuda:0', grad_fn=<AddBackward0>)
4510 tensor(0.8962, device='cuda:0', grad_fn=<AddBackward0>)
4511 tensor(1.0135, device='cuda:0', grad_fn=<AddBackward0>)
4512 tensor(0.9073, device='cuda:0', grad_fn=<AddBackward0>)
4513 tensor(0.9225, device='cuda:0', grad_fn=<AddBackward0>)
4514 tensor(0.9362, device='cuda:0', grad_fn=<AddBackward0>)
4515 tensor(0.9386, device='cuda:0', grad_fn=<AddBackward0>)
4516 tensor(0.8990, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 5000
root        : INFO         Train statistics: {'frame_ap': 0.1136, 'reg_onset_mae': 0.0623, 'reg_offset_mae': 0.0816, 'velocity_mae': 0.1147}
root        : INFO         Validation statistics: {'frame_ap': 0.1154, 'reg_onset_mae': 0.0621, 'reg_offset_mae': 0.0808, 'velocity_mae': 0.1147}
root        : INFO         Test statistics: {'frame_ap': 0.117, 'reg_onset_mae': 0.0634, 'reg_offset_mae': 0.0828, 'velocity_mae': 0.1149}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

5000 tensor(0.9723, device='cuda:0', grad_fn=<AddBackward0>)
5001 tensor(0.9177, device='cuda:0', grad_fn=<AddBackward0>)
5002 tensor(0.8739, device='cuda:0', grad_fn=<AddBackward0>)
5003 tensor(0.9251, device='cuda:0', grad_fn=<AddBackward0>)
5004 tensor(0.8995, device='cuda:0', grad_fn=<AddBackward0>)
5005 tensor(1.0144, device='cuda:0', grad_fn=<AddBackward0>)
5006 tensor(0.9184, device='cuda:0', grad_fn=<AddBackward0>)
5007 tensor(0.9022, device='cuda:0', grad_fn=<AddBackward0>)
5008 tensor(0.8863, device='cuda:0', grad_fn=<AddBackward0>)
5009 tensor(0.8952, device='cuda:0', grad_fn=<AddBackward0>)
5010 tensor(0.9732, device='cuda:0', grad_fn=<AddBackward0>)
5011 tensor(0.9042, device='cuda:0', grad_fn=<AddBackward0>)
5012 tensor(0.9341, device='cuda:0', grad_fn=<AddBackward0>)
5013 tensor(0.9665, device='cuda:0', grad_fn=<AddBackward0>)
5014 tensor(0.8623, device='cuda:0', grad_fn=<AddBackward0>)
5015 tensor(0.9164, device='cuda:0', grad_fn=<AddBackward0>)
5016 tensor(0.9339, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 5500
root        : INFO         Train statistics: {'frame_ap': 0.1136, 'reg_onset_mae': 0.0625, 'reg_offset_mae': 0.0792, 'velocity_mae': 0.115}
root        : INFO         Validation statistics: {'frame_ap': 0.1151, 'reg_onset_mae': 0.0623, 'reg_offset_mae': 0.0785, 'velocity_mae': 0.1149}
root        : INFO         Test statistics: {'frame_ap': 0.1156, 'reg_onset_mae': 0.0636, 'reg_offset_mae': 0.0805, 'velocity_mae': 0.1151}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

5500 tensor(0.9129, device='cuda:0', grad_fn=<AddBackward0>)
5501 tensor(0.9159, device='cuda:0', grad_fn=<AddBackward0>)
5502 tensor(0.9284, device='cuda:0', grad_fn=<AddBackward0>)
5503 tensor(0.9390, device='cuda:0', grad_fn=<AddBackward0>)
5504 tensor(0.9262, device='cuda:0', grad_fn=<AddBackward0>)
5505 tensor(0.9546, device='cuda:0', grad_fn=<AddBackward0>)
5506 tensor(0.9440, device='cuda:0', grad_fn=<AddBackward0>)
5507 tensor(0.9387, device='cuda:0', grad_fn=<AddBackward0>)
5508 tensor(0.9254, device='cuda:0', grad_fn=<AddBackward0>)
5509 tensor(0.9859, device='cuda:0', grad_fn=<AddBackward0>)
5510 tensor(0.9874, device='cuda:0', grad_fn=<AddBackward0>)
5511 tensor(0.9369, device='cuda:0', grad_fn=<AddBackward0>)
5512 tensor(0.9321, device='cuda:0', grad_fn=<AddBackward0>)
5513 tensor(0.9255, device='cuda:0', grad_fn=<AddBackward0>)
5514 tensor(0.9581, device='cuda:0', grad_fn=<AddBackward0>)
5515 tensor(0.9610, device='cuda:0', grad_fn=<AddBackward0>)
5516 tensor(0.9156, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 6000
root        : INFO         Train statistics: {'frame_ap': 0.1136, 'reg_onset_mae': 0.0623, 'reg_offset_mae': 0.0805, 'velocity_mae': 0.1152}
root        : INFO         Validation statistics: {'frame_ap': 0.1152, 'reg_onset_mae': 0.0621, 'reg_offset_mae': 0.0797, 'velocity_mae': 0.1152}
root        : INFO         Test statistics: {'frame_ap': 0.1169, 'reg_onset_mae': 0.0635, 'reg_offset_mae': 0.0817, 'velocity_mae': 0.1154}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/stati

6000 tensor(0.9082, device='cuda:0', grad_fn=<AddBackward0>)
6001 tensor(0.9505, device='cuda:0', grad_fn=<AddBackward0>)
6002 tensor(0.9245, device='cuda:0', grad_fn=<AddBackward0>)
6003 tensor(1.0040, device='cuda:0', grad_fn=<AddBackward0>)
6004 tensor(0.9086, device='cuda:0', grad_fn=<AddBackward0>)
6005 tensor(0.9382, device='cuda:0', grad_fn=<AddBackward0>)
6006 tensor(0.9914, device='cuda:0', grad_fn=<AddBackward0>)
6007 tensor(0.9174, device='cuda:0', grad_fn=<AddBackward0>)
6008 tensor(0.9627, device='cuda:0', grad_fn=<AddBackward0>)
6009 tensor(0.9193, device='cuda:0', grad_fn=<AddBackward0>)
6010 tensor(0.9162, device='cuda:0', grad_fn=<AddBackward0>)
6011 tensor(0.9150, device='cuda:0', grad_fn=<AddBackward0>)
6012 tensor(0.8877, device='cuda:0', grad_fn=<AddBackward0>)
6013 tensor(0.9609, device='cuda:0', grad_fn=<AddBackward0>)
6014 tensor(0.9215, device='cuda:0', grad_fn=<AddBackward0>)
6015 tensor(0.9124, device='cuda:0', grad_fn=<AddBackward0>)
6016 tensor(0.9685, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 6500
root        : INFO         Train statistics: {'frame_ap': 0.1138, 'reg_onset_mae': 0.069, 'reg_offset_mae': 0.1038, 'velocity_mae': 0.1151}
root        : INFO         Validation statistics: {'frame_ap': 0.1156, 'reg_onset_mae': 0.0687, 'reg_offset_mae': 0.1029, 'velocity_mae': 0.1152}
root        : INFO         Test statistics: {'frame_ap': 0.1168, 'reg_onset_mae': 0.0703, 'reg_offset_mae': 0.1053, 'velocity_mae': 0.1153}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

6500 tensor(0.9506, device='cuda:0', grad_fn=<AddBackward0>)
6501 tensor(0.8974, device='cuda:0', grad_fn=<AddBackward0>)
6502 tensor(0.9351, device='cuda:0', grad_fn=<AddBackward0>)
6503 tensor(0.9676, device='cuda:0', grad_fn=<AddBackward0>)
6504 tensor(0.9253, device='cuda:0', grad_fn=<AddBackward0>)
6505 tensor(0.9020, device='cuda:0', grad_fn=<AddBackward0>)
6506 tensor(0.9405, device='cuda:0', grad_fn=<AddBackward0>)
6507 tensor(0.9354, device='cuda:0', grad_fn=<AddBackward0>)
6508 tensor(0.9401, device='cuda:0', grad_fn=<AddBackward0>)
6509 tensor(0.9240, device='cuda:0', grad_fn=<AddBackward0>)
6510 tensor(1.0471, device='cuda:0', grad_fn=<AddBackward0>)
6511 tensor(0.9612, device='cuda:0', grad_fn=<AddBackward0>)
6512 tensor(0.9274, device='cuda:0', grad_fn=<AddBackward0>)
6513 tensor(0.9407, device='cuda:0', grad_fn=<AddBackward0>)
6514 tensor(0.9318, device='cuda:0', grad_fn=<AddBackward0>)
6515 tensor(0.9496, device='cuda:0', grad_fn=<AddBackward0>)
6516 tensor(0.9441, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 7000
root        : INFO         Train statistics: {'frame_ap': 0.1133, 'reg_onset_mae': 0.0574, 'reg_offset_mae': 0.0663, 'velocity_mae': 0.1155}
root        : INFO         Validation statistics: {'frame_ap': 0.1151, 'reg_onset_mae': 0.0573, 'reg_offset_mae': 0.0657, 'velocity_mae': 0.1156}
root        : INFO         Test statistics: {'frame_ap': 0.1176, 'reg_onset_mae': 0.0585, 'reg_offset_mae': 0.0673, 'velocity_mae': 0.1156}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/stati

7000 tensor(0.9714, device='cuda:0', grad_fn=<AddBackward0>)
7001 tensor(0.8901, device='cuda:0', grad_fn=<AddBackward0>)
7002 tensor(0.9315, device='cuda:0', grad_fn=<AddBackward0>)
7003 tensor(0.9199, device='cuda:0', grad_fn=<AddBackward0>)
7004 tensor(0.9793, device='cuda:0', grad_fn=<AddBackward0>)
7005 tensor(0.9210, device='cuda:0', grad_fn=<AddBackward0>)
7006 tensor(0.9595, device='cuda:0', grad_fn=<AddBackward0>)
7007 tensor(0.9481, device='cuda:0', grad_fn=<AddBackward0>)
7008 tensor(0.9314, device='cuda:0', grad_fn=<AddBackward0>)
7009 tensor(0.9019, device='cuda:0', grad_fn=<AddBackward0>)
7010 tensor(0.9624, device='cuda:0', grad_fn=<AddBackward0>)
7011 tensor(0.9226, device='cuda:0', grad_fn=<AddBackward0>)
7012 tensor(0.9181, device='cuda:0', grad_fn=<AddBackward0>)
7013 tensor(0.9425, device='cuda:0', grad_fn=<AddBackward0>)
7014 tensor(0.9588, device='cuda:0', grad_fn=<AddBackward0>)
7015 tensor(0.8948, device='cuda:0', grad_fn=<AddBackward0>)
7016 tensor(0.9280, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 7500
root        : INFO         Train statistics: {'frame_ap': 0.1142, 'reg_onset_mae': 0.0575, 'reg_offset_mae': 0.0813, 'velocity_mae': 0.1151}
root        : INFO         Validation statistics: {'frame_ap': 0.1154, 'reg_onset_mae': 0.0573, 'reg_offset_mae': 0.0806, 'velocity_mae': 0.1151}
root        : INFO         Test statistics: {'frame_ap': 0.117, 'reg_onset_mae': 0.0585, 'reg_offset_mae': 0.0826, 'velocity_mae': 0.1153}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

7500 tensor(0.9551, device='cuda:0', grad_fn=<AddBackward0>)
7501 tensor(0.8962, device='cuda:0', grad_fn=<AddBackward0>)
7502 tensor(0.9353, device='cuda:0', grad_fn=<AddBackward0>)
7503 tensor(0.9882, device='cuda:0', grad_fn=<AddBackward0>)
7504 tensor(0.9210, device='cuda:0', grad_fn=<AddBackward0>)
7505 tensor(0.9218, device='cuda:0', grad_fn=<AddBackward0>)
7506 tensor(0.9359, device='cuda:0', grad_fn=<AddBackward0>)
7507 tensor(0.9995, device='cuda:0', grad_fn=<AddBackward0>)
7508 tensor(0.8866, device='cuda:0', grad_fn=<AddBackward0>)
7509 tensor(0.9463, device='cuda:0', grad_fn=<AddBackward0>)
7510 tensor(0.9063, device='cuda:0', grad_fn=<AddBackward0>)
7511 tensor(0.9407, device='cuda:0', grad_fn=<AddBackward0>)
7512 tensor(0.9288, device='cuda:0', grad_fn=<AddBackward0>)
7513 tensor(0.8946, device='cuda:0', grad_fn=<AddBackward0>)
7514 tensor(0.8994, device='cuda:0', grad_fn=<AddBackward0>)
7515 tensor(0.8952, device='cuda:0', grad_fn=<AddBackward0>)
7516 tensor(0.9538, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 8000
root        : INFO         Train statistics: {'frame_ap': 0.1142, 'reg_onset_mae': 0.0576, 'reg_offset_mae': 0.0738, 'velocity_mae': 0.1152}
root        : INFO         Validation statistics: {'frame_ap': 0.1157, 'reg_onset_mae': 0.0574, 'reg_offset_mae': 0.0732, 'velocity_mae': 0.1152}
root        : INFO         Test statistics: {'frame_ap': 0.1168, 'reg_onset_mae': 0.0586, 'reg_offset_mae': 0.075, 'velocity_mae': 0.1153}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

8000 tensor(0.8827, device='cuda:0', grad_fn=<AddBackward0>)
8001 tensor(0.9750, device='cuda:0', grad_fn=<AddBackward0>)
8002 tensor(0.9506, device='cuda:0', grad_fn=<AddBackward0>)
8003 tensor(0.9390, device='cuda:0', grad_fn=<AddBackward0>)
8004 tensor(0.9724, device='cuda:0', grad_fn=<AddBackward0>)
8005 tensor(0.9333, device='cuda:0', grad_fn=<AddBackward0>)
8006 tensor(0.9450, device='cuda:0', grad_fn=<AddBackward0>)
8007 tensor(0.9400, device='cuda:0', grad_fn=<AddBackward0>)
8008 tensor(0.9235, device='cuda:0', grad_fn=<AddBackward0>)
8009 tensor(0.8903, device='cuda:0', grad_fn=<AddBackward0>)
8010 tensor(0.9295, device='cuda:0', grad_fn=<AddBackward0>)
8011 tensor(0.9675, device='cuda:0', grad_fn=<AddBackward0>)
8012 tensor(0.9928, device='cuda:0', grad_fn=<AddBackward0>)
8013 tensor(0.9742, device='cuda:0', grad_fn=<AddBackward0>)
8014 tensor(0.9168, device='cuda:0', grad_fn=<AddBackward0>)
8015 tensor(0.9396, device='cuda:0', grad_fn=<AddBackward0>)
8016 tensor(0.9472, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 8500
root        : INFO         Train statistics: {'frame_ap': 0.1138, 'reg_onset_mae': 0.0582, 'reg_offset_mae': 0.0708, 'velocity_mae': 0.1156}
root        : INFO         Validation statistics: {'frame_ap': 0.1156, 'reg_onset_mae': 0.058, 'reg_offset_mae': 0.0702, 'velocity_mae': 0.1156}
root        : INFO         Test statistics: {'frame_ap': 0.1167, 'reg_onset_mae': 0.0592, 'reg_offset_mae': 0.0719, 'velocity_mae': 0.1156}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statis

8500 tensor(0.9014, device='cuda:0', grad_fn=<AddBackward0>)
8501 tensor(0.9714, device='cuda:0', grad_fn=<AddBackward0>)
8502 tensor(0.9342, device='cuda:0', grad_fn=<AddBackward0>)
8503 tensor(0.9146, device='cuda:0', grad_fn=<AddBackward0>)
8504 tensor(0.9828, device='cuda:0', grad_fn=<AddBackward0>)
8505 tensor(0.9413, device='cuda:0', grad_fn=<AddBackward0>)
8506 tensor(0.9245, device='cuda:0', grad_fn=<AddBackward0>)
8507 tensor(0.9053, device='cuda:0', grad_fn=<AddBackward0>)
8508 tensor(0.9296, device='cuda:0', grad_fn=<AddBackward0>)
8509 tensor(0.9062, device='cuda:0', grad_fn=<AddBackward0>)
8510 tensor(0.9897, device='cuda:0', grad_fn=<AddBackward0>)
8511 tensor(1.0078, device='cuda:0', grad_fn=<AddBackward0>)
8512 tensor(0.9486, device='cuda:0', grad_fn=<AddBackward0>)
8513 tensor(0.9017, device='cuda:0', grad_fn=<AddBackward0>)
8514 tensor(0.9534, device='cuda:0', grad_fn=<AddBackward0>)
8515 tensor(0.9277, device='cuda:0', grad_fn=<AddBackward0>)
8516 tensor(0.9803, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 9000
root        : INFO         Train statistics: {'frame_ap': 0.1136, 'reg_onset_mae': 0.0577, 'reg_offset_mae': 0.0698, 'velocity_mae': 0.1157}
root        : INFO         Validation statistics: {'frame_ap': 0.1152, 'reg_onset_mae': 0.0575, 'reg_offset_mae': 0.0692, 'velocity_mae': 0.1157}
root        : INFO         Test statistics: {'frame_ap': 0.1165, 'reg_onset_mae': 0.0588, 'reg_offset_mae': 0.0709, 'velocity_mae': 0.1157}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/stati

9000 tensor(0.9150, device='cuda:0', grad_fn=<AddBackward0>)
9001 tensor(0.9670, device='cuda:0', grad_fn=<AddBackward0>)
9002 tensor(0.9135, device='cuda:0', grad_fn=<AddBackward0>)
9003 tensor(0.9311, device='cuda:0', grad_fn=<AddBackward0>)
9004 tensor(0.9439, device='cuda:0', grad_fn=<AddBackward0>)
9005 tensor(0.9223, device='cuda:0', grad_fn=<AddBackward0>)
9006 tensor(0.9956, device='cuda:0', grad_fn=<AddBackward0>)
9007 tensor(0.9368, device='cuda:0', grad_fn=<AddBackward0>)
9008 tensor(0.9476, device='cuda:0', grad_fn=<AddBackward0>)
9009 tensor(0.9153, device='cuda:0', grad_fn=<AddBackward0>)
9010 tensor(0.9312, device='cuda:0', grad_fn=<AddBackward0>)
9011 tensor(0.9103, device='cuda:0', grad_fn=<AddBackward0>)
9012 tensor(0.9211, device='cuda:0', grad_fn=<AddBackward0>)
9013 tensor(0.9622, device='cuda:0', grad_fn=<AddBackward0>)
9014 tensor(0.9456, device='cuda:0', grad_fn=<AddBackward0>)
9015 tensor(0.9384, device='cuda:0', grad_fn=<AddBackward0>)
9016 tensor(0.9513, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 9500
root        : INFO         Train statistics: {'frame_ap': 0.1137, 'reg_onset_mae': 0.0577, 'reg_offset_mae': 0.065, 'velocity_mae': 0.1157}
root        : INFO         Validation statistics: {'frame_ap': 0.1156, 'reg_onset_mae': 0.0575, 'reg_offset_mae': 0.0644, 'velocity_mae': 0.1156}
root        : INFO         Test statistics: {'frame_ap': 0.1166, 'reg_onset_mae': 0.0587, 'reg_offset_mae': 0.066, 'velocity_mae': 0.1156}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statist

9500 tensor(0.9404, device='cuda:0', grad_fn=<AddBackward0>)
9501 tensor(1.0002, device='cuda:0', grad_fn=<AddBackward0>)
9502 tensor(0.9711, device='cuda:0', grad_fn=<AddBackward0>)
9503 tensor(0.9099, device='cuda:0', grad_fn=<AddBackward0>)
9504 tensor(0.9305, device='cuda:0', grad_fn=<AddBackward0>)
9505 tensor(0.9140, device='cuda:0', grad_fn=<AddBackward0>)
9506 tensor(0.9369, device='cuda:0', grad_fn=<AddBackward0>)
9507 tensor(0.9759, device='cuda:0', grad_fn=<AddBackward0>)
9508 tensor(0.9322, device='cuda:0', grad_fn=<AddBackward0>)
9509 tensor(0.9357, device='cuda:0', grad_fn=<AddBackward0>)
9510 tensor(0.9754, device='cuda:0', grad_fn=<AddBackward0>)
9511 tensor(0.9318, device='cuda:0', grad_fn=<AddBackward0>)
9512 tensor(0.9136, device='cuda:0', grad_fn=<AddBackward0>)
9513 tensor(0.9406, device='cuda:0', grad_fn=<AddBackward0>)
9514 tensor(0.9574, device='cuda:0', grad_fn=<AddBackward0>)
9515 tensor(0.9363, device='cuda:0', grad_fn=<AddBackward0>)
9516 tensor(0.9854, devi

root        : INFO     ------------------------------------
root        : INFO     Iteration: 10000
root        : INFO         Train statistics: {'frame_ap': 0.114, 'reg_onset_mae': 0.058, 'reg_offset_mae': 0.0747, 'velocity_mae': 0.1157}
root        : INFO         Validation statistics: {'frame_ap': 0.1158, 'reg_onset_mae': 0.0578, 'reg_offset_mae': 0.074, 'velocity_mae': 0.1156}
root        : INFO         Test statistics: {'frame_ap': 0.1167, 'reg_onset_mae': 0.0591, 'reg_offset_mae': 0.0759, 'velocity_mae': 0.1156}
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statistics.pkl
root        : INFO         Dump statistics to /kaggle/working/statistics/SAP_P02_train/Regress_onset_offset_frame_velocity_CRNN/loss_type=regress_onset_offset_frame_velocity_bce/augmentation=none/max_note_shift=0/batch_size=12/statist

10000 tensor(0.9199, device='cuda:0', grad_fn=<AddBackward0>)
