In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import music21
from music21 import converter, note, chord, stream, tempo, key, pitch, duration, environment
import multiprocessing
from collections import Counter
from tqdm.auto import tqdm
import warnings

# Suppress specific warnings if needed
warnings.filterwarnings('ignore', category=FutureWarning, module='seaborn')
warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')
# music21 can be verbose with some warnings, suppress if necessary for cleanliness
# warnings.filterwarnings('ignore', module='music21')


pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
sns.set_theme(style="whitegrid")

# music21 environment configuration (if necessary, e.g., for MuseScore path for visualizations)
# This is often needed for.show() calls to render scores, but not strictly for data extraction.
# try:
#     env = environment.UserSettings()
#     # env.getSettingsPath() # to find where the settings file is
#     # env = '/Applications/MuseScore 3.app/Contents/MacOS/mscore' # Example for macOS
#     # env['musicxmlPath'] = '/Applications/MuseScore 3.app/Contents/MacOS/mscore' # Example for macOS
# except music21.environment.EnvironmentException:
#     print("music21 environment settings could not be configured (e.g., MuseScore path). This is usually fine for headless EDA.")

# Define base directory for the dataset
# The user query states: "the maestro unzipped dataset is in data/"
DATA_DIR = Path('../data/')
MAESTRO_METADATA_FILE = DATA_DIR / 'maestro-v3.0.0.csv' # Common version, adjust if different
MIDI_BASE_PATH = DATA_DIR # MIDI files are in subdirectories like data/2004/, data/2006/ etc.

In [None]:
def load_maestro_metadata(file_path: Path) -> pd.DataFrame | None:
    if not file_path.exists():
        print(f"Metadata file not found: {file_path}")
        return None
    try:
        if file_path.suffix == '.csv':
            df = pd.read_csv(file_path)
        elif file_path.suffix == '.json':
            # MAESTRO JSON is typically one JSON object per line, or a list of objects.
            # Adjust if it's a single JSON object with a top-level key.
            df = pd.read_json(file_path, lines=True if 'v1.0.0' in file_path.name else False) # v1.0.0 was lines=True
            if 'v3.0.0.json' in file_path.name: # v3.0.0 json is not line-delimited
                 df = pd.read_json(file_path)

        else:
            print(f"Unsupported metadata file format: {file_path.suffix}")
            return None
        
        # Construct full MIDI path. MIDI files are in subdirectories named by year,
        # e.g., data/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-04_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.midi
        # The 'midi_filename' column in maestro-v3.0.0.csv already contains the year prefix.
        df['midi_filepath'] = df['midi_filename'].apply(lambda x: MIDI_BASE_PATH / x)
        
        # Optional: Check file existence early, though this can be slow for large datasets
        # For now, we'll assume paths are correct and handle errors during parsing.
        # df['midi_exists'] = df['midi_filepath'].apply(lambda x: x.exists())
        # num_listed = len(df)
        # num_found = df['midi_exists'].sum()
        # if num_found < num_listed:
        #     print(f"Warning: {num_listed - num_found} MIDI files listed in metadata not found on disk.")
        # df = df[df['midi_exists']].copy() # Filter for existing files to prevent issues

        return df
    except Exception as e:
        print(f"Error loading metadata from {file_path}: {e}")
        return None

maestro_df = load_maestro_metadata(MAESTRO_METADATA_FILE)

if maestro_df is not None:
    print("MAESTRO Metadata Loaded Successfully:")
    print(maestro_df.head())
    print("\nDataFrame Info:")
    maestro_df.info()
    print(f"\nTotal pieces in metadata: {len(maestro_df)}")
    # A quick check for actual file existence for a small sample if needed for debugging
    # sample_exists = maestro_df.head()['midi_filepath'].apply(lambda x: x.exists())
    # print(f"\nExistence check for first 5 MIDI files:\n{sample_exists}")
else:
    print("Failed to load MAESTRO metadata. Some metadata-driven EDA parts will be skipped.")
    # Initialize an empty DataFrame to prevent errors in later cells if they expect maestro_df
    maestro_df = pd.DataFrame(columns=['canonical_composer', 'canonical_title', 'split', 'year', 'midi_filename', 'duration', 'midi_filepath'])

In [None]:
def load_maestro_metadata(file_path: Path) -> pd.DataFrame | None:
    if not file_path.exists():
        print(f"Metadata file not found: {file_path}")
        return None
    try:
        if file_path.suffix == '.csv':
            df = pd.read_csv(file_path)
        elif file_path.suffix == '.json':
            # MAESTRO JSON is typically one JSON object per line, or a list of objects.
            # Adjust if it's a single JSON object with a top-level key.
            df = pd.read_json(file_path, lines=True if 'v1.0.0' in file_path.name else False) # v1.0.0 was lines=True
            if 'v3.0.0.json' in file_path.name: # v3.0.0 json is not line-delimited
                 df = pd.read_json(file_path)

        else:
            print(f"Unsupported metadata file format: {file_path.suffix}")
            return None
        
        # Construct full MIDI path. MIDI files are in subdirectories named by year,
        # e.g., data/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-04_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.midi
        # The 'midi_filename' column in maestro-v3.0.0.csv already contains the year prefix.
        df['midi_filepath'] = df['midi_filename'].apply(lambda x: MIDI_BASE_PATH / x)
        
        # Optional: Check file existence early, though this can be slow for large datasets
        # For now, we'll assume paths are correct and handle errors during parsing.
        # df['midi_exists'] = df['midi_filepath'].apply(lambda x: x.exists())
        # num_listed = len(df)
        # num_found = df['midi_exists'].sum()
        # if num_found < num_listed:
        #     print(f"Warning: {num_listed - num_found} MIDI files listed in metadata not found on disk.")
        # df = df[df['midi_exists']].copy() # Filter for existing files to prevent issues

        return df
    except Exception as e:
        print(f"Error loading metadata from {file_path}: {e}")
        return None

maestro_df = load_maestro_metadata(MAESTRO_METADATA_FILE)

if maestro_df is not None:
    print("MAESTRO Metadata Loaded Successfully:")
    print(maestro_df.head())
    print("\nDataFrame Info:")
    maestro_df.info()
    print(f"\nTotal pieces in metadata: {len(maestro_df)}")
    # A quick check for actual file existence for a small sample if needed for debugging
    # sample_exists = maestro_df.head()['midi_filepath'].apply(lambda x: x.exists())
    # print(f"\nExistence check for first 5 MIDI files:\n{sample_exists}")
else:
    print("Failed to load MAESTRO metadata. Some metadata-driven EDA parts will be skipped.")
    # Initialize an empty DataFrame to prevent errors in later cells if they expect maestro_df
    maestro_df = pd.DataFrame(columns=['canonical_composer', 'canonical_title', 'split', 'year', 'midi_filename', 'duration', 'midi_filepath'])

