<a href="https://colab.research.google.com/github/RuchirS-spec/Music_Generation-model/blob/main/RNNs_LSTM_Music_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

!pip install music21
import numpy as np
import os
from music21 import converter, instrument, note, chord, stream
from google.colab import files
from tqdm.notebook import tqdm



In [10]:
# creating directories for the input dataset and the output file path
MIDI_DATA_PATH = "/content/midi_data"
OUTPUT_PATH = "/content/processed_data"
MODEL_SAVE_PATH = "/content/trained_model"

if not os.path.exists(MIDI_DATA_PATH):
    os.makedirs(MIDI_DATA_PATH)
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)
if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH)


In [4]:
# Defining model parameters -

SEQUENCE_LENGTH = 100

EMBED_SIZE = 256
HIDDEN_SIZE = 512
NUM_LAYERS = 3
DROPOUT = 0.3

BATCH_SIZE = 128
LEARNING_RATE = 0.001
NUM_EPOCHS = 90

In [5]:
#upload the midi_data files
print(f"Please upload your MIDI files. They will be saved to the '{MIDI_DATA_PATH}' directory.")
uploaded = files.upload()

#save the midi_files to the correct directory
for filename in uploaded.keys():
    os.rename(filename, os.path.join(MIDI_DATA_PATH, filename))



Please upload your MIDI files. They will be saved to the '/content/midi_data' directory.


