In [7]:
import os
import sys
import yaml
notebook_dir = os.path.dirname(os.path.abspath("__file__"))
parent_dir = os.path.dirname(os.path.dirname(notebook_dir))
sys.path.append(parent_dir)

import torch
from mido import Message, MidiFile, MidiTrack
from src.models import *
from src.datasets import *

In [10]:
# Change this to directory with the .yaml config file
yaml_path = "../../configs/dmm/jsb/standard.yaml"
with open(yaml_path, "r", encoding="utf-8") as yaml_file:
    config = yaml.load(yaml_file, Loader=yaml.FullLoader)

In [12]:
def generate_midi_file(sample, midi_path):
    # Create a MIDI file
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)

    # Populate the MIDI file
    notes = [False] * 88
    for idx, time_step in enumerate(sample):
        first = True
        for note, velocity in enumerate(time_step):
            if velocity > 0 and notes[note] == False:
                time = 200 if first else 0
                first = False
                track.append(Message("note_on", note=note + 21, velocity=30, time=time))
                notes[note] = True
            elif velocity == 0 and notes[note] == True:
                time = 200 if first else 0
                first = False
                track.append(Message("note_off", note=note + 21, velocity=0, time=time))
                notes[note] = False

    # Save the MIDI file
    mid.save(midi_path)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = eval(config["model_name"])(**config["model_params"])
dataloaders = eval(config["dataset_name"])(**config["dataset_params"])

model_path = os.path.join("../../", config["trainer_params"]["save_path"])

model.load_state_dict(torch.load(model_path, map_location=device))
start = torch.zeros((1, 1, 88))
out = model.generate(start, 100)
sample = out[0].detach().cpu().numpy()
generate_midi_file(sample, "generated.mid")

gt_samples = next(iter(dataloaders["train"]))[0]
gt_sample = samples[0].detach().cpu().numpy()
generate_midi_file(sample, "ground_truth.mid")