## Package Set up

In [None]:
!pip install pretty_midi

In [2]:
import os
import requests
import collections
from zipfile import ZipFile
import torch.nn.functional as F

import pretty_midi
import numpy as np
import glob

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader,random_split
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Download Dataset

In [4]:
def download_dataset(dataset_url, save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    file_name = dataset_url.split('/')[-1]
    zip_path = os.path.join(save_path, file_name)

    if os.path.exists(os.path.join(save_path, dataset_url.split('/')[-1])):
        print("Dataset already downloaded.")
    else:
        print("Downloading dataset...")
        response = requests.get(dataset_url)
        with open(zip_path, 'wb') as f:
            f.write(response.content)

    if os.path.exists(os.path.join(save_path, 'POP909')):
        print("Dataset already extracted.")
    else:
        print("Extracting dataset...")
        with ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_path)

    print("Dataset downloaded and extracted successfully.")

In [5]:
# Maestro Dataset
data_URL = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip"
data_path = "./maestro_dataset"
download_dataset(data_URL, data_path)

Downloading dataset...
Extracting dataset...
Dataset downloaded and extracted successfully.


In [6]:
musicFile = glob.glob(os.path.join(data_path, '**/*.mid*'), recursive=True)
print('Number of files:', len(musicFile))

Number of files: 1276


## Data Processing

In [7]:
def midi_to_notes(midi_file: str) -> pd.DataFrame:  # return daraframe object
    pm = pretty_midi.PrettyMIDI(midi_file)
    instrument = pm.instruments[0]
    notes = collections.defaultdict(list)

    # Sort the notes by start time
    sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
    prev_start = sorted_notes[0].start

    for note in sorted_notes:
        start = note.start
        end = note.end
        notes['pitch'].append(note.pitch)
        notes['start'].append(start)
        notes['end'].append(end)
        notes['velocity'].append(note.velocity)
        prev_start = start

    return pd.DataFrame({name: np.array(value) for name, value in notes.items()})

In [8]:
num_files = len(musicFile)
all_notes = []
for f in tqdm(musicFile[:num_files], desc='Processing files'):
    notes = midi_to_notes(f)
    all_notes.append(notes)

Processing files: 100%|██████████| 1276/1276 [07:17<00:00,  2.91it/s]


In [9]:
print(all_notes[0])

      pitch       start         end  velocity
0        48    0.957031   10.829427        56
1        36    0.970052   10.820312        52
2        72    1.617188    1.983073        61
3        74    1.920573    2.225260        66
4        76    2.164062    2.450521        71
...     ...         ...         ...       ...
1920     72  237.759115  239.765625        87
1921     60  237.769531  239.824219        73
1922     67  237.773438  239.816406        69
1923     64  237.773438  239.826823        73
1924     36  237.796875  239.841146        70

[1925 rows x 4 columns]


## Extract Main Melody & Split data

In [10]:
class MusicDataset(Dataset):
    def __init__(self, notes, window_size=10):
        self.notes = notes
        self.window_size = window_size
        self.data = self.load_data()

    def load_data(self):
        data = []
        for song in self.notes:
            # Beacuse the window size the information of the first nine notes will lost, add nine rest before the audio
            song_padding = np.pad(song, ((self.window_size-1, 0), (0, 0)), mode='constant', constant_values=0)
            for i in range(len(song)):  # song_padding -（window_size-1） stride = 1
                window = song_padding[i:i+self.window_size][:, :-1]  # Delete velocity
                label = True if song_padding[i+self.window_size-1][-1] > threshold else False
                data.append((window, label))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        window, label = self.data[idx]
        return torch.tensor(window, dtype=torch.float), torch.tensor(label, dtype=torch.bool)

In [22]:
threshold = 65   
window_size = 20

dataset = MusicDataset(all_notes, window_size = window_size)

In [25]:
print(f"Example: {dataset[10]}")
print(f"Number of samples: {len(dataset)}")

Number of samples: 70401


In [26]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

## LSTM Model