Saving x (1).mid to x (1).mid
Saving x (2).mid to x (2).mid
Saving x (3).mid to x (3).mid
Saving x (4).mid to x (4).mid
Saving x (5).mid to x (5).mid
Saving x (6).mid to x (6).mid
Saving x (7).mid to x (7).mid
Saving x (8).mid to x (8).mid
Saving x (9).mid to x (9).mid
Saving x (10).mid to x (10).mid
Saving x (11).mid to x (11).mid
Saving x (12).mid to x (12).mid
Saving x (13).mid to x (13).mid
Saving x (14).mid to x (14).mid
Saving x (15).mid to x (15).mid
Saving x (16).mid to x (16).mid
Saving x (17).mid to x (17).mid
Saving x (18).mid to x (18).mid
Saving x (19).mid to x (19).mid
Saving x (20).mid to x (20).mid
Saving x (21).mid to x (21).mid
Saving x (22).mid to x (22).mid
Saving x (23).mid to x (23).mid
Saving x (24).mid to x (24).mid
Saving x (25).mid to x (25).mid
Saving x (26).mid to x (26).mid
Saving x (27).mid to x (27).mid
Saving x (28).mid to x (28).mid
Saving x (29).mid to x (29).mid
Saving x (30).mid to x (30).mid
Saving x (31).mid to x (31).mid
Saving x (32).mid to x (32

In [6]:
#  We will parse all MIDI files, extract the main melody from each, and convert them into a format the model can learn from
#  Code borrowed from gemini -
def get_melody_from_midi_files():
    """
    Parses all MIDI files, intelligently extracts the melody track (usually the
    one with the highest average pitch), and returns a single list of all
    musical events (notes, chords, rests).
    """
    all_musical_events = []
    print("Parsing MIDI files and extracting melodies...")

    for file in tqdm(os.listdir(MIDI_DATA_PATH)):
        if file.endswith((".mid", ".midi")):
            midi_path = os.path.join(MIDI_DATA_PATH, file)
            try:
                score = converter.parse(midi_path)

                # --- Melody Extraction Logic ---
                best_part = None
                highest_avg_pitch = 0

                # Partition by instrument to analyze tracks separately
                parts = instrument.partitionByInstrument(score)
                if parts: # If file has distinct instrument parts
                    for part in parts:
                        if not part.flat.notes:
                            continue

                        pitches = [p.ps for p in part.flat.pitches]
                        if not pitches:
                            continue

                        avg_pitch = sum(pitches) / len(pitches)

                        # The track with the highest average pitch is likely the melody
                        if avg_pitch > highest_avg_pitch:
                            highest_avg_pitch = avg_pitch
                            best_part = part
                else: # If file is flat, use the whole score
                    best_part = score.flat.notes

                if best_part:
                    # Convert the melody part to a sequence of strings
                    for element in best_part:
                        if isinstance(element, note.Note):
                            all_musical_events.append(str(element.pitch))
                        elif isinstance(element, chord.Chord):
                            # Represent chords as dot-separated note pitches
                            all_musical_events.append('.'.join(str(n) for n in element.normalOrder))
                        elif isinstance(element, note.Rest):
                            all_musical_events.append('rest')

            except Exception as e:
                print(f"Could not process {file}: {e}")

    return all_musical_events

def create_sequences(events, event_to_int):
    """Creates input sequences and corresponding outputs for the model."""
    network_input, network_output = [], []
    print(f"\nCreating sequences of length {SEQUENCE_LENGTH}...")
    for i in tqdm(range(len(events) - SEQUENCE_LENGTH)):
        sequence_in = events[i:i + SEQUENCE_LENGTH]
        sequence_out = events[i + SEQUENCE_LENGTH]
        network_input.append([event_to_int[char] for char in sequence_in])
        network_output.append(event_to_int[sequence_out])

    # Convert to PyTorch tensors
    network_input = torch.tensor(network_input, dtype=torch.long)
    network_output = torch.tensor(network_output, dtype=torch.long)
    return network_input, network_output

# --- Main Preprocessing Execution ---
print("Starting preprocessing...")
musical_events = get_melody_from_midi_files()

if musical_events:
    print(f"\nFound {len(musical_events)} total musical events (notes, chords, rests).")

    # Create the vocabulary
    pitchnames = sorted(list(set(musical_events)))
    vocab_size = len(pitchnames)
    print(f"Vocabulary size: {vocab_size}")

    # Create mapping dictionaries
    event_to_int = {event: number for number, event in enumerate(pitchnames)}
    int_to_event = {number: event for number, event in enumerate(pitchnames)}

    # Create the sequences
    inputs, outputs = create_sequences(musical_events, event_to_int)

    # Save the processed data
    print("\nSaving processed data...")
    torch.save(inputs, os.path.join(OUTPUT_PATH, "inputs.pt"))
    torch.save(outputs, os.path.join(OUTPUT_PATH, "outputs.pt"))
    torch.save(int_to_event, os.path.join(OUTPUT_PATH, "int_to_event.pt"))
    torch.save(vocab_size, os.path.join(OUTPUT_PATH, "vocab_size.pt"))

    print(f"✅ Preprocessing complete. Data saved to '{OUTPUT_PATH}'")
else:
    print("⚠️ No musical events found. Please check your uploaded MIDI files.")

Starting preprocessing...
Parsing MIDI files and extracting melodies...


  0%|          | 0/50 [00:00<?, ?it/s]

  return self.iter().getElementsByClass(classFilterList)



Found 12405 total musical events (notes, chords, rests).
Vocabulary size: 133

Creating sequences of length 100...


  0%|          | 0/12305 [00:00<?, ?it/s]


Saving processed data...
✅ Preprocessing complete. Data saved to '/content/processed_data'


In [11]:
# Create the LTSM model -
class MusicLSTM(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout):
    super(MusicLSTM, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)

    self.lstm = nn.LSTM(
        embed_size,
        hidden_size,
        num_layers,
        batch_first=True,
        dropout=dropout,
        bidirectional=False
    )
    self.dropout = nn.Dropout(dropout)
    self.fc = nn.Linear(hidden_size, vocab_size)


  def forward(self, x): # Corrected indentation and placement of forward method

      x = self.embedding(x)

      out, _ = self.lstm(x)

      out = out[:, -1, :]

      out = self.dropout(out)

      out = self.fc(out)

      return out

In [None]:
# train the model -

#Load Data
print("Loading preprocessed data...")
inputs = torch.load(os.path.join(OUTPUT_PATH, 'inputs.pt'))
outputs = torch.load(os.path.join(OUTPUT_PATH, 'outputs.pt'))
vocab_size = torch.load(os.path.join(OUTPUT_PATH, 'vocab_size.pt'))

#Create dataset and dataloader for efficient training
dataset = TensorDataset(inputs, outputs)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

#Initialize Model, Loss, and Optimizer

model = MusicLSTM(vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

#Training Loop

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0

    #Use tqdm for a nice progress bar
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False)

    for batch_inputs, batch_targets in progress_bar:
        batch_inputs = batch_inputs
        batch_targets = batch_targets

        optimizer.zero_grad()
        output = model(batch_inputs)
        loss = loss_fn(output, batch_targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=f'{loss.item():.4f}')

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Average Loss: {avg_loss:.4f}")



final_model_path = os.path.join(MODEL_SAVE_PATH, 'music_model_final.pth')
torch.save(model.state_dict(), final_model_path)
print(f"\n✅ Training finished. Final model saved to {final_model_path}")

In [1]:
# Generate Music
# from gemini -

# Define the path to the trained model you want to load
MODEL_TO_LOAD = os.path.join(MODEL_SAVE_PATH, 'music_model_final.pth')

# Define the number of notes to generate
NOTES_TO_GENERATE = 500 # You can adjust this number

# Define the output filename for the generated MIDI
OUTPUT_FILENAME = 'generated_music.mid'


def generate_music():
    """
    Generates a musical sequence using the trained model and saves it as a MIDI file.
    """
    try:
        # --- Load required files ---
        print("Loading model and data mappings...")
        inputs = torch.load(os.path.join(OUTPUT_PATH, 'inputs.pt'))
        int_to_event = torch.load(os.path.join(OUTPUT_PATH, 'int_to_event.pt'))
        vocab_size = torch.load(os.path.join(OUTPUT_PATH, 'vocab_size.pt'))

        # --- Initialize Model ---
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = MusicLSTM(vocab_size, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT).to(device)
        model.load_state_dict(torch.load(MODEL_TO_LOAD, map_location=device))
        model.eval() # Set model to evaluation mode

        # --- Start Generation ---
        print("🎶 Generating new music...")

        # Pick a random starting sequence from the original data to seed the model
        start_index = np.random.randint(0, len(inputs) - 1)
        pattern = inputs[start_index].unsqueeze(0).to(device) # Shape: [1, SEQUENCE_LENGTH]

        prediction_output = []

        with torch.no_grad():
            for i in tqdm(range(NOTES_TO_GENERATE), desc="Generating notes"):
                prediction = model(pattern)

                # Use softmax to get probabilities and sample from the distribution
                # This adds more creativity than just taking the most likely note (argmax)
                softmax = torch.nn.functional.softmax(prediction, dim=1)
                predicted_index = torch.multinomial(softmax, 1)

                # Add the predicted note to our output list
                predicted_event = int_to_event[predicted_index.item()]
                prediction_output.append(predicted_event)

                # Update the input pattern for the next prediction:
                # Remove the first note and append the newly predicted note
                pattern = torch.cat((pattern[:, 1:], predicted_index), dim=1)

        print("\nConverting generated sequence to MIDI file...")
        output_notes = []
        for pattern_item in prediction_output:
            # If it's a chord (e.g., '60.64.67')
            if ('.' in pattern_item) or pattern_item.isdigit():
                notes_in_chord = pattern_item.split('.')
                notes = []
                for current_note in notes_in_chord:
                    try:
                        new_note = note.Note(int(current_note))
                        new_note.storedInstrument = instrument.Piano()
                        notes.append(new_note)
                    except: # Handle cases where a pitch name (e.g., 'C#4') might be in a chord string
                        new_note = note.Note(current_note)
                        new_note.storedInstrument = instrument.Piano()
                        notes.append(new_note)
                new_chord = chord.Chord(notes)
                output_notes.append(new_chord)
            # If it's a rest
            elif 'rest' in pattern_item:
                output_notes.append(note.Rest())
            # If it's a single note (e.g., 'C#4')
            else:
                new_note = note.Note(pattern_item)
                new_note.storedInstrument = instrument.Piano()
                output_notes.append(new_note)

        # Create a music21 stream and save it
        midi_stream = stream.Stream(output_notes)
        midi_stream.write('midi', fp=OUTPUT_FILENAME)

        print(f"\n✅ Success! MIDI file '{OUTPUT_FILENAME}' created.")
        print("You can find it in the file browser on the left. Download it and enjoy!")

    except FileNotFoundError:
        print("⚠️ Model or data not found. Please run all previous steps first.")
    except Exception as e:
        print(f"An error occurred during generation: {e}")

# --- Run the generation ---
generate_music()

NameError: name 'os' is not defined