# MIDI Generation using LSTM Model

This notebook generates new MIDI files based on an example MIDI and a pre-trained LSTM model.

In [None]:
import music21
import numpy as np
from keras.models import load_model
from music21 import *
import random
import matplotlib.pyplot as plt
from IPython.display import Audio
import os

## Configuration
Set your paths and parameters here:

In [None]:
# Paths
MODEL_PATH = "path_to_your_model.keras"
EXAMPLE_MIDI = "path_to_your_example.mid"
OUTPUT_MIDI = "generated_output.mid"

# Parameters
SEQUENCE_LENGTH = 100
N_NOTES = 500
TEMPERATURE = 1.0  # Controls randomness in generation

## Helper Functions

In [None]:
def extract_notes(file_path):
    notes = []
    try:
        midi = converter.parse(file_path)
        instruments = instrument.partitionByInstrument(midi)
        
        for part in instruments.parts:
            for element in part.flat.notes:
                if isinstance(element, note.Note):
                    notes.append(str(element.pitch))
                elif isinstance(element, chord.Chord):
                    notes.append('.'.join(str(n) for n in element.normalOrder))
                    
        return notes, True
    except Exception as e:
        print(f"Error loading MIDI: {str(e)}")
        return [], False

In [None]:
def prepare_sequences(notes, sequence_length=SEQUENCE_LENGTH):
    pitchnames = sorted(set(notes))
    note_to_int = dict((note, number) for number, note in enumerate(pitchnames))
    
    network_input = []
    network_output = []
    
    for i in range(0, len(notes) - sequence_length):
        sequence_in = notes[i:i + sequence_length]
        sequence_out = notes[i + sequence_length]
        network_input.append([note_to_int[char] for char in sequence_in])
        network_output.append(note_to_int[sequence_out])
    
    return (np.array(network_input), np.array(network_output), len(pitchnames), pitchnames)

In [None]:
def generate_notes(model, network_input, pitchnames, n_vocab, n_notes=N_NOTES, temperature=TEMPERATURE):
    start = random.randint(0, len(network_input)-1)
    int_to_note = dict((number, note) for number, note in enumerate(pitchnames))
    pattern = list(network_input[start])
    
    prediction_output = []
    
    for i in range(n_notes):
        prediction_input = np.reshape(pattern, (1, len(pattern), 1))
        prediction_input = prediction_input / float(n_vocab)
        
        prediction = model.predict(prediction_input, verbose=0)
        
        # Apply temperature
        prediction = np.log(prediction) / temperature
        exp_preds = np.exp(prediction)
        prediction = exp_preds / np.sum(exp_preds)
        
        index = np.random.choice(len(prediction[0]), p=prediction[0])
        
        result = int_to_note[index]
        prediction_output.append(result)
        
        pattern.append(index)
        pattern = pattern[1:]
    
    return prediction_output

In [None]:
def create_midi(prediction_output, filename=OUTPUT_MIDI):
    offset = 0
    output_notes = []

    for pattern in prediction_output:
        if ('.' in pattern) or pattern.isdigit():
            chord_notes = pattern.split('.')
            notes = []
            for current_note in chord_notes:
                new_note = note.Note(int(current_note))
                new_note.storedInstrument = instrument.Piano()
                notes.append(new_note)
            new_chord = chord.Chord(notes)
            new_chord.offset = offset
            output_notes.append(new_chord)
        else:
            new_note = note.Note(pattern)
            new_note.offset = offset
            new_note.storedInstrument = instrument.Piano()
            output_notes.append(new_note)

        offset += 0.5

    midi_stream = stream.Stream(output_notes)
    midi_stream.write('midi', fp=filename)
    return midi_stream

## Load and Verify Model

In [None]:
try:
    model = load_model(MODEL_PATH)
    print("Model loaded successfully")
    model.summary()
except Exception as e:
    print(f"Error loading model: {str(e)}")

## Process Example MIDI

In [None]:
# Extract and analyze notes
notes, success = extract_notes(EXAMPLE_MIDI)
if success:
    print(f"Total notes extracted: {len(notes)}")
    print(f"Unique notes/chords: {len(set(notes))}")
    
    # Plot note distribution
    plt.figure(figsize=(12, 6))
    unique_notes = list(set(notes))
    plt.hist([unique_notes.index(note) for note in notes], bins=50)
    plt.title('Note Distribution in Example MIDI')
    plt.xlabel('Note Index')
    plt.ylabel('Frequency')
    plt.show()

## Generate New MIDI

In [None]:
if success:
    # Prepare sequences
    network_input, _, n_vocab, pitchnames = prepare_sequences(notes)
    
    # Generate new notes
    prediction_output = generate_notes(model, network_input, pitchnames, n_vocab)
    
    # Create MIDI file
    midi_stream = create_midi(prediction_output)
    
    print(f"Generated MIDI saved as: {OUTPUT_MIDI}")
    
    # Play the generated MIDI
    Audio(OUTPUT_MIDI)