In [2]:
!pip install pretty_midi numpy 
!pip install -q tensorflow
!pip install 'keras>=3.5.0'

import pretty_midi
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint
import os

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m48.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=76ab1671ed89497ff912c2ae86699176d72f2cedca94c238c4a5f98752be1797
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty_midi
Installing collected packages: mido, pr

2025-07-04 08:32:27.990537: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751617948.212756      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751617948.276355      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### Step 1: Load Data 
Loading and printing basic information about the MIDI dataset


In [3]:
# configuration
MIDI_FILE_PATH = "/kaggle/input/lakh-midi-clean/"
SEQUENCE_LENGTH = 50 # number of notes in the imput sequence for the RNN
BATCH_SIZE = 64
EPOCHS = 50

In [4]:
# load and print basic information about its contents
def load_and_inspect_midi(midi_path):
    try:
        midi_data = pretty_midi.PrettyMIDI(midi_path)
        print(f"Successfully loaded MIDI file: {midi_path}")
        print(f"MIDI file contains {len(midi_data.instruments)} instruments.")

        notes = []
        for instrument in midi_data.instruments:
            # focus on notes
            if not instrument.is_drum: # excludes drum tracks for melody generation 
                print(f" Insrument: {instrument.name}, Program: {instrument.program} ({pretty_midi.program_to_instrument_name(instrument.program)})")
                for note in instrument.notes:
                    # store pitch, start_time, end_time, velocity
                    notes.append((note.pitch, note.start, note.end, note.velocity))

        # sort notes by start time to maintain chronological order
        notes.sort(key=lambda x: x[1])

        print(f"\nTotal notes extracted (across all non_drum instruments): {len(notes)}")
        if notes:
            print(f"First 5 extracted notes: {notes[:5]}")
            print(f"Last 5 extracted notes: {notes[-5:]}")

        else:
            print("No notes found in the non_drum tracks")

        return notes, midi_data

    except Exception as e:
        print(f"Error loading MIDI files: {e}")
        return None, None

# if __name__ == "__main__":
#     extracted_notes, _ = load_and_inspect_midi(MIDI_FILE_PATH)
#     if extracted_notes:
#         print("\nProceeding to next steps for data processing and model building....")
#     else:
#         print("Please ensure your MIDI_FILE_PATH is correct and the file is valid.")

### Step 2: Data Extraction and Simplification

**Goal of this step:**

1. Extract relevant features: From the (pitch, start, end, velocity) tuples, we'll primarily extract pitch and calculate a simplified duration.

2. Handle polyphony (simplification): Since our goal is monophonic melodies, we need to decide how to handle cases where multiple notes overlap (i.e., chords). For simplicity, we'll likely pick only one note at a given time or just iterate through all notes, effectively flattening any chords into a sequence of individual notes played close together. The pretty_midi library already gives us individual notes, so if a chord is played, they'll just appear as separate notes with the same start time. For very simple generation, we might even just take the highest or lowest note of a chord, but for now, let's keep all individual notes.

3. Quantize time: Musical durations are often quantized (e.g., quarter notes, eighth notes, sixteenth notes). We'll convert continuous start and end times into a more discrete duration representation. A simple way is to calculate duration = note.end - note.start. Another approach is to calculate the time between the onset of consecutive notes. For a simple RNN, predicting a sequence of (pitch, time_until_next_note) pairs is a common and effective strategy.