In [None]:
if maestro_df is not None and not maestro_df.empty:
    plt.figure(figsize=(10, 6))
    sns.histplot(maestro_df['year'], 
                 bins=len(maestro_df['year'].unique()) if maestro_df['year'].nunique() > 0 else 1, 
                 kde=False)
    plt.title('Distribution of Pieces by Performance Year')
    plt.xlabel('Year of Performance')
    plt.ylabel('Number of Pieces')
    plt.show()

    plt.figure(figsize=(10, 6))
    sns.histplot(maestro_df['duration'] / 60, bins=50, kde=True) # Duration in minutes
    plt.title('Distribution of Piece Durations (Minutes)')
    plt.xlabel('Duration (Minutes)')
    plt.ylabel('Number of Pieces')
    plt.show()

    plt.figure(figsize=(14, 8)) # Increased figure size for better label readability
    top_n_composers = 20 # Show more composers
    composer_counts = maestro_df['canonical_composer'].value_counts()
    # Filter out composers with very few pieces for a cleaner plot, e.g., > 5 pieces
    # composer_counts_filtered = composer_counts[composer_counts > 5]
    # composer_counts_to_plot = composer_counts_filtered.nlargest(top_n_composers)
    composer_counts_to_plot = composer_counts.nlargest(top_n_composers)
    
    sns.barplot(x=composer_counts_to_plot.index, y=composer_counts_to_plot.values, palette="viridis")
    plt.title(f'Top {len(composer_counts_to_plot)} Composers by Number of Pieces')
    plt.xlabel('Composer')
    plt.ylabel('Number of Pieces')
    plt.xticks(rotation=60, ha='right') # Adjusted rotation for clarity
    plt.tight_layout() # Adjust layout to prevent labels from overlapping
    plt.show()

    print("\nSummary statistics for piece duration (seconds):")
    print(maestro_df['duration'].describe())
else:
    print("maestro_df is None or empty, skipping metadata visualizations.")

In [None]:
from midi_processing import process_single_midi_file

with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
    results = list(tqdm(pool.imap_unordered(process_single_midi_file, maestro_df['midi_filepath']), total=len(maestro_df)))

results_df = pd.DataFrame(results)

In [None]:
results_df.head()

In [None]:
if maestro_df is not None and not maestro_df.empty and not results_df.empty:
    maestro_df['midi_filename_basename'] = maestro_df['midi_filename'].apply(lambda x: Path(x).name)
    results_df.rename(columns={'midi_filename_processed': 'midi_filename_basename'}, inplace=True)
    
    combined_df = pd.merge(maestro_df, results_df, on='midi_filename_basename', how='inner', suffixes=('_meta', '_extracted'))
    print(f"Combined DataFrame after merging with metadata (shape: {combined_df.shape}):")
elif not results_df.empty:
    combined_df = results_df
    print("Using features extracted directly from MIDI files (metadata not merged).")
    print(f"Features DataFrame shape: {combined_df.shape}")
else:
    combined_df = pd.DataFrame() 
    print("Resulting combined_df is empty.")


In [None]:
combined_df = pd.read_csv('combined_df.csv') 
combined_df.head()

In [6]:
combined_df.to_csv('combined_df.csv', index=False)

In [None]:
if not combined_df.empty and 'key_tonic' in combined_df.columns and 'key_mode' in combined_df.columns:
    combined_df['key_tonic_str'] = combined_df['key_tonic'].astype(str).fillna('Unknown')
    combined_df['key_mode_str'] = combined_df['key_mode'].astype(str).fillna('')
    
    combined_df['estimated_key_full'] = combined_df['key_tonic_str'] + ' ' + combined_df['key_mode_str']
    combined_df['estimated_key_full'] = combined_df['estimated_key_full'].str.replace('Unknown ', 'Unknown', regex=False).str.strip()

    valid_keys_df = combined_df[~combined_df['estimated_key_full'].isin(['Unknown', 'None None'])]

    if not valid_keys_df.empty:
        plt.figure(figsize=(14, 7))
        key_counts = valid_keys_df['estimated_key_full'].value_counts().nlargest(24)
        sns.barplot(x=key_counts.index, y=key_counts.values, palette="crest")
        plt.title('Distribution of Estimated Keys (Top 24)')
        plt.xlabel('Estimated Key')
        plt.ylabel('Number of Pieces')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(8, 5))
        mode_counts = valid_keys_df['key_mode_str'].value_counts()
        if 'major' not in mode_counts.index and 'minor' in mode_counts.index : mode_counts.loc['major'] = 0
        if 'minor' not in mode_counts.index and 'major' in mode_counts.index : mode_counts.loc['minor'] = 0
        mode_counts = mode_counts.sort_index()

        sns.barplot(x=mode_counts.index, y=mode_counts.values, palette="pastel")
        plt.title('Distribution of Modes (Major/Minor)')
        plt.xlabel('Mode')
        plt.ylabel('Number of Pieces')
        plt.show()

        print("\nFrequency of Top Estimated Keys:")
        print(key_counts)
        print("\nFrequency of Modes:")
        print(mode_counts)
        
        key_freq_table_data = key_counts.reset_index()
        key_freq_table_data.columns = ['Estimated Key', 'Count']
        key_freq_table_data['Percentage'] = (key_freq_table_data['Count'] / len(valid_keys_df)) * 100
        print("\nTable: Frequency of Top N Estimated Keys")
        print(key_freq_table_data.to_string(index=False))

    else:
        print("No valid key analysis results to display.")
else:
    print("Key features ('key_tonic', 'key_mode') not available in combined_df for analysis.")

In [None]:
if not combined_df.empty and 'pitch_min' in combined_df.columns and combined_df['pitch_min'].notna().all():

    if 'pitch_range' in combined_df.columns and combined_df['pitch_range'].notna().any():
        plt.figure(figsize=(10, 6))
        sns.histplot(combined_df['pitch_range'].dropna().astype(float), bins=30, kde=True, color="skyblue")
        plt.title('Distribution of Pitch Ranges (in semitones) Per Piece')
        plt.xlabel('Pitch Range (Semitones)')
        plt.ylabel('Number of Pieces')
        plt.show()

    if 'pitch_mean' in combined_df.columns and combined_df['pitch_mean'].notna().any():
        plt.figure(figsize=(10, 6))
        sns.histplot(combined_df['pitch_mean'].dropna().astype(float), bins=30, kde=True, color="salmon")
        plt.title('Distribution of Average MIDI Pitch Per Piece')
        plt.xlabel('Average MIDI Pitch (C4=60)')
        plt.ylabel('Number of Pieces')
        plt.axvline(60, color='k', linestyle='--', label='C4 (Middle C)')
        plt.legend()
        plt.show()

    if 'pitch_min' in combined_df.columns and combined_df['pitch_min'].notna().any():
        plt.figure(figsize=(10, 6))
        sns.histplot(combined_df['pitch_min'].dropna().astype(float), bins=30, kde=True, color="lightgreen", label='Min Pitch')
        plt.title('Distribution of Minimum MIDI Pitches Per Piece')
        plt.xlabel('Minimum MIDI Pitch')
        plt.ylabel('Number of Pieces')
        plt.legend()
        plt.show()

    if 'pitch_max' in combined_df.columns and combined_df['pitch_max'].notna().any():
        plt.figure(figsize=(10, 6))
        sns.histplot(combined_df['pitch_max'].dropna().astype(float), bins=30, kde=True, color="gold", label='Max Pitch')
        plt.title('Distribution of Maximum MIDI Pitches Per Piece')
        plt.xlabel('Maximum MIDI Pitch')
        plt.ylabel('Number of Pieces')
        plt.legend()
        plt.show()

    print("\nSummary statistics for per-piece pitch features:")
    pitch_summary_cols = ['num_notes', 'pitch_min', 'pitch_max', 'pitch_mean', 'pitch_median', 'pitch_std', 'pitch_range']
    existing_pitch_summary_cols = [col for col in pitch_summary_cols if col in combined_df.columns]
    if existing_pitch_summary_cols:
        pitch_stats_df = combined_df[existing_pitch_summary_cols].describe()
        print(pitch_stats_df)
        
        print("\nTable: Summary Statistics for Per-Piece Pitch Features")
        print(pitch_stats_df.to_string())
    else:
        print("No pitch summary columns available for describe().")

    # For a global pitch class distribution, one would need to aggregate all pitch events.
    # This is computationally more intensive if not done during initial parsing.
    # Example (if all_pitches_global was populated from samples or full data):
    # all_pitches_global = # Placeholder: populate this by iterating through all notes in all scores
    # if all_pitches_global:
    #    all_pitch_classes = [int(round(p)) % 12 for p in all_pitches_global]
    #    pc_counts = Counter(all_pitch_classes)
    #    pc_names =
    #    plt.figure(figsize=(10, 6))
    #    sns.barplot(x=[pc_names[i] for i in sorted(pc_counts.keys())], y=[pc_counts[i] for i in sorted(pc_counts.keys())])
    #    plt.title('Overall Pitch Class Distribution (Sampled/Full)')
    #    plt.xlabel('Pitch Class')
    #    plt.ylabel('Frequency')
    #    plt.show()

