## Package Set up

In [None]:
!pip install pretty_midi

In [None]:
import os
import requests
import collections
from zipfile import ZipFile

import pretty_midi
import numpy as np
import glob

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset,DataLoader,random_split
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Download Dataset

In [None]:
def download_dataset(dataset_url, save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    file_name = dataset_url.split('/')[-1]
    zip_path = os.path.join(save_path, file_name)

    if os.path.exists(os.path.join(save_path, dataset_url.split('/')[-1])):
        print("Dataset already downloaded.")
    else:
        print("Downloading dataset...")
        response = requests.get(dataset_url)
        with open(zip_path, 'wb') as f:
            f.write(response.content)

    if os.path.exists(os.path.join(save_path, 'POP909')):
        print("Dataset already extracted.")
    else:
        print("Extracting dataset...")
        with ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_path)

    print("Dataset downloaded and extracted successfully.")

In [None]:
# POP909 Dataset
data_URL = "https://raw.githubusercontent.com/music-x-lab/POP909-Dataset/master/POP909.zip"
data_path = "../data"
download_dataset(data_URL, data_path)

Downloading dataset...
Extracting dataset...
Dataset downloaded and extracted successfully.


In [None]:
musicFiles = glob.glob(os.path.join(data_path, '**/*.mid*'), recursive=True)
print('Number of files:', len(musicFiles))

Number of files: 2898


## Data Processing

In [None]:
def find_main_versions(data_path):
    main_files = []
    song_dirs = os.listdir(data_path)  # get all directory songs

    for song_dir in song_dirs:
        song_path = os.path.join(data_path, song_dir)
        if os.path.isdir(song_path):
            # Construct the path of the main version MIDI file
            main_file_path = os.path.join(song_path, song_dir + '.mid')
            if os.path.isfile(main_file_path):
                main_files.append(main_file_path)

    return main_files

In [None]:
# Set the data path
data_path = os.path.join(data_path, 'POP909')
mainFiles = find_main_versions(data_path)
print("The number of main versions is", len(mainFiles))

The number of main versions is 909


## Extract Main Melody & Split data

In [None]:
class MelodyDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, window_size=20, window_step=1):
        self.file_list = file_list
        self.window_size = window_size
        self.window_step = window_step
        self.data = self.process_midi_files()

    def __len__(self):
        return len(self.data)

    def process_midi_files(self):
        data = []
        for path in tqdm(self.file_list, desc="Processing MIDI files"):
            midi_data = pretty_midi.PrettyMIDI(path)
            melody_sequence = []

            for instrument in midi_data.instruments:
                label = 0
                if instrument.name.upper() == "MELODY":
                    label = 1
                for note in instrument.notes:
                    start_time = note.start
                    pitch = note.pitch
                    duration = note.end - note.start
                    melody_sequence.append((start_time, pitch, duration, label))

            # padding window_size-1 zeros to the front of the melody_sequence
            melody_sequence = sorted(melody_sequence, key=lambda x: x[0])
            melody_sequence = [(0, 0, 0, 0)] * (self.window_size - 1) + melody_sequence

            # Split the melody sequence into windows
            for j in range(0, len(melody_sequence) - self.window_size, self.window_step):
                data.append(melody_sequence[j:j+self.window_size])
        return data

    def __getitem__(self, idx):
        item = torch.tensor(self.data[idx], dtype=torch.float32)
        return item[:, :-1], item[-1, -1]

In [None]:
window_size = 20
full_dataset = MelodyDataset(mainFiles, window_size=window_size, window_step=2)  

Processing MIDI files: 100%|██████████| 909/909 [01:10<00:00, 12.89it/s]


In [None]:
# split dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
print(full_dataset[0])
print(len(full_dataset))

(tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.6364, 62.0000,  0.4716]]), tensor(0.))
766480


## Bidirectional LSTM Model