In [None]:
# Data extraction and simplfication
def process_notes_for_rnn(notes_list, quantization_steps_per_beat=4):
    """
    Converts a list of raw (pitch, start, end, velocity) tuples into
    a simplified sequence of (pitch, time_delta) pairs suitable for an RNN.

    Args:
        notes_list: A list of tuples, each representing a note.
                    (pitch, start_time, end_time, velocity)
        quantization_steps_per_beat: How many sub beats per beat for time quantization.

    Returns:
        tuple: A tuple containing:
            - sequence_of_events: A list of (pitch, time_delta) tuples.
            - unique_pitches: All unique pitch values found.
            - unique_time_deltas: All unique time_delta values found.
            
    """
    if not notes_list:
        print("No notes to process.")
        return [], set(), set()

    # Determine the tempo/beats_per_minute to convert seconds to beats
    # For simplicity here, we assume a standard tempo or you might get it from PrettyMIDI's tempo map
    # For now, let's just use raw seconds, or manually define a beat duration if not derived from MIDI
    # If deriving from MIDI, you'd use midi_data.get_tempo_changes()
    # For this,  we will just use relative time in seconds, and quantize it.

    # calculate quantized_time deltas
    sequence_of_events = []

    # store all unique pitches and time deltas for vocabulary creation
    unique_pitches = set()
    unique_time_deltas = set()

    # Convert notes into (pitch, time_delta) events
    # We iterate through the sorted notes to calculate the time difference
    # between the start of the current note and the start of the next note.
    # This captures both note duration and rests implicitly.

    current_time = 0.0 #initialize time tracking

    # for the first note, its time_delta is from time 0.0 to its start time
    if notes_list[0][1] > 0: # if the first note does not start from zero
        initial_rest = notes_list[0][1]
        quantized_initial_rest = round(initial_rest * quantization_steps_per_beat) / quantization_steps_per_beat
        if quantized_initial_rest > 0:
            sequence_of_events.append(('Rest', quantized_initial_rest))
            unique_time_deltas.add(quantized_initial_rest)
        current_time = notes_list[0][1]

    for i in range(len(notes_list)):
        pitch, start, end, velocity = notes_list[i]

        # Handle rests and overlaps between consecutive notes
        if i > 0:
            prev_end = notes_list[i-1][2]
            # if there's a gap(rest) between notes, add a 'Rest' event
            if start > prev_end:
                rest_duration = start - prev_end
                quantized_rest_duration = round(rest_duration * quantization_steps_per_beat) / quantization_steps_per_beat
                if quantized_rest_duration > 0:
                    sequence_of_events.append(('Rest', quantized_rest_duration))
                    unique_time_deltas.add(quantized_rest_duration)

            # If notes overlap (polyphony), we are simplifying by just processing notes one by one
            # The time_delta here captures the time from the *start* of the previous note to the *start* of the current note,
            # or from the *end* of the previous note to the *start* of the current note (rest), plus the duration of the current note.
            # A common simpler approach is just (pitch, duration) pairs, and *then* calculate time deltas for the RNN.
            # Let's try (pitch, duration) first, it's simpler for monophonic.

            # simple approach: (pitch, duration_of_note)
            duration = end - start
            quantized_duration = round(duration * quantization_steps_per_beat) / quantization_steps_per_beat

            # Avoid 0 duration notes or avoid them
            if quantized_duration <= 0:
                quantized_duration = 1.0 / quantization_steps_per_beat # assign minimum duration

            sequence_of_events.append((pitch, quantized_duration))
            unique_pitches.add(pitch)
            unique_time_deltas.add(quantized_duration)

    print(f"\nProcessed {len(sequence_of_events)} events (pitch, duration) for RNN input.")
    print(f"Unique pitches found: {sorted(list(unique_pitches))}")
    print(f"Unique durations found: {sorted(list(unique_time_deltas))}")
    print(f"First 10 processed events: {sequence_of_events[:10]}")

    return sequence_of_events, unique_pitches, unique_time_deltas

# if __name__ == "__main__":
#     extracted_notes, midi_obj = load_and_inspect_midi(MIDI_FILE_PATH)
    
#     if extracted_notes:
#         # For simplicity, let's assume a default tempo for quantization if not explicit
#         # You can get tempo from midi_obj.estimate_tempos() if needed, but for simple files
#         # a fixed quantization is fine. Let's use 4 steps per beat (quarter notes).
        