else:
    print("Pitch features ('pitch_min', etc.) not available in combined_df for analysis.")

In [None]:
if not combined_df.empty and 'durations_ql_sample' in combined_df.columns and 'iois_ql_sample' in combined_df.columns:
    all_durations_ql = []
    for i, row in combined_df.iterrows():
        if isinstance(row['durations_ql_sample'], list):
            all_durations_ql.extend(row['durations_ql_sample'])

    all_iois_ql = []
    for i, row in combined_df.iterrows():
        if isinstance(row['iois_ql_sample'], list):
            all_iois_ql.extend(row['iois_ql_sample'])
    
    all_durations_ql = [d for d in all_durations_ql if d > 1e-6]
    all_iois_ql = [i for i in all_iois_ql if i > 1e-6]

    common_qls = {
        '1/64': 0.0625, '1/32': 0.125, '1/16': 0.25, 'triplet 8th (approx)': 1/3, 
        '8th': 0.5, 'dotted 8th': 0.75, 'quarter': 1.0, 
        'triplet quarter (approx)': 2/3, 'dotted quarter': 1.5, 'half': 2.0, 
        'dotted half': 3.0, 'whole': 4.0
    }

    if all_durations_ql:
        plt.figure(figsize=(14, 7))
        min_dur, max_dur = np.min(all_durations_ql), np.max(all_durations_ql)
        log_bins = np.logspace(np.log10(min_dur if min_dur > 0 else 0.01), np.log10(max_dur if max_dur > 0 else 4.0), 75)
        sns.histplot(all_durations_ql, bins=log_bins, kde=False, color='teal')
        plt.xscale('log')
        plt.title('Overall Distribution of Note/Rest Durations (Quarter Lengths, Log Scale)')
        plt.xlabel('Duration (Quarter Lengths)')
        plt.ylabel('Frequency')
        
        for name, val in common_qls.items():
            if min_dur < val < max_dur:
                plt.axvline(val, color='r', linestyle='--', alpha=0.6, label=f'{name} ({val:.3f})')
        plt.legend(fontsize='small', ncol=2)
        plt.grid(True, which="both", ls="-", alpha=0.5)
        plt.show()

        duration_bins = sorted(list(set([0, 0.0625, 0.125, 0.25, 1/3, 0.5, 2/3, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, float('inf')])))
        duration_labels = [
            '0 (<1/64)',  # Or 'Almost Zero' for [0, 0.0625)
            '1/64 (0.0625-0.125)',
            '1/32 (0.125-0.25)',
            '1/16 (0.25-1/3)',
            'triplet 8th (1/3-0.5)',
            '8th (0.5-2/3)',
            'triplet qtr / dotted 8th (2/3-0.75)', # Bin [2/3, 0.75)
            'dotted 8th / quarter (0.75-1.0)',   # Bin [0.75, 1.0)
            'quarter / dotted qtr (1.0-1.5)',    # Bin [1.0, 1.5)
            'dotted quarter / half (1.5-2.0)',   # Bin [1.5, 2.0)
            'half / dotted half (2.0-3.0)',      # Bin [2.0, 3.0)
            'dotted half / whole (3.0-4.0)',     # Bin [3.0, 4.0)
            'whole (4.0-8.0)',                   # Bin [4.0, 8.0)
            '> 2 wholes (>8.0)'                  # Label for the last bin [8.0, inf)
        ]


        
        quantized_durs = pd.cut([d for d in all_durations_ql if d > 0], bins=duration_bins, labels=duration_labels, right=False, include_lowest=True)
        plt.figure(figsize=(14, 7))
        quantized_durs.value_counts().plot(kind='bar', color='lightcoral')
        plt.title('Distribution of Quantized Note/Rest Durations')
        plt.xlabel('Quantized Duration Interval (Quarter Length)')
        plt.ylabel('Frequency')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()
        
        quantized_dur_counts = quantized_durs.value_counts().nlargest(10)
        quantized_dur_table = quantized_dur_counts.reset_index()
        quantized_dur_table.columns = ['Quantized Duration Interval', 'Count']
        quantized_dur_table['Percentage'] = (quantized_dur_table['Count'] / len(quantized_durs)) * 100
        print("\nTable: Top 10 Most Frequent Quantized Note/Rest Durations")
        print(quantized_dur_table.to_string(index=False))

    else:
        print("No duration data to plot.")

    if all_iois_ql:
        plt.figure(figsize=(14, 7))
        min_ioi, max_ioi = np.min(all_iois_ql), np.max(all_iois_ql)
        log_bins_ioi = np.logspace(np.log10(min_ioi if min_ioi > 0 else 0.01), np.log10(max_ioi if max_ioi > 0 else 4.0), 75)
        sns.histplot(all_iois_ql, bins=log_bins_ioi, kde=False, color='darkorange')
        plt.xscale('log')
        plt.title('Overall Distribution of Inter-Onset Intervals (Quarter Lengths, Log Scale)')
        plt.xlabel('IOI (Quarter Lengths)')
        plt.ylabel('Frequency')
        for name, val in common_qls.items():
            if min_ioi < val < max_ioi:
                plt.axvline(val, color='b', linestyle=':', alpha=0.6, label=f'{name} ({val:.3f})')
        plt.legend(fontsize='small', ncol=2)
        plt.grid(True, which="both", ls="-", alpha=0.5)
        plt.show()
    else:
        print("No IOI data to plot.")

    rhythm_summary_cols = ['avg_duration_ql', 'median_duration_ql', 'std_duration_ql', 'num_rhythmic_elements',
                           'avg_ioi_ql', 'median_ioi_ql', 'std_ioi_ql', 'num_iois']
    existing_rhythm_summary_cols = [col for col in rhythm_summary_cols if col in combined_df.columns]
    if existing_rhythm_summary_cols and not combined_df[existing_rhythm_summary_cols].dropna().empty:
        print("\nSummary statistics for per-piece rhythmic features (averages, counts):")
        print(combined_df[existing_rhythm_summary_cols].describe())
    else:
        print("Not enough per-piece rhythmic summary data for description.")
else:
    print("Rhythmic features ('durations_ql_sample', 'iois_ql_sample') not available in combined_df for analysis.")