In [27]:
class BasicLSTM(nn.Module):
    def __init__(self, input_shape):
        super(BasicLSTM, self).__init__()
        self.lstm1 = nn.LSTM(input_shape, 128, batch_first=True)
        self.dropout1 = nn.Dropout(0.2)
        self.lstm2 = nn.LSTM(128, 128, batch_first=True)
        self.dropout2 = nn.Dropout(0.2)
        self.time_distributed = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x, _ = self.lstm1(x)
        x = self.dropout1(x)
        x, _ = self.lstm2(x)
        x = self.dropout2(x)
        x = self.time_distributed(x)
        x = self.sigmoid(x)
        return x[:, -1, :]

## Model define

In [28]:
# Define model
model = BasicLSTM(3).to(device)

In [29]:
# binary cross entropy loss
criterion = nn.BCELoss()

# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

## Training Part

In [30]:
history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}

In [36]:
def train_and_validate(epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        total_train = 0

        for i, (data, label) in tqdm(enumerate(train_loader), desc=f'Epoch {epoch+1}', total=len(train_loader)):
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, label.unsqueeze(1).float()) # loss computation
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            preds = output >= 0.5
            train_correct += (preds == label.unsqueeze(1)).sum().item()
            total_train += label.size(0)

        train_accuracy = train_correct / total_train
        print(f'Epoch {epoch+1} Loss: {train_loss/len(train_loader)} Accuracy: {train_accuracy}')

        model.eval()
        val_loss = 0
        val_correct = 0
        total_val = 0

        with torch.no_grad():
            for i, (data, label) in enumerate(val_loader):
                data, label = data.to(device), label.to(device)
                output = model(data)
                loss = criterion(output, label.unsqueeze(1).float())
                val_loss += loss.item()

                preds = output >= 0.5
                val_correct += (preds == label.unsqueeze(1)).sum().item()
                total_val += label.size(0)

        val_accuracy = val_correct / total_val
        print(f'Epoch {epoch+1} Validation Loss: {val_loss/len(val_loader)} Validation Accuracy: {val_accuracy}')

        history['train_loss'].append(train_loss/len(train_loader))
        history['val_loss'].append(val_loss/len(val_loader))
        history['train_accuracy'].append(train_accuracy)
        history['val_accuracy'].append(val_accuracy)

In [None]:
train_and_validate(10)

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_accuracy'], label='Train Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

## Prediction & Output MIDI file

In [None]:
# Use the trained model to predict
def predict(model, sequence):
    model.eval()
    with torch.no_grad():
        outputs = model(sequence)
    return outputs

In [None]:
def process_midi_file(path):
    midi_data = pretty_midi.PrettyMIDI(path)
    melody_sequence = []
    for instrument in midi_data.instruments:
        # Only consider the piano
        if 0 <= instrument.program <= 7:   # In PrettyMIDI, 0-7 refers to piano
            for note in instrument.notes:
                start_time = note.start
                pitch = note.pitch
                duration = note.end - note.start
                melody_sequence.append((start_time, pitch, duration))
    melody_sequence = sorted(melody_sequence, key=lambda x: x[0])
    return melody_sequence

def split_sequence(melody_sequence, window_size):
    data_sequence = []
    melody_sequence = [(0, 0, 0)] * (window_size - 1) + melody_sequence

    for j in range(0, len(melody_sequence) - window_size + 1):
        data_sequence.append(melody_sequence[j:j+window_size])
    return data_sequence

In [None]:
midi_sequence = process_midi_file(musicFile[0])

sequence = split_sequence(midi_sequence, window_size)
sequence = torch.tensor(sequence, dtype=torch.float32).to(device)

In [None]:
outputs = predict(model, sequence)

In [None]:
# Convert to MIDI format
def notes_to_midi(midi_sequence, outputs, output_path):
    midi = pretty_midi.PrettyMIDI()
    piano = pretty_midi.Instrument(program=0)

    for i, (start_time, pitch, duration) in enumerate(midi_sequence):
        if outputs[i] > 0.5:  # Threshold determine the main theme
            note = pretty_midi.Note(
                velocity=100, pitch=int(pitch), start=start_time, end=start_time + duration
            )
            piano.notes.append(note)

    midi.instruments.append(piano)
    midi.write(output_path)

In [None]:
notes_to_midi(midi_sequence, outputs, "output.mid")