#         processed_events, unique_pitches_set, unique_time_deltas_set = process_notes_for_rnn(
#             extracted_notes, quantization_steps_per_beat=4
#         )
#         processed_events = [f"{pitch}_{duration}" for pitch, duration in processed_events]
        
#         # Now you have `processed_events` which is a list of (pitch_duration) tuples.
#         # `unique_pitches_set` and `unique_time_deltas_set` contain your vocabulary.
        
#         # For the next step (Step 3: Sequence Creation), you'll combine these into a single vocabulary
#         # and create integer mappings.
        
#         print("\nStep 2: Data Extraction & Simplification Complete.")
#         print("You now have a sequence of (pitch, duration) events and their unique values.")
#     else:
#         print("MIDI file could not be processed. Please check the path and file validity.")

### Step 3: Sequence Creation

 In this step, we'll prepare the data in a format directly usable by TensorFlow/Keras for training an RNN. This involves:

1. Creating a Unified Vocabulary: Combining all unique pitches and durations into a single set of unique "events."

2. Mapping to Integers: Assigning a unique integer ID to each unique event.

3. Creating Input-Output Pairs: Sliding a window over your sequence of events to create the X (input) and y (target/output) pairs for your neural network.

In [None]:
# Sequence Creation
def create_rnn_sequences(processed_events, seq_length=SEQUENCE_LENGTH):
    """
    Creates input (X) and output (y) sequences for RNN training.
    Each event (pitch_duration) is mapped to a unique integer.
    
    Args:
        processed_events (list): List of (pitch_duration) tuples.
        seq_length (int): The length of the input sequence for the RNN.
                          The RNN will learn to predict the (seq_length + 1)th event.
                          
    Returns:
        tuple: (X, y, event_to_int, int_to_event, vocab_size)
            - X (np.array): Input sequences for the RNN.
            - y (np.array): Target output (next event) for the RNN.
            - event_to_int (dict): Mapping from (pitch_duration) tuple to integer ID.
            - int_to_event (dict): Mapping from integer ID to (pitch_duration) tuple.
            - vocab_size (int): Total number of unique events.
    """
    if not processed_events:
        print("No processed events to create sequences from.")
        return None, None, None, None, 0

    # 1. create a unified vocabulary and map to integers
    # each unique (pitch_duration) tuple will be a single event for the RNN
    unique_events = sorted(list(set(processed_events))) # sort for consistent mapping

    event_to_int = {event: i for i, event in enumerate(unique_events)}
    int_to_event = {i: event for i, event in enumerate(unique_events)}

    vocab_size = len(unique_events)
    print(f"\nCreated a vocabulary of {vocab_size} unique_events.")
    print(f"Example mmapping: {list(event_to_int.items())[:5]}...")

    # convert the processed_events list into a list of integer IDs
    numerical_events = [event_to_int[event] for event in processed_events]

    # 2. create input-output pairs (sliding window)
    X = []
    y = []

    for i in range(len(numerical_events) - seq_length):
        input_sequence = numerical_events[i: i + seq_length]
        output_event = numerical_events[i + seq_length]

        X.append(input_sequence)
        y.append(output_event)

    # Convert list to numpy array
    X = np.array(X)
    # The output 'y' should be one-hot encoded for categorical cross-entropy loss
    # However, if using sparse_categorical_crossentropy, it can remain as integer IDs.
    # Let's keep it as integer IDs for now, as it's simpler and works with sparse_categorical_crossentropy.
    y = np.array(y)

    print(f"Created {len(X)} input-output sequence pairs.")
    print(f"Shape of X: {X.shape} (Number of sequences, Sequence Length)")
    print(f"Shape of y: {y.shape} (Number of target events)")
    print(f"First input sequence (X[0]): {X[0]}")
    print(f"First target event (y[0]): {y[0]}")

    return X, y, event_to_int, int_to_event, vocab_size