In [None]:
if not combined_df.empty and 'initial_tempo_bpm' in combined_df.columns:
    valid_initial_tempos = combined_df['initial_tempo_bpm'].dropna()
    valid_initial_tempos = valid_initial_tempos[valid_initial_tempos > 0]

    if not valid_initial_tempos.empty:
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_initial_tempos, bins=40, kde=True, color="purple")
        plt.title('Distribution of Initial Tempos (BPM)')
        plt.xlabel('Tempo (BPM)')
        plt.ylabel('Number of Pieces')
        tempo_markings = {'Largo': 50, 'Adagio': 70, 'Andante': 90, 'Moderato': 110, 'Allegro': 130, 'Presto': 180}
        for mark, bpm_val in tempo_markings.items():
            if valid_initial_tempos.min() < bpm_val < valid_initial_tempos.max():
                 plt.axvline(bpm_val, color='gray', linestyle='--', alpha=0.8, label=f'{mark} ({bpm_val} BPM)')
        plt.legend(fontsize='small')
        plt.show()

        print("\nSummary statistics for initial tempo (BPM):")
        print(valid_initial_tempos.describe())
    else:
        print("No valid initial tempo data to plot after filtering.")
    
    if 'num_distinct_tempos' in combined_df.columns:
        num_tempo_changes_series = combined_df['num_distinct_tempos'].dropna().astype(int)
        pieces_with_tempo_changes = num_tempo_changes_series[num_tempo_changes_series > 1]
        
        plt.figure(figsize=(10,6))
        sns.histplot(num_tempo_changes_series, discrete=True, stat="count", shrink=0.8) 
        plt.title('Distribution of Number of Distinct Tempos within Pieces')
        plt.xlabel('Number of Distinct Tempos (1 implies stable or single marked tempo)')
        plt.ylabel('Number of Pieces')
        if not num_tempo_changes_series.empty:
             plt.xticks(sorted(num_tempo_changes_series.unique()))
        plt.show()
        
        print(f"\nNumber of pieces with more than one distinct tempo marking: {len(pieces_with_tempo_changes)}")
        if len(pieces_with_tempo_changes) > 0:
            print("Examples of pieces with multiple distinct tempos:")
            display_cols = ['midi_filename_meta', 'num_distinct_tempos']
            if 'all_tempos_bpm_sample' in combined_df.columns:
                display_cols.append('all_tempos_bpm_sample')
            
            print(combined_df[combined_df['num_distinct_tempos'] > 1][display_cols].head())
else:
    print("Tempo features ('initial_tempo_bpm', 'num_distinct_tempos') not available in combined_df for analysis.")

