### Generating A Drum Rhythm with NN

#### ToDo's

**17.05.2024** (Friday)
- Connnect to VS code! ✅
- Choose framework - Jax / PyTorch / Keras (not recommended for its simplicity) ✅
- Pick dataset ✅

    > Groove MIDI Dataset [Link](https://magenta.tensorflow.org/datasets/groove#download) + [Paper](https://arxiv.org/pdf/1905.06118)

- Pick model - RNN (+ MLP?) ✅
- Study the dataset and think of pre-processing it
- Study framework


**21.05.2024** (Tuesday)
- Define problem
- Define goal
- Start with setting up training model template


**23.05.2024** (Thursday)
- We code a model that acts as a generator: after running the code it should produce a simple beat pattern
- Keep things as simple as possible: use only one sounds for now
- Preprocessing and chosing the right format/encdoing is going to be the hardest part 
    - Time series: discretize, align all samples so they start on the same bar
    - Filter out different sounds (keep only one drum sounds)
    - 4/4 bar but leave space (32 time slices) for more complex pattern
    - Keep same duration for Feed-Forward Network or not for RNN
 - To keep it simple, maybe generate own data with the desired format
 - Regular meetings!!!
 - Keras is also fine
 - Understand the dataset really well before setting up the network
 - RNN not recommended if first time
 - 


 **27.05.2024** (Monday)
 - Dateset?
 - Framework division
 - Data format
 - Feed-Forward or RNN
 - Regular Meetings


#### Pre-processing 

**Resources:**

> [Short IBM Article: Data preprocessing in detail](https://developer.ibm.com/articles/data-preprocessing-in-detail/)  
> 
> [Medium Detailed Article with Example: Data pre-processing: A step-by-step guide](https://towardsdatascience.com/data-pre-processing-a-step-by-step-guide-541b083912b5)
>
> [Easy Kaggle Article: Machine Learning Episode #1: Data Pre-processing steps using Python](https://www.kaggle.com/discussions/getting-started/151612)
>
> 





#### Frameworks

**JAX**

> [Extensive Article with Example for RNN, MLP, and CNN: Getting started with JAX (MLPs, CNNs & RNNs)](https://roberttlange.com/posts/2020/03/blog-post-10/)


### Main Code

In [1]:
class DataPoint:
    def __init__(self,
                 midi_data,
                 features,
                 audio_file_path,
                 midi_file_path,
                 purpose) -> None:
        self.midi_data = midi_data
        self.features = features
        self.audio_file_path = audio_file_path
        self.midi_file_path = midi_file_path
        self.purpose = purpose

feature_names = ["style", "bpm", "beat_type", "time_signature", "duration"]
feature_to_column_idx = {
    "session" : 0,
    "drummer" : 1,
    "id" : 2,
    "style" : 3,
    "bpm" : 4,
    "beat_type" : 5,
    "time_signature" : 6,
    "midi_filename" : 7,
    "audio_filename" : 8,
    "duration" : 9,
    "split" : 10
}

from pretty_midi import PrettyMIDI
from pathlib import Path


# Init for data storage
X = []

# Define filename and its path
data_filename = "info.csv"
data_folder_path = Path().cwd() / "data"
data_path = data_folder_path / data_filename

# Open file with data information
with data_path.open() as file:
    columns = file.readline()
    train_idx = 0
    test_idx = 0
    validation_idx = 0
    for i, line in enumerate(file):
        # Strip line of new line and split into list
        line = line.strip().split(",")

        bpm = line[4]
        time_signature = line[6]
        if time_signature != "4-4":
            continue

        # Attach absolute path to midi & audio files
        midi_data_path = data_folder_path / line[7]
        audio_data_path = data_folder_path / line[8]

        # Create custom data point with all info
        midi_data = PrettyMIDI(str(midi_data_path))

        features = {i: line[feature_to_column_idx[i]] for i in feature_names}
        data_point = DataPoint(
            midi_data=midi_data,
            audio_file_path=audio_data_path,
            midi_file_path=midi_data_path,
            features=features,
            purpose=line[feature_to_column_idx["split"]]
        )

        X.append(data_point)
        # print("Next sample\n")
        # if i == 10: print("Breaking"); break

In [3]:
from pretty_midi import PrettyMIDI
import numpy as np
from pre_processing_tools import filter_to_kick_drum
from pre_processing_tools import align_all_notes_to_origin
from pre_processing_tools import is_first_note_aligned
from pre_processing_tools import make_all_notes_same_duartion
from pre_processing_tools import make_all_notes_same_volume
from pre_processing_tools import make_instrument_not_drum
from pre_processing_tools import slice_midi_to_batches
from pre_processing_tools import midi_to_bit_vector


# Initiate data structures
X_train = []
X_test = []
X_validation = []
# X_train = [0] * 897
# X_test = [0] * 129
# X_validation = [0] * 124
# y_train = np.zeros(())
# y_test = np.zeros(())

pipeline = [
    filter_to_kick_drum,
    make_all_notes_same_duartion,
    make_all_notes_same_volume,
    make_instrument_not_drum
]

for x in X:
    x_midi = PrettyMIDI(str(x.midi_file_path))
    for process in pipeline:
        try:
            x_midi = process(x_midi)
        except Exception as er:
            # print(er)
            break
    else:
        try:
            if not is_first_note_aligned(x_midi):
                x_midi = align_all_notes_to_origin(x_midi)    
        except:
            continue
        batches = slice_midi_to_batches(x_midi)
        # x.midi_data = x_midi
        if len(batches) == 1:
            bit_vectors = [midi_to_bit_vector(x_midi)]
        else:
            bit_vectors = []
            for batch in batches:
                print()
                print(batch)
                # print(*batch.instruments[0].notes, sep='\n')
                if not is_first_note_aligned(batch):
                    batch = align_all_notes_to_origin(batch)
                bit_vect = midi_to_bit_vector(batch)
                bit_vectors.append(bit_vect)
        if x.purpose == "test":
            X_test.extend(bit_vectors)
        elif x.purpose == "train":
            X_train.extend(bit_vectors)
        elif x.purpose == "validation":
            X_validation.extend(bit_vectors)
        else:
            print("Data not categorized!")
    continue

print(len(X_train))
print(len(X_test))
print(len(X_validation))


<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD23E4F610>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8110>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8210>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD83D0>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8590>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8550>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8710>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8890>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD8C50>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD9310>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBD9B10>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBDA290>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BBDAA10>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BE14050>

<pretty_midi.pretty_midi.PrettyMIDI object at 0x000001FD5BE14

AttributeError: 'int' object has no attribute 'instruments'

In [None]:
from pre_processing_tools import slice_midi_to_batches
test_midi = X_train[0].midi_data
# print(*test_midi.instruments[0].notes, sep='\n')
print(test_midi.get_end_time() == test_midi.instruments[0].notes[-1].end)
print(test_midi.instruments[0].notes[-1].end)
print(test_midi.get_end_time())
batches = slice_midi_to_batches(test_midi)
print(len(test_midi.instruments[0].notes))
print(len(batches))
for batch in batches:
    print(batch)
    # print(batch.get_end_time())
    # print("Num of notes: ", len(batch.instruments[0].notes))
# print(*batches, sep="\n")

False
83.45625000000001
87.0
Note #0/120
Batch idx: 0
Note #1/120
Batch idx: 0
Note #2/120
Batch idx: 0
Note #3/120
Batch idx: 0
Note #4/120
Batch idx: 0
	 Batch idx: 0
	 Num of notes: 4
Note #5/120
Batch idx: 1
Note #6/120
Batch idx: 1
	 Batch idx: 1
	 Num of notes: 1
Note #7/120
Batch idx: 2
Note #8/120
Batch idx: 2
Note #9/120
Batch idx: 2
	 Batch idx: 2
	 Num of notes: 2
Note #10/120
Batch idx: 3
Note #11/120
Batch idx: 3
Note #12/120
Batch idx: 3
	 Batch idx: 3
	 Num of notes: 2
Note #13/120
Batch idx: 4
Note #14/120
Batch idx: 4
Note #15/120
Batch idx: 4
	 Batch idx: 4
	 Num of notes: 2
Note #16/120
Batch idx: 5
Note #17/120
Batch idx: 5
	 Batch idx: 5
	 Num of notes: 1
Note #18/120
Batch idx: 6
Note #19/120
Batch idx: 6
Note #20/120
Batch idx: 6
	 Batch idx: 6
	 Num of notes: 2
Note #21/120
Batch idx: 7
Note #22/120
Batch idx: 7
Note #23/120
Batch idx: 7
	 Batch idx: 7
	 Num of notes: 2
Note #24/120
Batch idx: 8
Note #25/120
Batch idx: 8
	 Batch idx: 8
	 Num of notes: 1
Note #26