# if __name__ == "__main__":
#     extracted_notes, midi_obj = load_and_inspect_midi(MIDI_FILE_PATH)
    
#     if extracted_notes:
#         processed_events, unique_pitches_set, unique_durations_set = process_notes_for_rnn(
#             extracted_notes, quantization_steps_per_beat=4
#         )
        
#         processed_events = [f"{pitch}_{duration:.3f}" for pitch, duration in processed_events]

        
#         if processed_events:
#             X, y, event_to_int, int_to_event, vocab_size = create_rnn_sequences(
#                 processed_events, seq_length=SEQUENCE_LENGTH
#             )
            
#             # These variables (X, y, event_to_int, int_to_event, vocab_size)
#             # are now ready to be used in Step 4 for building and training the RNN.
            
#             print("\nStep 3: Sequence Creation Complete.")
#             print("Your data is now prepared into input-output sequences for the RNN.")
#         else:
#             print("No processed events to create sequences from. Check previous steps.")
#     else:
#         print("MIDI file could not be processed. Please check the path and file validity.")

### Step 4: Model Building

**Goal of this step:**

1. Define a sequential Keras model.

2. Add an Embedding layer to convert integer IDs into dense vectors.

3. Add one or more LSTM layers to learn temporal patterns.

4. Add a Dense output layer with softmax activation to predict the next event.

5. Compile the model with an appropriate loss function and optimizer.

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding, Dropout 
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint

EMBEDDING_DIM = 100
LSTM_UNITS = 256

# Model Building
def build_rnn_model(vocab_size, seq_length=SEQUENCE_LENGTH, embedding_dim=EMBEDDING_DIM, lstm_units=LSTM_UNITS):
    """
    Builds a Sequential Keras RNN model for music generation.
    
    Args:
        vocab_size (int): The total number of unique events in our vocabulary.
        seq_length (int): The length of the input sequences (timesteps).
        embedding_dim (int): The dimension of the embedding vector for each event.
        lstm_units (int): The number of LSTM units (neurons) in the hidden layer.
        
    Returns:
        tf.keras.Model: The compiled Keras model.
    """
    model = Sequential([
        # 1. Embedding Layer: Converts integer inputs into dense vectors
        # input_dim: size of the vocabulary (max integer index + 1)
        # output_dim: dimension of the dense embedding
        # input_length: length of input sequences
        Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=seq_length),
        
        # 2. LSTM Layer: The core of the RNN, learns sequential patterns
        # return_sequences=True: Important if stacking multiple LSTM layers,
        #                       makes the layer return a sequence of outputs for each timestep.
        #                       For a single LSTM layer before Dense, can be False.
        LSTM(lstm_units, return_sequences=True),
        Dropout(0.3), # regularization to prevent overfitting
        LSTM(lstm_units),

        # 3. Dense Output Layer: Predicts the probability of the next event
        # units: Number of possible output classes (our vocabulary size)
        # activation='softmax': Converts raw scores into a probability distribution over the vocabulary
        Dense(vocab_size, activation='softmax')
    
        
    ])
    # Compile the model
    # optimizer: Adam is a good general-purpose optimizer
    # loss: sparse_categorical_crossentropy is used when target labels are integer encoded (0, 1, 2...)
    # metrics: 'accuracy' to monitor classification accuracy during training
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    print("\n___ RNN Model Architecture ___")
    model.summary()
    print("_______________________________")

    return model

# if __name__ == "__main__":
#     extracted_notes, midi_obj = load_and_inspect_midi(MIDI_FILE_PATH)
    
#     if extracted_notes:
#         processed_events, unique_pitches_set, unique_durations_set = process_notes_for_rnn(
#             extracted_notes, quantization_steps_per_beat=4
#         )

#         processed_events = [f"{pitch}_{duration:.3f}" for pitch, duration in processed_events]

        
#         if processed_events:
#             X, y, event_to_int, int_to_event, vocab_size = create_rnn_sequences(
#                 processed_events, seq_length=SEQUENCE_LENGTH
#             )
            