In [None]:
if not combined_df.empty and 'avg_polyphony' in combined_df.columns and 'max_polyphony' in combined_df.columns:
    valid_avg_polyphony = combined_df['avg_polyphony'].dropna()
    valid_avg_polyphony = valid_avg_polyphony[valid_avg_polyphony >= 0]

    valid_max_polyphony = combined_df['max_polyphony'].dropna()
    valid_max_polyphony = valid_max_polyphony[valid_max_polyphony >= 0]

    if not valid_avg_polyphony.empty:
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_avg_polyphony, bins=30, kde=True, color="green")
        plt.title('Distribution of Average Polyphony Per Piece')
        plt.xlabel('Average Number of Simultaneous Notes/Pitches')
        plt.ylabel('Number of Pieces')
        plt.show()
        print("\nSummary statistics for average polyphony:")
        print(valid_avg_polyphony.describe())

    if not valid_max_polyphony.empty:
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_max_polyphony, bins=max(1, int(valid_max_polyphony.max()) - int(valid_max_polyphony.min())), kde=False, color="orange", discrete=True if valid_max_polyphony.nunique()<30 else False)
        plt.title('Distribution of Maximum Polyphony Per Piece')
        plt.xlabel('Maximum Number of Simultaneous Notes/Pitches')
        plt.ylabel('Number of Pieces')
        if not valid_max_polyphony.empty:
            plt.xticks(np.arange(int(valid_max_polyphony.min()), int(valid_max_polyphony.max())+1, step=max(1, int(valid_max_polyphony.max())//10)))
        plt.show()
        print("\nSummary statistics for maximum polyphony:")
        print(valid_max_polyphony.describe())
    
    all_polyphony_levels = []
    if 'polyphony_levels_sample' in combined_df.columns:
        for poly_list in combined_df['polyphony_levels_sample'].dropna():
            if isinstance(poly_list, list):
                all_polyphony_levels.extend(poly_list)
    
    all_polyphony_levels = [p for p in all_polyphony_levels if p > 0]

    if all_polyphony_levels:
        plt.figure(figsize=(12, 7))
        poly_counts = pd.Series(all_polyphony_levels).value_counts().sort_index()
        poly_counts_to_plot = poly_counts[poly_counts.index <= 12]
        
        sns.barplot(x=poly_counts_to_plot.index, y=poly_counts_to_plot.values, color="cyan")
        plt.title('Overall Distribution of Polyphony Levels (at onsets, up to 12 voices)')
        plt.xlabel('Number of Simultaneous Notes/Pitches')
        plt.ylabel('Frequency of Onsets')
        plt.show()
    else:
        print("No aggregated polyphony level data to plot.")
else:
    print("Polyphony features ('avg_polyphony', 'max_polyphony') not available in combined_df for analysis.")

In [None]:
if not combined_df.empty:
    numeric_features_for_corr = [
        'duration', 
        'year', 
        'total_quarter_length', 
        'initial_tempo_bpm', 
        'mean_tempo_bpm',
        'num_distinct_tempos',
        'num_notes',
        'pitch_min', 'pitch_max', 'pitch_mean', 'pitch_median', 'pitch_std', 'pitch_range',
        'avg_duration_ql', 'median_duration_ql', 'std_duration_ql', 'num_rhythmic_elements',
        'avg_ioi_ql', 'median_ioi_ql', 'std_ioi_ql', 'num_iois',
        'avg_polyphony', 'max_polyphony', 'median_polyphony',
        'key_confidence'
    ]
    
    existing_numeric_cols_for_corr = [col for col in numeric_features_for_corr if col in combined_df.columns and pd.api.types.is_numeric_dtype(combined_df[col])]
    
    if not existing_numeric_cols_for_corr or len(existing_numeric_cols_for_corr) < 2:
        print("Not enough numeric features available for correlation analysis.")
    else:
        corr_df_subset = combined_df[existing_numeric_cols_for_corr].copy()
        corr_df_subset.dropna(inplace=True)

        if not corr_df_subset.empty and len(corr_df_subset.columns) > 1 and len(corr_df_subset) > 1:
            correlation_matrix = corr_df_subset.corr()
            
            plt.figure(figsize=(18, 15))
            sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", linewidths=.5, annot_kws={"size": 8})
            plt.title('Correlation Matrix of Numeric Musical Features', fontsize=16)
            plt.xticks(fontsize=10)
            plt.yticks(fontsize=10)
            plt.tight_layout()
            plt.show()

            if 'total_quarter_length' in corr_df_subset.columns and 'pitch_range' in corr_df_subset.columns:
                plt.figure(figsize=(8, 6))
                sns.scatterplot(data=corr_df_subset, x='total_quarter_length', y='pitch_range', alpha=0.4, color='blue')
                sns.regplot(data=corr_df_subset, x='total_quarter_length', y='pitch_range', scatter=False, color='red')
                plt.title('Piece Length (Total QL) vs. Pitch Range')
                plt.xlabel('Total Quarter Length')
                plt.ylabel('Pitch Range (Semitones)')
                plt.grid(True)
                plt.show()

            if 'initial_tempo_bpm' in corr_df_subset.columns and 'avg_ioi_ql' in corr_df_subset.columns:
                plt.figure(figsize=(8, 6))
                sns.scatterplot(data=corr_df_subset, x='initial_tempo_bpm', y='avg_ioi_ql', alpha=0.4, color='green')
                sns.regplot(data=corr_df_subset, x='initial_tempo_bpm', y='avg_ioi_ql', scatter=False, color='red')
                plt.title('Initial Tempo (BPM) vs. Average Inter-Onset Interval (QL)')
                plt.xlabel('Initial Tempo (BPM)')
                plt.ylabel('Average IOI (Quarter Lengths)')
                plt.grid(True)
                plt.show()
                
            if 'avg_polyphony' in corr_df_subset.columns and 'num_notes' in corr_df_subset.columns:
                plt.figure(figsize=(8, 6))
                sns.scatterplot(data=corr_df_subset, x='num_notes', y='avg_polyphony', alpha=0.4, color='purple')
                sns.regplot(data=corr_df_subset, x='num_notes', y='avg_polyphony', scatter=False, color='red')
                plt.title('Number of Notes vs. Average Polyphony')
                plt.xlabel('Total Number of Notes in Piece')
                plt.ylabel('Average Polyphony')
                plt.xscale('log')
                plt.grid(True)
                plt.show()

        else:
            print("Not enough data or numeric columns for correlation analysis after cleaning NaNs.")
else:
    print("Combined DataFrame is empty. Cannot perform correlation analysis.")

# Task 1 - Symbolic Music Generation Unconditioned

In [32]:
import numpy as np
import pandas as pd
from pathlib import Path
import music21
from music21 import converter, note, chord, stream, instrument, tempo
from collections import defaultdict, Counter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import multiprocessing
# tqdm.auto
from tqdm.auto import tqdm
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning, module='music21')
warnings.filterwarnings('ignore', category=FutureWarning) # General FutureWarnings
# print torch GPU available
print(torch.cuda.is_available())
# --- Configuration ---
DATA_DIR = Path('../data/')  # Adjust if your data is elsewhere
MAESTRO_METADATA_FILE = DATA_DIR / 'maestro-v3.0.0.csv'
MIDI_BASE_PATH = DATA_DIR # MIDI files are in subdirectories like data/2004/

# Model and Generation Parameters
MARKOV_ORDER = 3  # Order of the Markov chain
LSTM_SEQUENCE_LENGTH = 50
LSTM_EMBEDDING_DIM = 128
LSTM_UNITS = 256
LSTM_EPOCHS = 15 # Adjust based on dataset size and convergence
LSTM_BATCH_SIZE = 2048
GENERATION_LENGTH = 1000 # Number of events to generate
QUANTIZED_DURATIONS_QN = np.array([
    0.0625,  # 64th note
    0.125,   # 32nd note
    0.1666,  # Approx. 16th note triplet component (1/6)
    0.25,    # 16th note
    0.3333,  # Approx. 8th note triplet component (1/3)
    0.5,     # 8th note
    0.6666,  # Approx. Quarter note triplet component (2/3)
    0.75,    # Dotted 8th note
    1.0,     # Quarter note
    1.5,     # Dotted Quarter note
    2.0,     # Half note
    3.0,     # Dotted Half note
    4.0      # Whole note
])

def get_closest_quantized_duration(actual_ql: float) -> float:
    """Finds the closest duration from our predefined QUANTIZED_DURATIONS_QN list."""
    # Handle rests or very short notes gracefully, map them to the smallest positive duration
    if actual_ql <= 0:
        return QUANTIZED_DURATIONS_QN[0]
    idx = np.abs(QUANTIZED_DURATIONS_QN - actual_ql).argmin()
    return QUANTIZED_DURATIONS_QN[idx]

True


In [2]:
def load_maestro_metadata(file_path: Path) -> pd.DataFrame | None:
    """Loads MAESTRO metadata, adapted from the provided PDF."""
    if not file_path.exists():
        print(f"Metadata file not found: {file_path}")
        return None
    try:
        df = pd.read_csv(file_path)
        # Construct full MIDI path [cite: 7]
        df['midi_filepath'] = df['midi_filename'].apply(lambda x: MIDI_BASE_PATH / x)
        # Basic check for file existence (optional, can be slow)
        # df['midi_exists'] = df['midi_filepath'].apply(lambda x: x.exists())
        # print(f"Found {df['midi_exists'].sum()} of {len(df)} MIDI files on disk.")
        # df = df[df['midi_exists']]
        return df
    except Exception as e:
        print(f"Error loading metadata from {file_path}: {e}")
        return None

def extract_events_from_midi(midi_path: Path) -> list[str] | None:
    """Extracts notes/chords and their quantized durations as string events."""
    events = []
    try:
        score = converter.parse(midi_path)
        notes_to_parse = score.flat.notesAndRests # Include rests if desired, or stick to .notes
        
        for element in notes_to_parse:
            # Skip elements with no duration or zero duration if they cause issues
            if element.duration is None or element.duration.quarterLength == 0:
                continue

            quantized_ql = get_closest_quantized_duration(element.duration.quarterLength)
            
            if isinstance(element, note.Note):
                # Event format: "pitch_duration" e.g., "60_0.5"
                events.append(f"{element.pitch.midi}_{quantized_ql}")
            elif isinstance(element, chord.Chord):
                # Event format: "pitch1.pitch2.pitch3_duration" e.g., "60.64.67_1.0"
                chord_pitches_str = '.'.join(str(p.midi) for p in sorted(element.pitches))
                events.append(f"{chord_pitches_str}_{quantized_ql}")
            # Optionally handle rests:
            # elif isinstance(element, note.Rest):
            #     events.append(f"Rest_{quantized_ql}")

    except Exception as e:
        # print(f"Could not parse {midi_path}: {e}")
        return None
    return events

def create_vocabulary(all_event_sequences: list[list[str]]):
    """Creates event_to_int and int_to_event mappings."""
    all_events = [event for seq in all_event_sequences for event in seq]
    event_counts = Counter(all_events)
    unique_events = sorted(event_counts.keys()) # Sort for consistency
    event_to_int = {event: i for i, event in enumerate(unique_events)}
    int_to_event = {i: event for i, event in enumerate(unique_events)}
    return event_to_int, int_to_event, len(unique_events)

# --- 2. MIDI Generation Utility ---

def create_midi_from_events(events: list[str], output_filename: str): # Removed default_duration_qn
    """Creates a MIDI file from a list of string events (pitch/chord_duration)."""
    output_stream = stream.Stream()
    output_stream.append(instrument.Piano())
    bpm = 120 # Default tempo, can also be learned or configured
    mm = tempo.MetronomeMark(number=bpm)
    output_stream.append(mm)

    # current_offset = 0.0 # If precise offsets are needed based on appending.
                           # For simple sequential generation, music21 handles appending.

    for event_str in events:
        try:
            parts = event_str.split('_')
            if len(parts) != 2:
                # print(f"Skipping malformed event: {event_str}")
                continue # Skip malformed events

            item_str = parts[0]
            duration_ql_str = parts[1]
            
            try:
                element_duration_ql = float(duration_ql_str)
            except ValueError:
                # print(f"Skipping event with unparsable duration: {event_str}")
                continue
            
            element_duration = music21.duration.Duration(element_duration_ql)

            if "Rest" in item_str: # Handling rests if they are in vocabulary
                r = note.Rest(duration=element_duration)
                output_stream.append(r)
            elif '.' in item_str: # Chord
                pitch_strings = item_str.split('.')
                pitches = [int(p_str) for p_str in pitch_strings]
                c = chord.Chord(pitches, duration=element_duration)
                output_stream.append(c)
            elif item_str.isdigit(): # Note
                pitch = int(item_str)
                n = note.Note(pitch, duration=element_duration)
                output_stream.append(n)
            else:
                # print(f"Skipping unknown event type: {item_str}")
                continue
            
            # output_stream.append(m21_event)
            # current_offset += element_duration_ql # Advance offset if managing manually

        except Exception as e:
            # print(f"Error processing event '{event_str}': {e}")
            continue
            
    try:
        output_stream.write('midi', fp=output_filename)
        print(f"Generated MIDI file: {output_filename}")
    except Exception as e:
        print(f"Error writing MIDI file {output_filename}: {e}")

In [4]:
class MarkovChain:
    def __init__(self, order=1):
        self.order = order
        self.transitions = defaultdict(Counter)
        self.initial_states = Counter() # For generating the first 'order' states
        self.vocab_size = 0

    def train(self, sequences: list[list[int]], vocab_size: int):
        self.vocab_size = vocab_size
        for seq in sequences:
            if len(seq) <= self.order:
                continue
            
            # Store initial states for generation
            initial_context = tuple(seq[:self.order])
            self.initial_states[initial_context] += 1

            for i in range(len(seq) - self.order):
                context = tuple(seq[i : i + self.order])
                next_event = seq[i + self.order]
                self.transitions[context][next_event] += 1
        
        # Normalize initial_states for generation
        total_initials = sum(self.initial_states.values())
        if total_initials > 0:
             for context in self.initial_states:
                self.initial_states[context] /= total_initials


    def get_prob(self, context: tuple[int, ...], next_event: int):
        """Calculates P(next_event | context) with Laplace smoothing."""
        context_counts = sum(self.transitions[context].values())
        event_in_context_count = self.transitions[context][next_event]
        return (event_in_context_count + 1) / (context_counts + self.vocab_size)

    def generate(self, length: int) -> list[int]:
        if not self.transitions:
            print("Markov chain not trained or no transitions learned.")
            return []

        # Start with a probable initial context
        if not self.initial_states: # Fallback if no initial states were captured
             current_context = list(np.random.choice(list(self.transitions.keys()))) if self.transitions else []
        else:
            # Sample from learned initial states
            initial_contexts_list = list(self.initial_states.keys())
            initial_probs = [self.initial_states[ctx] for ctx in initial_contexts_list]
            if not initial_contexts_list or sum(initial_probs) == 0: # further fallback
                 current_context_tuple = np.random.choice(list(self.transitions.keys())) if self.transitions else tuple()
                 current_context = list(current_context_tuple)
            else:
                start_idx = np.random.choice(len(initial_contexts_list), p=initial_probs)
                current_context = list(initial_contexts_list[start_idx])
        
        generated_sequence = list(current_context)

        for _ in range(length - self.order):
            if tuple(current_context) not in self.transitions or not self.transitions[tuple(current_context)]:
                # If context is unknown or has no continuations, break or pick random
                # For simplicity, picking a random event from vocab. A better way is backoff or random context.
                next_event = np.random.randint(0, self.vocab_size)
                # print(f"Warning: Unknown context {tuple(current_context)}, choosing random next event.")
            else:
                possible_next_events = list(self.transitions[tuple(current_context)].keys())
                probabilities = [self.get_prob(tuple(current_context), event) for event in possible_next_events]
                # Normalize probabilities if they don't sum to 1 (due to smoothing in get_prob calculation)
                prob_sum = sum(probabilities)
                if prob_sum > 0:
                    probabilities = [p/prob_sum for p in probabilities]
                else: # Fallback if all probabilities are zero (should not happen with Laplace)
                    probabilities = [1.0/len(possible_next_events)] * len(possible_next_events)
                
                if not possible_next_events: # Should not happen if context in transitions
                    next_event = np.random.randint(0, self.vocab_size)
                else:
                    next_event = np.random.choice(possible_next_events, p=probabilities)

            generated_sequence.append(next_event)
            current_context = generated_sequence[-self.order:]
        
        return generated_sequence

    def calculate_perplexity(self, sequences: list[list[int]]):
        log_likelihood_sum = 0
        total_events = 0
        for seq in sequences:
            if len(seq) <= self.order:
                continue
            for i in range(len(seq) - self.order):
                context = tuple(seq[i : i + self.order])
                next_event = seq[i + self.order]
                prob = self.get_prob(context, next_event)
                if prob > 1e-9: # Avoid log(0)
                    log_likelihood_sum += np.log2(prob)
                else: # Penalize heavily for zero probability events
                    log_likelihood_sum -= 20 # Arbitrary large penalty
                total_events += 1
        
        if total_events == 0: return float('inf')
        avg_log_likelihood = log_likelihood_sum / total_events
        perplexity = np.power(2, -avg_log_likelihood)
        return perplexity


In [33]:
def prepare_sequences_for_lstm_pytorch(sequences: list[list[int]], seq_len: int):
    X, y = [], []
    for seq in sequences:
        for i in range(len(seq) - seq_len):
            X.append(seq[i : i + seq_len])
            y.append(seq[i + seq_len])
    X = torch.tensor(X, dtype=torch.long)  # LongTensor for embeddings
    y = torch.tensor(y, dtype=torch.long)  # AdaptiveLogSoftmax expects class indices
    return X, y

def build_lstm_model(vocab_size: int, embedding_dim: int, lstm_units: int, num_layers: int = 2):
    class LSTMModel(nn.Module):
        def __init__(self, vocab_size, embedding_dim, hidden_dim, cutoffs=[1000, 10000, 50000]):
            super(LSTMModel, self).__init__()
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
            self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, dropout=0.2, num_layers=num_layers)
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=hidden_dim,
                n_classes=vocab_size,
                cutoffs=cutoffs
            )

        def forward(self, x, target=None):
            x = self.embedding(x)
            out, _ = self.lstm(x)
            out = out[:, -1, :]  # last time step

            if self.training and target is not None:
                return self.adaptive_softmax(out, target)
            else:
                return out  # raw output for generation/logits

    return LSTMModel(vocab_size, embedding_dim, lstm_units)