In [None]:
class biLSTM(nn.Module):
    def __init__(self, input_shape):
        super(biLSTM, self).__init__()
        self.lstm1 = nn.LSTM(input_shape, 256, batch_first=True, bidirectional=True)
        self.batch_norm1 = nn.BatchNorm1d(256*2)
        self.dropout1 = nn.Dropout(0.3)
        self.lstm2 = nn.LSTM(256*2, 256, batch_first=True, bidirectional=True)
        self.batch_norm2 = nn.BatchNorm1d(256*2)
        self.dropout2 = nn.Dropout(0.3)
        self.lstm3 = nn.LSTM(256*2, 128, batch_first=True, bidirectional=True)
        self.batch_norm3 = nn.BatchNorm1d(128*2)
        self.dropout3 = nn.Dropout(0.3)
        self.time_distributed = nn.Linear(128*2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x, _ = self.lstm1(x)
        x = x.permute(0, 2, 1)
        x = self.batch_norm1(x)
        x = x.permute(0, 2, 1)
        x = self.dropout1(x)
        x, _ = self.lstm2(x)
        x = x.permute(0, 2, 1)
        x = self.batch_norm2(x)
        x = x.permute(0, 2, 1)
        x = self.dropout2(x)
        x, _ = self.lstm3(x)
        x = x.permute(0, 2, 1)
        x = self.batch_norm3(x)
        x = x.permute(0, 2, 1)
        x = self.dropout3(x)
        x = self.time_distributed(x)
        x = self.sigmoid(x)
        return x[:, -1, :]

In [None]:
# Define model
model = biLSTM(3).to(device)

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)

## Training Part

In [None]:
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        # Train
        model.train()
        total_train_loss = 0
        correct_train = 0
        total_train = 0
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1} Training"):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.squeeze()
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

            # Calculate accuracy
            preds = outputs >= 0.5  # Assuming binary classification
            correct_train += (preds == targets).sum().item()
            total_train += targets.size(0)

        avg_train_loss = total_train_loss / len(train_loader)
        train_accuracy = correct_train / total_train
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_accuracy)

        # Validation
        model.eval()
        total_val_loss = 0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1} Validation"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                outputs = outputs.squeeze()
                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

                # Calculate accuracy
                preds = outputs >= 0.5  # Assuming binary classification
                correct_val += (preds == targets).sum().item()
                total_val += targets.size(0)

        avg_val_loss = total_val_loss / len(val_loader)
        val_accuracy = correct_val / total_val
        val_losses.append(avg_val_loss)
        val_accuracies.append(val_accuracy)

        # Print loss and accuracy
        print(f"Epoch {epoch + 1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}, Train Acc = {train_accuracy:.4f}, Val Acc = {val_accuracy:.4f}")
        torch.save(model.state_dict(), "model-{}.pth".format(epoch))
        torch.save(model, "model-{}.pkl".format(epoch))

    return train_losses, val_losses, train_accuracies, val_accuracies

In [None]:
train_losses, val_losses, train_accuracies, val_accuracies = train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs=50)

In [None]:
def plot_losses_and_accuracies(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 6))

    # Plot Losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot Accuracies
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
plot_losses_and_accuracies(train_losses, val_losses, train_accuracies, val_accuracies)

## Prediction & Output MIDI file

In [None]:
# Use the trained model to predict
def predict(model, sequence):
    model.eval()
    with torch.no_grad():
        outputs = model(sequence)
    return outputs

In [None]:
def process_midi_file(path):
    midi_data = pretty_midi.PrettyMIDI(path)
    melody_sequence = []
    for instrument in midi_data.instruments:
        # Only consider the piano
        if 0 <= instrument.program <= 7:   # In PrettyMIDI, 0-7 refers to piano
            for note in instrument.notes:
                start_time = note.start
                pitch = note.pitch
                duration = note.end - note.start
                melody_sequence.append((start_time, pitch, duration))
    melody_sequence = sorted(melody_sequence, key=lambda x: x[0])
    return melody_sequence

def split_sequence(melody_sequence, window_size):
    data_sequence = []
    melody_sequence = [(0, 0, 0)] * (window_size - 1) + melody_sequence

    for j in range(0, len(melody_sequence) - window_size + 1):
        data_sequence.append(melody_sequence[j:j+window_size])
    return data_sequence

In [None]:
midi_sequence = process_midi_file(mainFiles[0])  # example
sequence = split_sequence(midi_sequence, window_size)
sequence = torch.tensor(sequence, dtype=torch.float32).to(device)

In [None]:
outputs = predict(model, sequence)

In [None]:
# Convert to MIDI format
def notes_to_midi(midi_sequence, outputs, output_path):
    midi = pretty_midi.PrettyMIDI()
    piano = pretty_midi.Instrument(program=0)

    for i, (start_time, pitch, duration) in enumerate(midi_sequence):
        if outputs[i] > 0.5:  # Threshold determine the main theme
            note = pretty_midi.Note(
                velocity=100, pitch=int(pitch), start=start_time, end=start_time + duration
            )
            piano.notes.append(note)

    midi.instruments.append(piano)
    midi.write(output_path)

In [None]:
notes_to_midi(midi_sequence, outputs, "output.mid")