In [None]:
import torch
import pathlib
import zipfile
from model_handler import ModelHandler

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mh = ModelHandler(device)
print(f"Running on: {device}")

In [None]:

data_path = pathlib.Path("../data")
if not data_path.exists():
    data_path.mkdir()

maestro_data = pathlib.Path("../data/maestro-v2.0.0")
if not maestro_data.exists():
    torch.hub.download_url_to_file(
        "https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip",
        "../data/maestro-v2.0.0-midi.zip",
    )

    with zipfile.ZipFile("../data/maestro-v2.0.0-midi.zip", "r") as zip_ref:
        zip_ref.extractall("../data")
        pathlib.Path("../data/maestro-v2.0.0-midi.zip").unlink()

In [None]:
mh.midi.load_files("../data/maestro/**/*.mid*")
print(f"Number of Samples: {len(mh.midi.files)}")

In [None]:
train_notes = mh.get_train_notes("../data/notes.pt", 0, 5)
print("Number of notes parsed:", len(train_notes))

In [None]:
epochs = 50
batch_size = 16
learning_rate = 0.0005

data, loader, model, criterion, optimizer = mh.create_model(
    epochs, batch_size, learning_rate
)

In [None]:
mh.train_model()

In [None]:
raw_notes = mh.midi.get_notes(mh.midi.files[0])
num_predictions = 120
seq_length = 25
temperature = 4.0

generated_notes, out_pm = mh.generate_notes(
    raw_notes, num_predictions, seq_length, temperature
)

In [None]:
mh.midi.display_audio(out_pm)

In [None]:
mh.midi.display_midi("output.mid")