def generate_sequence_lstm(model, seed_sequence: list[int], length: int, vocab_size: int, temperature=1.0, device='cpu'):
    model.eval()
    generated_sequence = list(seed_sequence)
    current_input_sequence = torch.tensor(seed_sequence, dtype=torch.long, device=device).unsqueeze(0)

    with torch.no_grad():
        for _ in range(length):
            logits = model(current_input_sequence)  # (1, hidden_dim)
            probs = torch.softmax(model.adaptive_softmax.log_prob(logits), dim=-1).squeeze(0).cpu().numpy()

            if temperature > 0:
                scaled_probs = np.power(probs, 1.0 / temperature)
                scaled_probs /= np.sum(scaled_probs)
            else:
                scaled_probs = np.zeros_like(probs)
                scaled_probs[np.argmax(probs)] = 1.0

            if np.isnan(scaled_probs).any() or np.sum(scaled_probs) == 0:
                next_event_int = np.random.choice(vocab_size)
            else:
                next_event_int = np.random.choice(vocab_size, p=scaled_probs)

            generated_sequence.append(next_event_int)
            current_input_sequence = torch.cat(
                [current_input_sequence[:, 1:], torch.tensor([[next_event_int]], device=device)], dim=1
            )

    return generated_sequence

def calculate_perplexity_lstm(model, dataloader, device='cpu'):
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            output = model(X_batch, y_batch)
            loss = output.loss
            total_loss += loss.item() * X_batch.size(0)
            total_tokens += y_batch.size(0)

    avg_loss = total_loss / total_tokens
    perplexity = np.exp(avg_loss)
    return perplexity

