# Load from cqt files

In [2]:
!pip install reservoirpy

Collecting reservoirpy
  Downloading reservoirpy-0.3.11-py3-none-any.whl (176 kB)
[K     |████████████████████████████████| 176 kB 656 kB/s eta 0:00:01
[?25hCollecting dill>=0.3.1.1
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[K     |████████████████████████████████| 116 kB 959 kB/s eta 0:00:01
Installing collected packages: dill, reservoirpy
Successfully installed dill-0.3.8 reservoirpy-0.3.11


In [53]:
import pickle
import reservoirpy

model_file_path = 'esn_model.pkl'

with open(model_file_path, 'rb') as file:
    loaded_model = pickle.load(file)

In [111]:
import os
import numpy as np
train_cqt_dir = 'datasets/musicnet/musicnet/train_cqt/'

# load the array
file = os.listdir(train_cqt_dir)[0]
filepath = train_cqt_dir + file
print(filepath)

cqt_array = np.array(np.load(filepath))
print(cqt_array)
print(cqt_array.shape) # freq_bins X time_intervals (one interval = 0.1s)

datasets/musicnet/musicnet/train_cqt/1727.npy
[[ 2.57428765e-04-2.22782124e-04j -5.93174540e-04-3.43217507e-05j
   5.20157162e-04+5.71272685e-04j ...  5.13322302e-05-6.84659483e-07j
  -5.35671388e-05-8.05060699e-06j  5.88376097e-05+4.16767107e-05j]
 [-2.69114738e-04+3.92691763e-05j -3.82055550e-05+6.25704299e-04j
   9.04196466e-04+1.21861180e-04j ... -6.72054157e-05+1.06477870e-04j
   7.02706820e-05+1.08719127e-04j  1.41021737e-04-3.36765697e-05j]
 [ 1.82399323e-04+4.40537144e-04j  8.38531647e-04+2.32781895e-04j
   1.03155267e-03-7.93580606e-04j ... -5.60147600e-05+2.10249462e-04j
   1.74279354e-04+1.81913929e-04j  2.47208431e-04-4.66376296e-05j]
 ...
 [ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j ...  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j]
 [ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j
   0.00000000e+00+0.00000000e+00j ...  0.00000000e+00+0.00000000e+00j
 

In [112]:
cqt_array_transposed = np.abs(np.transpose(cqt_array))
print(cqt_array_transposed[0].shape)

(128,)


# Inference on reservoir model

In [98]:
out = []
for i in range(len(cqt_array_transposed)):
    out.append(loaded_model(cqt_array_transposed[i]))

In [99]:
out = np.transpose(np.squeeze(np.array(out), axis=1))

In [100]:
out = np.where(out > 0.5, 1, 0)
out.shape

(128, 3257)

# Or load inference output of transformer

In [113]:
out = np.load('out_transformer.npy')

In [117]:
out, out.shape

(array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]),
 (128, 5000))

In [114]:
def convert_to_notes(matrix, timestep):
    notes = []

    # Iterate through each row (note pitch) in the matrix
    for pitch_idx, pitch_row in enumerate(matrix):
        start_time = None

        # Iterate through each column (timestep) in the row
        for timestep_idx, value in enumerate(pitch_row):
            # if note is playing
            if value == 1:
                # Record the start time if not already set
                if start_time is None:
                    start_time = timestep_idx * timestep

                # if we're at the end of the row record the note
                if (timestep_idx + 1) == len(pitch_row):
                    end_time = (timestep_idx + 1) * timestep
                    pitch = pitch_idx
                    notes.append([round(start_time, 2), round(end_time, 2), pitch])
                    start_time = None

            # if we have a start_time AND value was not 1 (so 0) then record the note
            elif start_time is not None:
                end_time = timestep_idx * timestep
                pitch = pitch_idx
                notes.append([round(start_time, 2), round(end_time, 2), pitch])
                start_time = None

    # sort by start time
    sorted_notes = sorted(notes, key=lambda x: x[0])       
    return sorted_notes

In [115]:
from midiutil import MIDIFile

def create_midi_from_notes(note_list, output_file="inference.mid", tempo=100):
    # Create a MIDIFile object
    midi = MIDIFile(1, deinterleave=False)

    # Add a track to the MIDI file
    track = 0
    time = 0
    midi.addTrackName(track, time, "Sample Track")

    # Set tempo (in beats per minute)
    midi.addTempo(track, time, tempo)

    # Set program number to change the instrument (40 is violi)
    # midi.addProgramChange(track, time, 0, 40)

    # Iterate over each note in the list of lists and add it to the MIDI file
    for note_start, note_end, note_pitch in note_list:
        # convert to seconds, WTF ??????
        # note_start /= 43447
        # note_end /= 43447

        # Convert second time values to quarter notes
        # quarter note duration, get the tempo in bps (/60) then take the inverse
        quarter_note_duration_s = 1 / (tempo / 60)
        note_start_time_quarter_notes = note_start / quarter_note_duration_s
        note_duration_quarter_notes = (note_end - note_start) / quarter_note_duration_s
        # print(f'note start: {note_start_time_quarter_notes}\tduration: {note_duration_quarter_notes}\tpitch : {note_pitch}')

        # Add the note using the converted time values
        midi.addNote(track, 0, note_pitch, note_start_time_quarter_notes, note_duration_quarter_notes, volume=100)

    # Write the MIDI data to a file
    with open(output_file, "wb") as midi_file:
        midi.writeFile(midi_file)

In [118]:
note_events = convert_to_notes(out, 0.1)
create_midi_from_notes(note_events)

[[0.0, 0.3, 61], [0.0, 0.1, 69], [0.1, 0.4, 64], [0.1, 0.2, 81], [0.2, 0.3, 69], [0.3, 0.4, 68], [0.3, 0.4, 71], [0.3, 0.5, 80], [0.5, 0.6, 64], [0.6, 0.7, 71], [0.9, 1.0, 69], [0.9, 1.0, 81], [1.0, 1.6, 64], [1.1, 1.4, 61], [1.1, 1.4, 69], [1.1, 1.2, 76], [1.2, 1.3, 81], [1.4, 1.5, 73], [1.4, 1.5, 85], [1.5, 1.6, 45], [1.6, 1.7, 73], [1.6, 1.7, 85], [1.8, 1.9, 32], [1.8, 1.9, 65], [1.9, 2.0, 44], [1.9, 2.0, 61], [2.1, 2.3, 42], [2.4, 2.6, 61], [2.4, 2.5, 79], [2.5, 2.6, 73], [2.6, 2.7, 62], [2.8, 2.9, 62], [3.0, 3.1, 58], [3.0, 3.1, 70], [3.2, 3.4, 59], [3.4, 3.5, 35], [3.4, 3.5, 62], [3.6, 3.7, 61], [3.6, 3.7, 66], [3.6, 3.7, 73], [3.7, 3.8, 40], [3.7, 3.8, 71], [4.3, 4.8, 61], [4.3, 4.4, 69], [4.5, 4.6, 45], [4.7, 4.8, 64], [4.8, 5.0, 40], [4.8, 5.1, 62], [5.1, 5.2, 40], [5.3, 5.4, 40], [5.6, 5.9, 76], [5.7, 5.9, 69], [6.0, 6.3, 40], [6.0, 6.2, 62], [6.5, 6.6, 69], [6.5, 6.8, 73], [6.5, 7.0, 85], [6.6, 6.7, 81], [6.8, 6.9, 69], [6.8, 7.0, 76], [7.0, 7.1, 64], [7.0, 7.1, 69], [7.0, 7