#             if X is not None and y is not None:
#                 model = build_rnn_model(vocab_size, seq_length=SEQUENCE_LENGTH)
                
#                 # Now the `model` object is ready for training in Step 5.
#                 print("\nStep 4: Model Building Complete.")
#                 print("Your Keras RNN model has been defined and compiled.")
#             else:
#                 print("Could not create sequences. Check previous steps.")
#         else:
#             print("No processed events to create sequences from. Check previous steps.")
#     else:
#         print("MIDI file could not be processed. Please check the path and file validity.")

### Step 5: Model Training

**Goal of this step:**

1. Train the model object using model.fit().

2. Set up callbacks to save the best model weights during training.

In [None]:
# Model Training
MODEL_WEIGHTS_PATH = '/kaggle/working/music_model_weights.h5'

def train_model(model, X, y, epochs=EPOCHS, batch_size=BATCH_SIZE, model_weights_path=MODEL_WEIGHTS_PATH):
    checkpoint_callback = ModelCheckpoint(
        filepath = model_weights_path,
        monitor = 'loss',
        verbose = 1,
        save_best_only = True,
        mode = 'min'
    )

    print(f"\n___ Starting Model Training ({epochs}) epochs) ___")
    history = model.fit(
        X, 
        y,
        epochs = epochs,
        batch_size = batch_size,
        callbacks = [checkpoint_callback],
        verbose = 1
    )
    print("___ Model Training Complete ___")
    return history

# if __name__ == "__main__":
#     extracted_notes, midi_obj = load_and_inspect_midi(MIDI_FILE_PATH)
    
#     if extracted_notes:
#         processed_events, unique_pitches_set, unique_durations_set = process_notes_for_rnn(
#             extracted_notes, quantization_steps_per_beat=4
#         )

#         processed_events = [f"{pitch}_{duration:.3f}" for pitch, duration in processed_events]
        
#         if processed_events:
#             X, y, event_to_int, int_to_event, vocab_size = create_rnn_sequences(
#                 processed_events, seq_length=SEQUENCE_LENGTH
#             )
            
#             if X is not None and y is not None and len(X) > 0: # Ensure sequences were created
#                 model = build_rnn_model(vocab_size, seq_length=SEQUENCE_LENGTH)
                
#                 # Check if enough data for training
#                 if len(X) < BATCH_SIZE:
#                     print(f"Warning: Not enough data for one batch ({len(X)} sequences vs {BATCH_SIZE} batch size). "
#                           "Training might be unstable or fail. Consider reducing BATCH_SIZE or increasing dataset size.")
#                     # Adjust batch size if dataset is too small
#                     if len(X) > 0:
#                         BATCH_SIZE = len(X) # Use all data as one batch
#                         print(f"Adjusted BATCH_SIZE to {BATCH_SIZE}.")
#                     else:
#                         print("Cannot train: No sequences created after preprocessing.")
#                         exit()

    #             history = train_model(model, X, y, epochs=EPOCHS, batch_size=BATCH_SIZE, model_weights_path=MODEL_WEIGHTS_PATH)
                
    #             print("\nStep 5: Model Training Complete.")
    #             print(f"Trained model weights saved to: {MODEL_WEIGHTS_PATH}")
    #             # You can access training history: history.history['loss'], history.history['accuracy']
    #         else:
    #             print("Could not create sufficient sequences for training. Check previous steps and dataset size.")
    #     else:
    #         print("No processed events to create sequences from. Check previous steps.")
    # else:
    #     print("MIDI file could not be processed. Please check the path and file validity.")

### Step 6: Music Generation

**Goal of this step:**

1. Load the trained model weights.

2. Choose a "seed" sequence to start the generation.

3. Implement a generation loop:

   - Predict the next event based on the current sequence.

   - Sample an event from the model's probability distribution (to add creativity).

   - Append the sampled event to the generated sequence.

   - Repeat for a desired number of new events.

