In [47]:
VOICE = 1
LAG = 128
NUM_PREDICT = 20 * 8

In [48]:
import pandas as pd
import numpy as np
from sklearn.multioutput import RegressorChain
from sklearn.linear_model import Ridge
import midiutil

# Import the csv
piano_input = pd.read_csv('F.txt', sep='\t', header=None, names=['v1', 'v2', 'v3', 'v4'])
# Transform to midi values
midi_input = piano_input.where(piano_input == 0, piano_input + 8)
l, w = midi_input.shape
# Lagged input
X = midi_input.to_numpy()
for lag in range(1, LAG + 1):
    next_lag = midi_input[lag:]
    X = np.delete(X, l - lag, axis=0)
    X = np.append(X, next_lag, axis=1)

# Remove y values from x
np.delete(X, 0, axis=1)

# Multi-output
Y = midi_input.to_numpy()

# Use sklearn's regressorchain to fit 4 correlated regressors
chain = RegressorChain(Ridge(alpha=100000)).fit(X, Y[:-LAG])
print(f"Score: {chain.score(X, Y[:-LAG])}")

# Predict some notes
for i in range(NUM_PREDICT):
    next_notes = chain.predict(X[-lag:])
    next_notes = next_notes.reshape(lag, 4)[-1]
    next_row = np.append(next_notes, X[-1, :-4]).reshape(1, -1)
    X = np.append(X, next_row, axis=0)

predicted_notes = pd.DataFrame(np.round(X[-NUM_PREDICT:, 0:4]))
predicted_notes.describe()

Score: 0.9912141401047205


Unnamed: 0,0,1,2,3
count,160.0,160.0,160.0,160.0
mean,72.76875,67.43125,51.475,45.86875
std,5.216353,1.604325,1.625079,2.271292
min,67.0,64.0,48.0,41.0
25%,68.0,67.0,51.0,44.0
50%,71.0,67.0,51.0,45.0
75%,76.0,69.0,53.0,47.0
max,83.0,70.0,55.0,52.0


In [49]:
from dataclasses import dataclass

note_len = .25

@dataclass
class MidiNote:
    note: int
    time: float
    dura: float

# Generate midi file
midi = midiutil.MIDIFile(4)
midi.addTempo(0, 0, 120)

cur = [MidiNote(note, 0, note_len) for note in predicted_notes[0]]

for time, step in enumerate(predicted_notes.to_numpy()):
    for track, note in enumerate(step):
        if cur[track].note == int(note):
            cur[track].dura += note_len
        else:
            if cur[track].note != 0:
                midi.addNote(track, 0, int(cur[track].note), cur[track].time, cur[track].dura, 95)
            cur[track] = MidiNote(int(note), (time + 1) * note_len, note_len)

with open('simplified_generated.mid', 'wb') as output:
    midi.writeFile(output)


[time: 0000][track: 0][note: 76]
[time: 0000][track: 1][note: 69]
[time: 0000][track: 2][note: 50]
[time: 0000][track: 3][note: 47]
[time: 0001][track: 0][note: 76]
[time: 0001][track: 1][note: 69]
[time: 0001][track: 2][note: 50]
[time: 0001][track: 3][note: 48]
[time: 0002][track: 0][note: 76]
[time: 0002][track: 1][note: 68]
[time: 0002][track: 2][note: 49]
[time: 0002][track: 3][note: 49]
[time: 0003][track: 0][note: 76]
[time: 0003][track: 1][note: 68]
[time: 0003][track: 2][note: 49]
[time: 0003][track: 3][note: 49]
[time: 0004][track: 0][note: 76]
[time: 0004][track: 1][note: 68]
[time: 0004][track: 2][note: 50]
[time: 0004][track: 3][note: 48]
[time: 0005][track: 0][note: 77]
[time: 0005][track: 1][note: 68]
[time: 0005][track: 2][note: 50]
[time: 0005][track: 3][note: 48]
[time: 0006][track: 0][note: 78]
[time: 0006][track: 1][note: 68]
[time: 0006][track: 2][note: 50]
[time: 0006][track: 3][note: 48]
[time: 0007][track: 0][note: 79]
[time: 0007][track: 1][note: 69]
[time: 000