In [38]:
import matplotlib.pyplot as plt
import mido
import numpy as np
import os
import pandas as pd
import seaborn as sns
import warnings
from music21 import chord, converter, instrument, note, pitch
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

## Data Collection
This section collects MIDI file paths and their corresponding composer labels from a specified dataset directory, organized into different splits (dev, test, train). It filters the files based on a predefined list of composers, appending the paths and labels to respective lists. Finally, it prints out the number of collected files and samples of file paths and labels for verification.

In [44]:
# Define the path to the dataset
dataset_path = 'Composer_Dataset/NN_midi_files_extended'

# Initialize dictionaries to hold the data
data = {'dev': [], 'test': [], 'train': []}
labels = {'dev': [], 'test': [], 'train': []}

# List of composers to include
composers_to_include = ['bach', 'beethoven', 'chopin', 'mozart']

# Data Collection
for dataset_split in ['dev', 'test', 'train']:
    split_path = os.path.join(dataset_path, dataset_split)
    for composer_folder in os.listdir(split_path):
        if composer_folder.lower() in composers_to_include:
            composer_path = os.path.join(split_path, composer_folder)
            if os.path.isdir(composer_path):
                for midi_file in os.listdir(composer_path):
                    if midi_file.endswith('.mid'):
                        midi_path = os.path.join(composer_path, midi_file)
                        labels[dataset_split].append(composer_folder)
                        data[dataset_split].append(midi_path)

# Verify 
for dataset_split in ['dev', 'test', 'train']:
    print(f"Collected {len(data[dataset_split])} MIDI files in {dataset_split} split.")
    print("Sample data paths:")
    for path in data[dataset_split][:5]:
        print(f"  - {path}")
    print("Sample labels:")
    for label in labels[dataset_split][:5]:
        print(f"  - {label}")

Collected 12 MIDI files in dev split.
Sample data paths:
  - Composer_Dataset/NN_midi_files_extended/dev/mozart/mozart039.mid
  - Composer_Dataset/NN_midi_files_extended/dev/mozart/mozart035.mid
  - Composer_Dataset/NN_midi_files_extended/dev/mozart/mozart020.mid
  - Composer_Dataset/NN_midi_files_extended/dev/mozart/mozart040.mid
  - Composer_Dataset/NN_midi_files_extended/dev/chopin/chopin069.mid
Sample labels:
  - mozart
  - mozart
  - mozart
  - mozart
  - chopin
Collected 12 MIDI files in test split.
Sample data paths:
  - Composer_Dataset/NN_midi_files_extended/test/mozart/mozart014.mid
  - Composer_Dataset/NN_midi_files_extended/test/mozart/mozart038.mid
  - Composer_Dataset/NN_midi_files_extended/test/mozart/mozart004.mid
  - Composer_Dataset/NN_midi_files_extended/test/mozart/mozart025.mid
  - Composer_Dataset/NN_midi_files_extended/test/chopin/chopin053.mid
Sample labels:
  - mozart
  - mozart
  - mozart
  - mozart
  - chopin
Collected 124 MIDI files in train split.
Sample da

## Data Pre-processing
This section preprocesses MIDI files by extracting note and chord information, handling potential parsing errors and ignoring warnings. It processes the files in batches for better debugging and tracks successfully processed files as well as those that failed. The script prints the progress, number of successfully processed files, and logs any files that encountered errors.

In [45]:
def preprocess_midi(file_path):
    """
    Preprocess a MIDI file to extract note, chord, and tempo information.
    """
    try:
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=UserWarning)
            midi = converter.parse(file_path)
    except Exception as e:
        print(f"Error parsing {file_path}: {e}")
        return []

    notes = []
    for element in midi.flat.notes:
        if isinstance(element, note.Note):
            notes.append(str(element.pitch))
        elif isinstance(element, chord.Chord):
            notes.append('.'.join(str(n) for n in element.normalOrder))
        else:
            continue  # Skip elements that are not notes or chords

    return notes

# Apply preprocessing and keep separated by splits
preprocessed_data = {'dev': [], 'test': [], 'train': []}
failed_files = {'dev': [], 'test': [], 'train': []}

# Process in smaller batches for debugging
batch_size = 10  # Adjust this value as needed
for dataset_split in ['dev', 'test', 'train']:
    for i, file_path in enumerate(data[dataset_split]):
        if i % batch_size == 0:
            print(f"Processing batch {i//batch_size + 1} for {dataset_split} split...")

        preprocessed = preprocess_midi(file_path)
        if preprocessed:
            preprocessed_data[dataset_split].append(preprocessed)
        else:
            failed_files[dataset_split].append(file_path)

        # Print progress every 10 files
        if (i + 1) % 10 == 0:
            print(f"Processed {i + 1} files in {dataset_split} split...")

# Verify preprocessing results
for dataset_split in ['dev', 'test', 'train']:
    print(f"Processed {len(preprocessed_data[dataset_split])} MIDI files successfully in {dataset_split} split.")
    print(f"Failed to process {len(failed_files[dataset_split])} MIDI files in {dataset_split} split.")
    print(f"Sample preprocessed data from {dataset_split} split: {preprocessed_data[dataset_split][:1]}")

    if failed_files[dataset_split]:
        print(f"Failed files in {dataset_split} split:")
        for file_path in failed_files[dataset_split]:
            print(file_path)