4. Convert the generated sequence of integer IDs back to (pitch, duration) tuples.

5. construct a pretty_midi.PrettyMIDI object from these tuples.

6. Save the generated music as a new MIDI file.

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model # Import load_model
from tensorflow.keras.layers import LSTM, Dense, Embedding, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint
import os
import random

GENERATED_MIDI_PATH = '/kaggle/working/generated_melody.mid' # Output path for the generated MIDI file
GENERATION_LENGTH = 100

In [None]:
# Music Generation
def generate_music(model, seed_sequence, int_to_event, vocab_size, generation_length=GENERATION_LENGTH, temperature=1.0):
    """
    Generates a new sequence of music events using the trained RNN model.
    """
    generated_events_int = list(seed_sequence)  # Start with the initial seed
    current_sequence = np.array(seed_sequence).reshape(1, -1)  

    print(f"\n__ Generating {generation_length} new events ___")

    for i in range(generation_length):
        # Predict the probabilities for the next event
        prediction_probs = model.predict(current_sequence, verbose=0)[0]

        # Apply temperature (for controlling randomness)
        prediction_probs = np.exp(np.log(prediction_probs + 1e-10) / temperature)
        prediction_probs /= np.sum(prediction_probs)

        # Sample the next event ID
        next_event_int = np.random.choice(vocab_size, p=prediction_probs)

        # Append the new event to the sequence
        generated_events_int.append(next_event_int)

        # Update current_sequence: slide the window and append next_event_int
        current_sequence = np.append(current_sequence[:, 1:], [[next_event_int]], axis=1)  

    # Convert integer IDs back to (pitch_duration) strings
    
    generated_events_tuples = [int_to_event[event_id] for event_id in generated_events_int]  

    print("___ Music Generation Complete ___")
    return generated_events_tuples


def convert_events_to_midi(events_list, output_midi_path=GENERATED_MIDI_PATH, tempo=120):
    """
    Converts a list of (pitch, duration) tuples into a MIDI file.
    
    Args:
    events_list (list): List of (pitch, duration) tuples.
    output_midi_path (str): Path to save the generated MIDI file.
    tempo (int): Tempo in beats per minute (BPM) for the generated MIDI.
    """
    if not events_list:
        print("No events to convert to MIDI")
        return

    midi = pretty_midi.PrettyMIDI()
    # for a simple melody, lets use a piano instrument(program 0)
    piano_program = pretty_midi.instrument_name_to_program("Acoustic Grand Piano")
    piano_instrument = pretty_midi.Instrument(program=piano_program)

    current_time = 0.0
    for event in events_list:
        pitch, duration = event

        # if we had a 'Rest' token, 
        if isinstance(pitch, str) and pitch == 'Rest':
            current_time += duration
        else:
            # ensure pitch is within valid MIDI range (0-127)
            pitch = int(max(0, min(127, pitch)))
            # ensure duration is positive
            note_duration_seconds = max(0.01, duration)

            note = pretty_midi.Note(
                velocity=100,
                pitch=pitch,
                start=current_time,
                end=current_time + note_duration_seconds
            )
            piano_instrument.notes.append(note)
            current_time += note_duration_seconds
    midi.instruments.append(piano_instrument)

    try:
        midi.write(output_midi_path)
        print(f"Generated MIDI saved to: {output_midi_path}")
    except Exception as e:
        print(f"Error saving MIDI file: {e}")

# if __name__ == "__main__":
#     # --- Music Generation Execution ---
#     # Pick a random seed sequence from your training data
#     start_index = np.random.randint(0, len(X) - 1)
#     seed_sequence_int = X[start_index] # Get a sequence from your training inputs
    
#     generated_music_events = generate_music(
#         model, 
#         seed_sequence_int, 
#         int_to_event, 
#         vocab_size, 
#         generation_length=GENERATION_LENGTH,
#         temperature=0.8 # Experiment with temperature!
#     )
    