from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import OneCycleLR

def lstm_train(model, train_loader, val_loader, epochs, device):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=5e-3,
        steps_per_epoch=len(train_loader),
        epochs=epochs,
    )
    scaler = GradScaler()

    best_val_perp = float("inf")
    patience = 0
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch = X_batch.to(device, non_blocking=True)
            y_batch = y_batch.to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast():
                # only embed and LSTM under autocast
                logits = model.embedding(X_batch)
                out, _ = model.lstm(logits)
                out = out[:, -1, :]
            # Compute loss outside autocast (in float32)
            output = model.adaptive_softmax(out.float(), y_batch)
            loss = output.loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            train_loss += loss.item() * X_batch.size(0)
        avg_train_loss = train_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val = X_val.to(device, non_blocking=True)
                y_val = y_val.to(device, non_blocking=True)
                out = model.embedding(X_val)
                out, _ = model.lstm(out)
                out = out[:, -1, :]
                output = model.adaptive_softmax(out.float(), y_val)
                loss = output.loss
                val_loss += loss.item() * X_val.size(0)
        avg_val_loss = val_loss / len(val_loader.dataset)
        val_perplexity = torch.exp(torch.tensor(avg_val_loss)).item()

        print(
            f"Epoch {epoch+1}/{epochs}  "
            f"Train Loss: {avg_train_loss:.4f}  "
            f"Val Perplexity: {val_perplexity:.2f}"
        )

        if val_perplexity < best_val_perp:
            best_val_perp = val_perplexity
            patience = 0
            torch.save(model.state_dict(), "best_lstm.pt")
        else:
            patience += 1
            if patience >= 3:
                print("Early stopping (no improvement in 3 epochs).")
                break

    model.load_state_dict(torch.load("best_lstm.pt"))
    return model

In [7]:
print("Loading MAESTRO metadata...")
maestro_df = load_maestro_metadata(MAESTRO_METADATA_FILE)

if maestro_df is None or maestro_df.empty:
    print("Failed to load MAESTRO metadata. Exiting.")
    exit()

print("Extracting musical events from MIDI files...")
all_sequences_str_train, all_sequences_str_val, all_sequences_str_test = [], [], []

# Limiting files for faster demonstration - remove for full run
# max_files_per_split = 50 
num_processes = multiprocessing.cpu_count()
print(f"Using {num_processes} processes for parsing.")

for split_type, target_list in [('train', all_sequences_str_train), 
                                ('validation', all_sequences_str_val), 
                                ('test', all_sequences_str_test)]:
    split_df = maestro_df[maestro_df['split'] == split_type]
    # if max_files_per_split: 
    #    split_df = split_df.head(max_files_per_split)
    
    # First, get a list of MIDI paths that actually exist to avoid errors in the pool
    existing_midi_paths = [path for path in split_df['midi_filepath'] if path.exists()]
    
    if len(existing_midi_paths) < len(split_df['midi_filepath']):
        print(f"Warning: {len(split_df['midi_filepath']) - len(existing_midi_paths)} MIDI file(s) not found for {split_type} split.")

    if not existing_midi_paths:
        print(f"No existing MIDI files to process for {split_type} split.")
        continue
        
    print(f"Processing {len(existing_midi_paths)} files for {split_type} split using multiprocessing...")

    # Create a pool of worker processes
    with multiprocessing.Pool(processes=num_processes * 2) as pool:
        # Use imap_unordered to get results as they complete, which can be faster
        # if processing times vary. Wrap with tqdm for a progress bar.
        # The extract_events_from_midi function will be called for each path in existing_midi_paths.
        results_iterator = pool.imap_unordered(extract_events_from_midi, existing_midi_paths)
        
        for events in tqdm(results_iterator, total=len(existing_midi_paths), desc=f"Parsing {split_type}"):
            if events and len(events) > max(MARKOV_ORDER, LSTM_SEQUENCE_LENGTH): # Ensure sequence is long enough
                target_list.append(events)


if not all_sequences_str_train:
    print("No training data could be extracted. Check MIDI paths and parsing. Exiting.")
    exit()
    
print("Creating vocabulary...")
event_to_int, int_to_event, vocab_size = create_vocabulary(all_sequences_str_train + all_sequences_str_val + all_sequences_str_test)
print(f"Vocabulary size: {vocab_size}")

# Convert string sequences to integer sequences
all_sequences_int_train = [[event_to_int[e] for e in seq if e in event_to_int] for seq in all_sequences_str_train]
all_sequences_int_val = [[event_to_int[e] for e in seq if e in event_to_int] for seq in all_sequences_str_val]
all_sequences_int_test = [[event_to_int[e] for e in seq if e in event_to_int] for seq in all_sequences_str_test]

# Filter out empty sequences after mapping (if any event wasn't in vocab learned from train)
all_sequences_int_train = [s for s in all_sequences_int_train if len(s) > MARKOV_ORDER]
all_sequences_int_val = [s for s in all_sequences_int_val if len(s) > MARKOV_ORDER]
all_sequences_int_test = [s for s in all_sequences_int_test if len(s) > MARKOV_ORDER]


if not all_sequences_int_train:
    print("No valid integer sequences for training after vocabulary mapping. Exiting.")
    exit()
if not all_sequences_int_test:
    print("Warning: No valid integer sequences for testing. Perplexity might not be meaningful.")
    # Fallback: use validation set for testing if test set is empty
    if all_sequences_int_val:
        print("Using validation set for testing perplexity as test set is empty.")
        all_sequences_int_test = all_sequences_int_val
    else:
        print("Both test and validation sets are empty after processing. Cannot calculate perplexity. Exiting.")
        exit()

Loading MAESTRO metadata...
Extracting musical events from MIDI files...
Using 64 processes for parsing.
Processing 962 files for train split using multiprocessing...


Parsing train: 100%|██████████| 962/962 [10:02<00:00,  1.60it/s]


Processing 137 files for validation split using multiprocessing...


Parsing validation: 100%|██████████| 137/137 [02:17<00:00,  1.00s/it]


Processing 177 files for test split using multiprocessing...


Parsing test: 100%|██████████| 177/177 [01:54<00:00,  1.54it/s]


Creating vocabulary...
Vocabulary size: 319111

--- Markov Chain Model ---
Training Markov chain (order 3)...
Markov Chain Perplexity on Test Set: 317247.99
Generating music with Markov chain...
Generated MIDI file: markov_generated.mid


In [None]:
# --- Markov Chain ---
print("\n--- Markov Chain Model ---")
markov_model = MarkovChain(order=MARKOV_ORDER)
print(f"Training Markov chain (order {MARKOV_ORDER})...")
markov_model.train(all_sequences_int_train, vocab_size)

if all_sequences_int_test:
    markov_perplexity = markov_model.calculate_perplexity(all_sequences_int_test)
    print(f"Markov Chain Perplexity on Test Set: {markov_perplexity:.2f}")
else:
    print("No test data for Markov Chain perplexity.")