Processing batch 1 for dev split...
Processed 10 files in dev split...
Processing batch 2 for dev split...
Processing batch 1 for test split...
Error parsing Composer_Dataset/NN_midi_files_extended/test/mozart/mozart025.mid: 5513617312
Processed 10 files in test split...
Processing batch 2 for test split...
Processing batch 1 for train split...
Processed 10 files in train split...
Processing batch 2 for train split...
Error parsing Composer_Dataset/NN_midi_files_extended/train/mozart/mozart009.mid: 13400269056
Processed 20 files in train split...
Processing batch 3 for train split...
Processed 30 files in train split...
Processing batch 4 for train split...
Processed 40 files in train split...
Processing batch 5 for train split...
Processed 50 files in train split...
Processing batch 6 for train split...
Processed 60 files in train split...
Processing batch 7 for train split...
Processed 70 files in train split...
Processing batch 8 for train split...
Processed 80 files in train split.

## Feature Extraction
This section extracts numerical features from the preprocessed MIDI notes and chords by converting them into MIDI pitch values. It handles both individual notes and chords, skipping invalid or empty data and logging any errors encountered. The script then prints the total number of processed files and provides a sample of the extracted features for verification.

In [25]:
# Step 3: Feature Extraction
def extract_features(notes):
    """
    Extract features from the notes and chords in a piece of music.
    """
    features = []

    # Convert notes and chords into numerical representation
    for n in notes:
        if n == '':
            continue  # Skip empty strings
        try:
            if '.' in n:
                chord_notes = n.split('.')
                if all(c.isdigit() for c in chord_notes):
                    chord_notes = [pitch.Pitch(int(c)).midi for c in chord_notes]
                    features.append(chord_notes)
                else:
                    print(f"Invalid chord encountered: {n}")
            else:
                if n.isdigit():
                    features.append(pitch.Pitch(int(n)).midi)
                else:
                    features.append(pitch.Pitch(n).midi)
        except pitch.PitchException:
            print(f"Invalid pitch encountered: {n}")
            continue  # Skip invalid pitch data

    return features

# Apply feature extraction
features = [extract_features(notes) for notes in preprocessed_data]

# Verify Step 3
print(f"Extracted features from {len(features)} MIDI files.")
print(f"Sample features: {features[:1]}")

Extracted features from 146 MIDI files.
Sample features: [[69, 81, 78, [62, 64], [62, 66], 73, 74, [62, 64], [67, 71], 73, 74, [62, 66, 69], [61, 64], 67, 79, 76, [61, 62], [64, 69], 71, 73, 69, [61, 64], 76, 81, [62, 64, 67, 69], [62, 66], 74, 78, 50, 62, 79, 64, [69, 71], 66, [66, 67], [62, 64], 74, 50, 78, 62, 79, 64, [69, 71], 66, [66, 67], [62, 64], 74, 83, 55, 67, 85, 64, [61, 62], 66, [69, 71], [66, 67], [69, 71], 67, [66, 67], [62, 64], [69, 61, 64], 69, 69, 81, 78, [62, 64], [62, 66], 73, 74, [62, 64], [67, 71], 73, 74, [62, 66, 69], [61, 64], 67, 79, 76, [69, 61, 62, 64], 71, 73, [69, 61, 64], 76, 81, [62, 64, 67, 69], [62, 66], 74, 78, 50, 62, 79, 64, [69, 71], 66, [66, 67], [62, 64], 79, 43, 83, 55, 85, 57, [62, 64], 59, [71, 61], [67, 69], [71, 61], [64, 67, 71], [67, 69], [64, 66], [69, 71], [66, 67], [62, 64], [67, 69], [64, 66], [61, 62], 74, [62, 66], [62, 66], [62, 66], 69, 45, [66, 69], [66, 69], 74, 50, [69, 62], [69, 62], 78, 54, [62, 66], [62, 66], 81, 57, [62, 66

## Convert Features and Labels into a DataFrame
This section creates a DataFrame from the features and labels, ensuring that both arrays have the same length by trimming the longer array. It prints the lengths of the features and labels, adjusts them if necessary, and then constructs and verifies the DataFrame to ensure proper alignment of the data.

In [42]:
# Step 4: Convert features and labels into a DataFrame
# Example arrays (replace these with your actual data)
features = [[1, 2], [3, 4], [5, 6]]
labels = [0, 1, 0, 1]  # Different length from features

# Check lengths
features_length = len(features)
labels_length = len(labels)

print(f"Length of features: {features_length}")
print(f"Length of labels: {labels_length}")

# Make the lengths equal by trimming the longer array
if features_length > labels_length:
    features = features[:labels_length]
elif labels_length > features_length:
    labels = labels[:features_length]

# Now create the DataFrame
df = pd.DataFrame({'features': features, 'label': labels})

# Verify the DataFrame
print(df.head())

Length of features: 3
Length of labels: 4
  features  label
0   [1, 2]      0
1   [3, 4]      1
2   [5, 6]      0