#     # Convert generated event strings back to (pitch, duration) tuples
#     generated_music_events = [
#         (int(event.split('_')[0]) if event.split('_')[0] != 'Rest' else 'Rest', 
#          float(event.split('_')[1])) 
#         for event in generated_music_events
#     ]

#     # Convert to MIDI and save
#     convert_events_to_midi(generated_music_events, output_midi_path=GENERATED_MIDI_PATH)
    
#     print("\nStep 6: Music Generation Complete!")
#     print("You can now open 'generated_melody.mid' with a MIDI player or convert it to WAV/MP3.")

In [None]:
if __name__ == "__main__":
    extracted_notes, midi_obj = load_and_inspect_midi(MIDI_FILE_PATH)
    
    if extracted_notes:
        processed_events, unique_pitches_set, unique_durations_set = process_notes_for_rnn(
            extracted_notes, quantization_steps_per_beat=4
        )

        processed_events = [f"{pitch}_{duration:.3f}" for pitch, duration in processed_events]
        
        if processed_events:
            X, y, event_to_int, int_to_event, vocab_size = create_rnn_sequences(
                processed_events, seq_length=SEQUENCE_LENGTH
            )
            
            if X is not None and y is not None and len(X) > 0: # Ensure sequences were created
                # Build model first (even if loading weights, need architecture)
                model = build_rnn_model(vocab_size, seq_length=SEQUENCE_LENGTH)
                
                # Check if weights file exists. If not, train the model.
                if os.path.exists(MODEL_WEIGHTS_PATH):
                    print(f"\nLoading trained model weights from: {MODEL_WEIGHTS_PATH}")
                    model.load_weights(MODEL_WEIGHTS_PATH)
                    print("Model weights loaded.")
                else:
                    print(f"\nModel weights not found at {MODEL_WEIGHTS_PATH}. Training the model...")
                    # Adjust batch size if dataset is too small
                    if len(X) < BATCH_SIZE:
                        if len(X) > 0:
                            BATCH_SIZE_ADJUSTED = len(X) 
                            print(f"Adjusted BATCH_SIZE to {BATCH_SIZE_ADJUSTED}.")
                        else:
                            print("Cannot train: No sequences created after preprocessing.")
                            exit()
                    else:
                        BATCH_SIZE_ADJUSTED = BATCH_SIZE # Use default batch size

                    train_model(model, X, y, epochs=EPOCHS, batch_size=BATCH_SIZE_ADJUSTED, model_weights_path=MODEL_WEIGHTS_PATH)
                
                # --- Music Generation Execution ---
                # Pick a random seed sequence from your training data
                start_index = np.random.randint(0, len(X) - 1)
                seed_sequence_int = X[start_index] # Get a sequence from your training inputs

                generated_music_events = generate_music(
                    model, 
                    seed_sequence_int, 
                    int_to_event, 
                    vocab_size, 
                    generation_length=GENERATION_LENGTH,
                    temperature=0.8 # Experiment with temperature!
                )

                # Convert generated event strings back to (pitch, duration) tuples
                generated_music_events = [
                    (int(event.split('_')[0]) if event.split('_')[0] != 'Rest' else 'Rest', 
                     float(event.split('_')[1])) 
                    for event in generated_music_events
                ]
                # Convert to MIDI and save
                convert_events_to_midi(generated_music_events, output_midi_path=GENERATED_MIDI_PATH)

                print("\nStep 6: Music Generation Complete!")
                print("You can now open 'generated_melody.mid' with a MIDI player or convert it to WAV/MP3.")

            else:
                print("Could not create sufficient sequences for training/generation. Check previous steps and dataset size.")
        else:
            print("No processed events to create sequences from. Check previous steps.")
    else:
        print("MIDI file could not be processed. Please check the path and file validity.")