In [None]:
# Install dependencies
!sudo apt install -y fluidsynth
!pip install --upgrade pyfluidsynth
!pip install pretty_midi

# Imports
import collections
import datetime
import fluidsynth
import glob
import numpy as np
import pathlib
import pandas as pd
import pretty_midi
import seaborn as sns
import tensorflow as tf

from IPython import display
from matplotlib import pyplot as plt
from typing import Optional

# Set random seed
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

# Load dataset
data_dir = pathlib.Path('data/maestro-v3.0.0')
filenames = glob.glob(str(data_dir / '**/*.mid*'))
print('Number of files:', len(filenames))

# Sample file
sample_file = filenames[1]
pm = pretty_midi.PrettyMIDI(sample_file)

# Display audio
def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):
    waveform = pm.fluidsynth(fs=16000)
    display.display(display.Audio(waveform[:seconds * 16000], rate=16000))

display_audio(pm)

# Show instruments and sample notes
instrument = pm.instruments[0]
for i, note in enumerate(instrument.notes[:10]):
    print(f'{i}: pitch={note.pitch}, start={note.start:.4f}, end={note.end:.4f}, velocity={note.velocity}')

# Extract notes from MIDI
def extract_notes(midi_file: str) -> pd.DataFrame:
    pm = pretty_midi.PrettyMIDI(midi_file)
    instrument = pm.instruments[0]
    notes = collections.defaultdict(list)

    for note in instrument.notes:
        notes['pitch'].append(note.pitch)
        notes['start'].append(note.start)
        notes['end'].append(note.end)
        notes['velocity'].append(note.velocity)

    return pd.DataFrame(notes)

raw_notes = extract_notes(sample_file)
print(raw_notes.head())

# Plot notes
def Plot_notes(notes: pd.DataFrame, count: Optional[int] = None):
    if count:
        notes = notes.iloc[:count]
    plt.figure(figsize=(15, 5))
    plt.scatter(notes['start'], notes['pitch'], c=notes['velocity'], cmap='viridis', alpha=0.7)
    plt.xlabel("Time")
    plt.ylabel("Pitch")
    plt.title("MIDI Note Visualization")
    plt.colorbar(label='Velocity')
    plt.show()

Plot_notes(raw_notes, count=100)

# Distribution plots
def plot_dist(notes: pd.DataFrame, drop_percentile=2.5):
    for col in ['pitch', 'start', 'end', 'velocity']:
        values = notes[col]
        lower = np.percentile(values, drop_percentile)
        upper = np.percentile(values, 100 - drop_percentile)
        trimmed = values[(values > lower) & (values < upper)]
        sns.histplot(trimmed, bins=50, kde=True)
        plt.title(f'Distribution of {col}')
        plt.show()

plot_dist(raw_notes)

# Convert notes back to MIDI
def construct_midi(notes: pd.DataFrame, out_file: str = None) -> pretty_midi.PrettyMIDI:
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)
    prev_end = 0

    for i, row in notes.iterrows():
        start = float(row['start'])
        end = float(row['end'])
        pitch = int(row['pitch'])
        velocity = int(row['velocity'])

        note = pretty_midi.Note(
            velocity=velocity, pitch=pitch, start=start, end=end
        )
        instrument.notes.append(note)

    pm.instruments.append(instrument)
    if out_file:
        pm.write(out_file)
    return pm

example_pm = construct_midi(raw_notes.iloc[:50])
display_audio(example_pm)