print("Generating music with Markov chain...")
if all_sequences_int_train and all_sequences_int_train[0]:
        # Ensure there is at least one training sequence to pick a seed from
    seed_for_markov = all_sequences_int_train[0][:MARKOV_ORDER] if len(all_sequences_int_train[0]) >= MARKOV_ORDER else \
                        (all_sequences_int_train[0] + [np.random.randint(0, vocab_size)] * (MARKOV_ORDER - len(all_sequences_int_train[0])))[:MARKOV_ORDER]

    if len(seed_for_markov) == MARKOV_ORDER : # Markov model generates from internal state or random context if no seed
        generated_markov_ints = markov_model.generate(length=GENERATION_LENGTH)
        generated_markov_events = [int_to_event.get(i, str(i)) for i in generated_markov_ints] # Handle if int not in map
        create_midi_from_events(generated_markov_events, "markov_generated.mid")
    else:
        print("Could not create a valid seed for Markov chain generation from training data.")
else:
    print("No training data to seed Markov chain generation.")

In [34]:
print("\n--- LSTM Model ---")
print("Preparing data for LSTM...")

# Prepare training sequences
X_train_lstm, y_train_lstm = prepare_sequences_for_lstm_pytorch(all_sequences_int_train, LSTM_SEQUENCE_LENGTH)

if X_train_lstm.size(0) == 0:
    print("Not enough training data to form sequences for LSTM. Exiting LSTM part.")
else:
    X_test_lstm, y_test_lstm = prepare_sequences_for_lstm_pytorch(all_sequences_int_test, LSTM_SEQUENCE_LENGTH)
    
    if X_test_lstm.size(0) == 0 and all_sequences_int_val:
        print("Warning: Test set too small for LSTM sequence preparation. Trying with validation set for perplexity.")
        X_test_lstm, y_test_lstm = prepare_sequences_for_lstm_pytorch(all_sequences_int_val, LSTM_SEQUENCE_LENGTH)
    
    print(f"LSTM training data shape: X={X_train_lstm.shape}, y={y_train_lstm.shape}")
    if X_test_lstm.size(0) > 0:
        print(f"LSTM test data shape: X={X_test_lstm.shape}, y={y_test_lstm.shape}")
    else:
        print("Warning: Not enough test/validation data to form sequences for LSTM perplexity calculation.")

    # Create Dataset and DataLoader for training and testing
    train_dataset = torch.utils.data.TensorDataset(X_train_lstm, y_train_lstm)
    from torch.utils.data import random_split
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=LSTM_BATCH_SIZE,num_workers=4, shuffle=True, pin_memory=True)
    import time

    start = time.time()
    batch = next(iter(train_loader))
    print(f"First batch loaded in {time.time() - start:.2f} seconds")
    val_loader = DataLoader(val_dataset, batch_size=LSTM_BATCH_SIZE, num_workers=4, pin_memory=True)
    

    test_dataset = torch.utils.data.TensorDataset(X_test_lstm, y_test_lstm) if X_test_lstm.size(0) > 0 else None
    test_loader = DataLoader(test_dataset, batch_size=LSTM_BATCH_SIZE, num_workers=4, pin_memory=True) if test_dataset else None

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    lstm_model = build_lstm_model(vocab_size, LSTM_EMBEDDING_DIM, LSTM_UNITS).to(device)
    print(lstm_model)

    print("Training LSTM model...")
    lstm_model = lstm_train(lstm_model, train_loader, val_loader, epochs=LSTM_EPOCHS, device=device)

    if test_loader is not None:
        with torch.no_grad():
            lstm_perplexity = calculate_perplexity_lstm(lstm_model, test_loader, device=device)
        print(f"LSTM Perplexity on Test Set: {lstm_perplexity:.2f}")
    else:
        print("LSTM perplexity cannot be calculated due to insufficient test/validation data meeting sequence length.")

    print("Generating music with LSTM...")
    if all_sequences_int_train and len(all_sequences_int_train[0]) >= LSTM_SEQUENCE_LENGTH:
        import numpy as np
        seed_idx_lstm = np.random.randint(0, len(all_sequences_int_train))
        seed_start_idx_lstm = np.random.randint(0, len(all_sequences_int_train[seed_idx_lstm]) - LSTM_SEQUENCE_LENGTH)
        seed_for_lstm = all_sequences_int_train[seed_idx_lstm][seed_start_idx_lstm : seed_start_idx_lstm + LSTM_SEQUENCE_LENGTH]

        generated_lstm_ints = generate_sequence_lstm(
            lstm_model, seed_for_lstm, GENERATION_LENGTH, vocab_size, temperature=0.7, device=device
        )
        generated_lstm_events = [int_to_event.get(i, str(i)) for i in generated_lstm_ints]  # Fallback if missing
        create_midi_from_events(generated_lstm_events, "lstm_generated.mid")
    else:
        print("Could not find a suitable seed sequence from training data for LSTM generation.")

print("\nDone.")


--- LSTM Model ---
Preparing data for LSTM...
LSTM training data shape: X=torch.Size([2922699, 50]), y=torch.Size([2922699])
LSTM test data shape: X=torch.Size([367368, 50]), y=torch.Size([367368])
First batch loaded in 0.62 seconds
LSTMModel(
  (embedding): Embedding(319111, 128)
  (lstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.2)
  (adaptive_softmax): AdaptiveLogSoftmaxWithLoss(
    (head): Linear(in_features=256, out_features=1003, bias=False)
    (tail): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=256, out_features=64, bias=False)
        (1): Linear(in_features=64, out_features=9000, bias=False)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=16, bias=False)
        (1): Linear(in_features=16, out_features=40000, bias=False)
      )
      (2): Sequential(
        (0): Linear(in_features=256, out_features=4, bias=False)
        (1): Linear(in_features=4, out_features=269111, bias=False)
      )
    )
  )
)
T

Epoch 1/15: 100%|██████████| 1142/1142 [02:08<00:00,  8.89it/s]


Epoch 1/15  Train Loss: 11.2574  Val Perplexity: 11488.38


Epoch 2/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.98it/s]


Epoch 2/15  Train Loss: 9.0668  Val Perplexity: 9486.67


Epoch 3/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.96it/s]


Epoch 3/15  Train Loss: 8.9034  Val Perplexity: 8609.28


Epoch 4/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.96it/s]


Epoch 4/15  Train Loss: 8.6221  Val Perplexity: 6767.47


Epoch 5/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.96it/s]


Epoch 5/15  Train Loss: 8.2623  Val Perplexity: 5589.90


Epoch 6/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.97it/s]


Epoch 6/15  Train Loss: 7.9844  Val Perplexity: 5306.89


Epoch 7/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.99it/s]


Epoch 7/15  Train Loss: 7.7834  Val Perplexity: 4965.57


Epoch 8/15: 100%|██████████| 1142/1142 [02:06<00:00,  9.00it/s]


Epoch 8/15  Train Loss: 7.6150  Val Perplexity: 4751.53


Epoch 9/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.99it/s]


Epoch 9/15  Train Loss: 7.4524  Val Perplexity: 4989.01


Epoch 10/15: 100%|██████████| 1142/1142 [02:07<00:00,  8.97it/s]


Epoch 10/15  Train Loss: 7.2899  Val Perplexity: 5633.17


Epoch 11/15: 100%|██████████| 1142/1142 [02:06<00:00,  9.01it/s]


Epoch 11/15  Train Loss: 7.1292  Val Perplexity: 6796.34
Early stopping (no improvement in 3 epochs).


AttributeError: 'Tensor' object has no attribute 'loss'