In [1]:
import pandas as pd

In [2]:
df = pd.read_csv('archive/train/chorale_000.csv')

In [3]:
df

Unnamed: 0,note0,note1,note2,note3
0,74,70,65,58
1,74,70,65,58
2,74,70,65,58
3,74,70,65,58
4,75,70,58,55
...,...,...,...,...
187,70,65,62,46
188,70,65,62,46
189,70,65,62,46
190,70,65,62,46


In [4]:
import os

train_files = sorted([os.path.join('archive', 'train', f) for f in os.listdir(os.path.join('archive', 'train')) if f.endswith('.csv')])
test_files = sorted([os.path.join('archive', 'test', f) for f in os.listdir(os.path.join('archive', 'test')) if f.endswith('.csv')])
valid_files = sorted([os.path.join('archive', 'valid', f) for f in os.listdir(os.path.join('archive', 'valid')) if f.endswith('.csv')])

In [5]:
train_files

['archive/train/chorale_000.csv',
 'archive/train/chorale_001.csv',
 'archive/train/chorale_002.csv',
 'archive/train/chorale_003.csv',
 'archive/train/chorale_004.csv',
 'archive/train/chorale_005.csv',
 'archive/train/chorale_006.csv',
 'archive/train/chorale_007.csv',
 'archive/train/chorale_008.csv',
 'archive/train/chorale_009.csv',
 'archive/train/chorale_010.csv',
 'archive/train/chorale_011.csv',
 'archive/train/chorale_012.csv',
 'archive/train/chorale_013.csv',
 'archive/train/chorale_014.csv',
 'archive/train/chorale_015.csv',
 'archive/train/chorale_016.csv',
 'archive/train/chorale_017.csv',
 'archive/train/chorale_018.csv',
 'archive/train/chorale_019.csv',
 'archive/train/chorale_020.csv',
 'archive/train/chorale_021.csv',
 'archive/train/chorale_022.csv',
 'archive/train/chorale_023.csv',
 'archive/train/chorale_024.csv',
 'archive/train/chorale_025.csv',
 'archive/train/chorale_026.csv',
 'archive/train/chorale_027.csv',
 'archive/train/chorale_028.csv',
 'archive/trai

In [6]:
train_data = [pd.read_csv(f).values.tolist() for f in train_files]
test_data = [pd.read_csv(f).values.tolist() for f in test_files]
valid_data = [pd.read_csv(f).values.tolist() for f in valid_files]

In [7]:
train_data

[[[74, 70, 65, 58],
  [74, 70, 65, 58],
  [74, 70, 65, 58],
  [74, 70, 65, 58],
  [75, 70, 58, 55],
  [75, 70, 58, 55],
  [75, 70, 60, 55],
  [75, 70, 60, 55],
  [77, 69, 62, 50],
  [77, 69, 62, 50],
  [77, 69, 62, 50],
  [77, 69, 62, 50],
  [77, 70, 62, 55],
  [77, 70, 62, 55],
  [77, 69, 62, 55],
  [77, 69, 62, 55],
  [75, 67, 63, 48],
  [75, 67, 63, 48],
  [75, 69, 63, 48],
  [75, 69, 63, 48],
  [74, 70, 65, 46],
  [74, 70, 65, 46],
  [74, 70, 65, 46],
  [74, 70, 65, 46],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [72, 69, 65, 53],
  [74, 70, 65, 46],
  [74, 70, 65, 46],
  [74, 70, 65, 46],
  [74, 70, 65, 46],
  [75, 69, 63, 48],
  [75, 69, 63, 48],
  [75, 67, 63, 48],
  [75, 67, 63, 48],
  [77, 65, 62, 50],
  [77, 65, 62, 50],
  [77, 65, 60, 50],
  [77, 65, 60, 50],
  [74, 67, 58, 55],
  [74, 67, 58, 55],
  [74, 67, 58, 53],
  [74, 67, 58, 53],
  [72, 67, 58, 51],
  [72, 67, 58, 51],


36 = C1
81 = A5
0 -> silence

In [8]:
from music21 import stream, chord

chorale = train_data[20]

s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))

s.show('midi')

### Preprocessing

In [9]:
import numpy as np

min_note, max_note = 36, 81
window_size, window_offset, batch_size = 32, 16, 32

def make_xy(chorales):
    windows = [c[i:i + window_size + 1] for c in chorales for i in range(0, len(c) - window_size, window_offset)]

    data = np.array(windows, dtype=int)

    data = np.where(data == 0, 0, data - min_note + 1)
    data = np.clip(data, 0, max_note - min_note + 1)

    flat = data.reshape(data.shape[0], -1)

    return flat[:, :-1], flat[:, 1:]

X_train, y_train = make_xy(train_data)
X_test, y_test = make_xy(test_data)
X_valid, y_valid = make_xy(valid_data)

In [11]:
X_train.shape, y_train.shape

((3111, 131), (3111, 131))

### Training the Model

In [12]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, Dense, Embedding, LSTM, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Nadam

In [14]:
model = Sequential()

model.add(Embedding(input_dim=47, output_dim=5, input_shape=[None]))

model.add(Conv1D(32, kernel_size=2, padding='causal', activation='relu'))
model.add(BatchNormalization())
model.add(Conv1D(48, kernel_size=2, padding='causal', activation='relu', dilation_rate=2))
model.add(BatchNormalization())
model.add(Conv1D(48, kernel_size=2, padding='causal', activation='relu', dilation_rate=4))
model.add(BatchNormalization())
model.add(Conv1D(96, kernel_size=2, padding='causal', activation='relu', dilation_rate=8))
model.add(BatchNormalization())
model.add(Conv1D(128, kernel_size=2, padding='causal', activation='relu', dilation_rate=16))

model.add(Dropout(0.05))
model.add(LSTM(256, return_sequences=True))

model.add(Dense(47, activation='softmax'))

model.summary()

In [15]:
optimizer = Nadam(learning_rate=1e-3)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
model.fit(X_train, y_train, epochs=20, validation_data=[X_valid, y_valid], batch_size=batch_size)

Epoch 1/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 721ms/step - accuracy: 0.4934 - loss: 1.9278 - val_accuracy: 0.0163 - val_loss: 4.0772
Epoch 2/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 594ms/step - accuracy: 0.7543 - loss: 0.9346 - val_accuracy: 0.0605 - val_loss: 3.6559
Epoch 3/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 591ms/step - accuracy: 0.7834 - loss: 0.7807 - val_accuracy: 0.1329 - val_loss: 3.1481
Epoch 4/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 596ms/step - accuracy: 0.7981 - loss: 0.7039 - val_accuracy: 0.2553 - val_loss: 2.5657
Epoch 5/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 663ms/step - accuracy: 0.8095 - loss: 0.6514 - val_accuracy: 0.3681 - val_loss: 1.9841
Epoch 6/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 669ms/step - accuracy: 0.8177 - loss: 0.6151 - val_accuracy: 0.6939 - val_loss: 1.0254
Epoch 7/20
[1m98/98[

<keras.src.callbacks.history.History at 0x72d51c132dd0>

In [16]:
def sample_next_note(probs):
    probabilities = np.asarray(probs, dtype=float)

    prob_sum = probabilities.sum()

    if prob_sum <= 0 or not np.isfinite(prob_sum):
        return int(np.argmax(probabilities))

    probabilities /= prob_sum

    return np.random.choice(len(probabilities), p=probabilities)

In [22]:
def generate_chorale(model, seed_chrods, length):
    token_sequence = np.array(seed_chrods, dtype=int)
    token_sequence = np.where(token_sequence == 0, 0, token_sequence - min_note + 1)
    token_sequence = token_sequence.reshape(1, -1)

    for _ in range(length * 4):
        next_token_probabilities = model.predict(token_sequence)[0, -1]
        next_token = sample_next_note(next_token_probabilities)
        token_sequence = np.concatenate([token_sequence, [[next_token]]], axis=1)

    token_sequence = np.where(token_sequence == 0, 0, token_sequence + min_note - 1)

    return token_sequence.reshape(-1, 4)

In [23]:
seed_chords = test_data[2]

chorale = seed_chords

s = stream.Stream()

for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))

s.show('midi')

In [24]:
seed_chords = test_data[2][:8]

new_chorale = generate_chorale(model, seed_chords, 56)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 570ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 77ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 77ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4

In [25]:
new_chorale

array([[73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [73, 68, 61, 53],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [69, 66, 61, 54],
       [74, 66, 59, 56],
       [74, 66, 59, 56],
       [74, 66, 59, 56],
       [74, 66, 59, 56],
       [74, 64, 59, 56],
       [74, 64, 59, 56],
       [73, 64, 57, 52],
       [73, 64, 57, 52],
       [71, 64, 57, 52],
       [71, 64, 57, 52],
       [71, 64, 56, 52],
       [71, 64, 56, 52],
       [69, 64, 57, 52],
       [69, 64, 57, 52],
       [69, 64, 57, 52],
       [69, 64, 57, 52],
       [71, 66, 57, 51],
       [71, 66, 58, 51],
       [73, 64, 58, 48],
       [73, 64, 58, 48],
       [66, 63, 58, 47],
       [66, 63, 58, 47],
       [66, 63, 57, 46],
       [66, 63, 57, 46],
       [67, 59, 55, 40],
       [67, 59, 55, 40],
       [67, 59, 55, 43],
       [67, 59, 55, 43],
       [67, 64, 59, 43],
       [67, 64, 59, 55],
       [67, 64, 59, 55],
       [67, 64, 59, 55],


In [26]:
chorale = new_chorale.tolist()

s = stream.Stream()

for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))

s.show('midi')

In [27]:
def generate_random_chorale(length, rest_probability=0.2, pitch_low=36, pitch_high=81, seed=None):
    rng = np.random.default_rng(seed)   # random number generator
    random_pitches = rng.integers(pitch_low, pitch_high + 1, size = (length, 4)) # generate random notes

    # some masking to have both silence and random pitches
    rest_mask = rng.random((length, 4)) < float(rest_probability)
    chorale = np.where(rest_mask, 0, random_pitches).astype(int)

    return chorale

In [29]:
chorale = generate_random_chorale(56).tolist()
s = stream.Stream()
for row in chorale:
    s.append(chord.Chord([n for n in row if n], quarterLength=1))
s.show('midi')