# Model

## Imports

In [None]:
%pip install mido

Collecting mido
  Downloading mido-1.3.2-py3-none-any.whl.metadata (6.4 kB)
Collecting packaging~=23.1 (from mido)
  Downloading packaging-23.2-py3-none-any.whl.metadata (3.2 kB)
Downloading mido-1.3.2-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m896.3 kB/s[0m eta [36m0:00:00[0m
[?25hDownloading packaging-23.2-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.0/53.0 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: packaging, mido
  Attempting uninstall: packaging
    Found existing installation: packaging 24.1
    Uninstalling packaging-24.1:
      Successfully uninstalled packaging-24.1
Successfully installed mido-1.3.2 packaging-23.2


In [None]:
import os
import argparse

import torchaudio
import pickle
import mido
import json

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import time

import datetime

## Model

In [None]:
class Model_SPEC2MIDI(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder_spec2midi = encoder
        self.decoder_spec2midi = decoder

    def forward(self, input_spec):
        #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192)
        #print('Model_SPEC2MIDI(0) input_spec: '+str(input_spec.shape))

        enc_vector = self.encoder_spec2midi(input_spec)
        #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256)
        #print('Model_SPEC2MIDI(1) enc_vector: '+str(enc_vector.shape))

        output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector)
        #output_onset_A = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_onset_B = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_velocity_A = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128)
        #output_velocity_B = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128)
        #attention = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256)
        #print('Model_SPEC2MIDI(2) output_onset_A: '+str(output_onset_A.shape))
        #print('Model_SPEC2MIDI(2) output_onset_B: '+str(output_onset_B.shape))
        #print('Model_SPEC2MIDI(2) output_velocity_A: '+str(output_velocity_A.shape))
        #print('Model_SPEC2MIDI(2) output_velocity_B: '+str(output_velocity_B.shape))
        #print('Model_SPEC2MIDI(2) attention: '+str(attention.shape))

        return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B


##
## Encoder
##
class Encoder_SPEC2MIDI(nn.Module):
    def __init__(self, n_margin, n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
        super().__init__()

        self.device = device
        self.n_frame = n_frame
        self.n_bin = n_bin
        self.cnn_channel = cnn_channel
        self.cnn_kernel = cnn_kernel
        self.hid_dim = hid_dim
        self.conv = nn.Conv2d(1, self.cnn_channel, kernel_size=(1, self.cnn_kernel))
        self.n_proc = n_margin * 2 + 1
        self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1))
        self.tok_embedding_freq = nn.Linear(self.cnn_dim, hid_dim)
        self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim)
        self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, spec_in):
        #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128)
        #print('Encoder_SPEC2MIDI(0) spec_in: '+str(spec_in.shape))
        batch_size = spec_in.shape[0]

        spec = spec_in.unfold(2, self.n_proc, 1).permute(0, 2, 1, 3).contiguous()
        #spec = [batch_size, n_frame, n_bin, n_proc] (8, 128, 256, 65) (batch_size=8, n_frame=128, n_bins=256, n_proc=65)
        #print('Encoder_SPEC2MIDI(1) spec: '+str(spec.shape))

        # CNN 1D
        spec_cnn = spec.reshape(batch_size*self.n_frame, self.n_bin, self.n_proc).unsqueeze(1)
        #spec = [batch_size*n_frame, 1, n_bin, n_proc] (8*128, 1, 256, 65) (batch_size=128, 1, n_frame, n_bins=256, n_proc=65)
        #print('Encoder_SPEC2MIDI(2) spec_cnn: '+str(spec_cnn.shape))
        spec_cnn = self.conv(spec_cnn).permute(0, 2, 1, 3).contiguous()
        # spec_cnn: [batch_size*n_frame, n_bin, cnn_channel, n_proc-(cnn_kernel-1)] (8*128, 256, 4, 61)
        #print('Encoder_SPEC2MIDI(2) spec_cnn: '+str(spec_cnn.shape))

        ##
        ## frequency
        ##
        spec_cnn_freq = spec_cnn.reshape(batch_size*self.n_frame, self.n_bin, self.cnn_dim)
        # spec_cnn_freq: [batch_size*n_frame, n_bin, cnn_channel, (n_proc)-(cnn_kernel-1)] (8*128, 256, 244)
        #print('Encoder_SPEC2MIDI(3) spec_cnn_freq: '+str(spec_cnn_freq.shape))

        # embedding
        spec_emb_freq = self.tok_embedding_freq(spec_cnn_freq)
        # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256)
        #print('Encoder_SPEC2MIDI(4) spec_emb_freq: '+str(spec_emb_freq.shape))

        # position coding
        pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device)
        #pos_freq = [batch_size, n_frame, n_bin] (8*128, 256)
        #print('Encoder_SPEC2MIDI(5) pos_freq: '+str(pos_freq.shape))

        # embedding
        spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq))
        #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256)
        #print('Encoder_SPEC2MIDI(6) spec_freq: '+str(spec_freq.shape))

        # transformer encoder
        for layer_freq in self.layers_freq:
            spec_freq = layer_freq(spec_freq)
        spec_freq = spec_freq.reshape(batch_size, self.n_frame, self.n_bin, self.hid_dim)
        #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256)
        #print('Encoder_SPEC2MIDI(7) spec_freq: '+str(spec_freq.shape))

        return spec_freq


##
## Decoder
##
class Decoder_SPEC2MIDI(nn.Module):
    def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.device = device
        self.n_note = n_note
        self.n_frame = n_frame
        self.n_velocity = n_velocity
        self.n_bin = n_bin
        self.hid_dim = hid_dim
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout)

        # CAfreq
        self.pos_embedding_freq = nn.Embedding(n_note, hid_dim)
        self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device)
        self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)])

        self.fc_onset_freq = nn.Linear(hid_dim, 1)
        self.fc_offset_freq = nn.Linear(hid_dim, 1)
        self.fc_mpe_freq = nn.Linear(hid_dim, 1)
        self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity)

        # SAtime
        self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.pos_embedding_time = nn.Embedding(n_frame, hid_dim)
        #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])

        self.fc_onset_time = nn.Linear(hid_dim, 1)
        self.fc_offset_time = nn.Linear(hid_dim, 1)
        self.fc_mpe_time = nn.Linear(hid_dim, 1)
        self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity)

    def forward(self, enc_spec):
        batch_size = enc_spec.shape[0]
        enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim])
        #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256)
        #print('Decoder_SPEC2MIDI(0) enc_spec: '+str(enc_spec.shape))

        ##
        ## CAfreq freq(256)/note(88)
        ##
        pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device)
        midi_freq = self.pos_embedding_freq(pos_freq)
        #pos_freq = [batch_size*n_frame, n_note] (8*128, 88)
        #midi_freq = [batch_size, n_note, hid_dim] (8*128, 88, 256)
        #print('Decoder_SPEC2MIDI(1) pos_freq: '+str(pos_freq.shape))
        #print('Decoder_SPEC2MIDI(1) midi_freq: '+str(midi_freq.shape))

        midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq)
        for layer_freq in self.layers_freq:
            midi_freq, attention_freq = layer_freq(enc_spec, midi_freq)
        dim = attention_freq.shape
        attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]])
        #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256)
        #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256)
        #print('Decoder_SPEC2MIDI(2) midi_freq: '+str(midi_freq.shape))
        #print('Decoder_SPEC2MIDI(2) attention_freq: '+str(attention_freq.shape))

        ## output(freq)
        output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note]))
        output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note]))
        output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note]))
        output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity])
        #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128)
        #print('Decoder_SPEC2MIDI(3) output_onset_freq: '+str(output_onset_freq.shape))
        #print('Decoder_SPEC2MIDI(3) output_offset_freq: '+str(output_offset_freq.shape))
        #print('Decoder_SPEC2MIDI(3) output_mpe_freq: '+str(output_mpe_freq.shape))
        #print('Decoder_SPEC2MIDI(3) output_velocity_freq: '+str(output_velocity_freq.shape))

        ##
        ## SAtime time(64)
        ##
        #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim]
        midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim])
        pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device)
        midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time))
        #pos_time = [batch_size*n_note, n_frame] (8*88, 128)
        #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256)
        #print('Decoder_SPEC2MIDI(4) pos_time: '+str(pos_time.shape))
        #print('Decoder_SPEC2MIDI(4) midi_time: '+str(midi_time.shape))

        for layer_time in self.layers_time:
            midi_time = layer_time(midi_time)
        #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256)
        #print('Decoder_SPEC2MIDI(5) midi_time: '+str(midi_time.shape))

        ## output(time)
        output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous())
        output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous())
        output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous())
        output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous()
        #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88)
        #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128)
        #print('Decoder_SPEC2MIDI(6) output_onset_time: '+str(output_onset_time.shape))
        #print('Decoder_SPEC2MIDI(6) output_offset_time: '+str(output_offset_time.shape))
        #print('Decoder_SPEC2MIDI(6) output_mpe_time: '+str(output_mpe_time.shape))
        #print('Decoder_SPEC2MIDI(6) output_velocity_time: '+str(output_velocity_time.shape))

        return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, attention_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time


##
## sub functions
##
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        #src = [batch_size, src_len, hid_dim]

        #self attention
        _src, _ = self.self_attention(src, src, src)
        #dropout, residual connection and layer norm
        src = self.layer_norm(src + self.dropout(_src))
        #src = [batch_size, src_len, hid_dim]

        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        #dropout, residual and layer norm
        src = self.layer_norm(src + self.dropout(_src))
        #src = [batch_size, src_len, hid_dim]

        return src

class DecoderLayer_Zero(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.layer_norm = nn.LayerNorm(hid_dim)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, enc_src, trg):
        #trg = [batch_size, trg_len, hid_dim]
        #enc_src = [batch_size, src_len, hid_dim]

        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src)
        #dropout, residual connection and layer norm
        trg = self.layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg_len, hid_dim]

        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        #dropout, residual and layer norm
        trg = self.layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg_len, hid_dim]
        #attention = [batch_size, n_heads, trg_len, src_len]

        return trg, attention

class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, enc_src, trg):
        #trg = [batch_size, trg_len, hid_dim]
        #enc_src = [batch_size, src_len, hid_dim]

        #self attention
        _trg, _ = self.self_attention(trg, trg, trg)
        #dropout, residual connection and layer norm
        trg = self.layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg_len, hid_dim]

        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src)
        #dropout, residual connection and layer norm
        trg = self.layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg_len, hid_dim]

        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        #dropout, residual and layer norm
        trg = self.layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg_len, hid_dim]
        #attention = [batch_size, n_heads, trg_len, src_len]

        return trg, attention

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, query, key, value):
        batch_size = query.shape[0]
        #query = [batch_size, query_len, hid_dim]
        #key = [batch_size, key_len, hid_dim]
        #value = [batch_size, value_len, hid_dim]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        #Q = [batch_size, query_len, hid_dim]
        #K = [batch_size, key_len, hid_dim]
        #V = [batch_size, value_len, hid_dim]

        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        #Q = [batch_size, n_heads, query_len, head_dim]
        #K = [batch_size, n_heads, key_len, head_dim]
        #V = [batch_size, n_heads, value_len, head_dim]

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        #energy = [batch_size, n_heads, seq len, seq len]

        attention = torch.softmax(energy, dim = -1)
        #attention = [batch_size, n_heads, query_len, key_len]

        x = torch.matmul(self.dropout(attention), V)
        #x = [batch_size, n_heads, seq len, head_dim]

        x = x.permute(0, 2, 1, 3).contiguous()
        #x = [batch_size, seq_len, n_heads, head_dim]

        x = x.view(batch_size, -1, self.hid_dim)
        #x = [batch_size, seq_len, hid_dim]

        x = self.fc_o(x)
        #x = [batch_size, seq_len, hid_dim]

        return x, attention

class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        #x = [batch_size, seq_len, hid_dim]

        x = self.dropout(torch.relu(self.fc_1(x)))
        #x = [batch_size, seq_len, pf dim]

        x = self.fc_2(x)
        #x = [batch_size, seq_len, hid_dim]

        return x

## Training

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
DATASET_DIR = '/content/drive/MyDrive/Apollo/dataset-smaller'

In [None]:
# Constants

SR = 16000
FFT_BINS = 2048
WINDOW_LENGTH = 2048
HOP_SAMPLE = 256
PAD_MODE = 'constant'
MEL_BINS = 256
N_BINS = 256
LOG_OFFSET = 1e-8
WINDOWS = 'hann'

MARGIN_B = 32
MARGIN_F = 32
NUM_FRAME = 128

NOTE_MIN = 21
NOTE_MAX = 108
NUM_NOTE = 88
NUM_VELOCITY = 128

DIV_TRAIN = 4
DIV_VALID = 1
DIV_TEST = 1
EPOCH = 20
BATCH = 8
N_SLICE = 16
WEIGHT_A = 1.0
WEIGHT_B = 1.0
CLIP = 1.0

SEED = 1234
CNN_CHANNEL = 4
CNN_KERNEL = 5
HID_DIM = 256
ENC_LAYER = 3
ENC_HEAD = 4
PF_DIM = 512
DROPOUT = 0.1
DEC_LAYER = 3
DEC_HEAD = 4
LR = 1e-4
RESUME_EPOCH = -1
RESUME_DIV = -1
VALID_TEST = False

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [None]:
t0 = time.time()

# (1) torch settings
print('(1) torch settings')
print(' torch version    : '+torch.__version__)
print(' torch cuda       : '+str(torch.cuda.is_available()))
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
#torch.cuda.set_device(device)
print(device)


# (2) network settings
print('(2) network settings')
encoder = Encoder_SPEC2MIDI(MARGIN_B,
                            NUM_FRAME,
                            N_BINS,
                            CNN_CHANNEL,
                            CNN_KERNEL,
                            HID_DIM,
                            ENC_LAYER,
                            ENC_HEAD,
                            PF_DIM,
                            DROPOUT,
                            device)
decoder = Decoder_SPEC2MIDI(NUM_FRAME,
                            N_BINS,
                            NUM_NOTE,
                            NUM_VELOCITY,
                            HID_DIM,
                            DEC_LAYER,
                            DEC_HEAD,
                            PF_DIM,
                            DROPOUT,
                            device)
model = Model_SPEC2MIDI(encoder, decoder)
model = model.to(device)
model.apply(initialize_weights);
print(' The model has {} trainable parameters'.format(count_parameters(model)))

(1) torch settings
 torch version    : 2.4.0+cu121
 torch cuda       : True
cuda
(2) network settings
 The model has 5516574 trainable parameters


In [None]:

print(device)

cuda


In [None]:
CHECKPOINT_DIR = '/content/drive/MyDrive/Apollo/Checkpoint3'

In [None]:
# (3) training settings
print('(3) training settings')
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

criterion_onset_A = nn.BCELoss()
criterion_offset_A = nn.BCELoss()
criterion_mpe_A = nn.BCELoss()
criterion_velocity_A = nn.CrossEntropyLoss()

criterion_onset_B = nn.BCELoss()
criterion_offset_B = nn.BCELoss()
criterion_mpe_B = nn.BCELoss()
criterion_velocity_B = nn.CrossEntropyLoss()

d_out = CHECKPOINT_DIR
if not os.path.exists(d_out):
    os.mkdir(d_out)
parameters = {
    'parameters': count_parameters(model),
    'd_output': CHECKPOINT_DIR,
    'dataset': {
        'd_dataset': DATASET_DIR,
        'n_div_train': DIV_TRAIN,
        'n_div_valid': DIV_VALID,
        'n_div_test': DIV_TEST,
        'n_slice': N_SLICE
    },
    'training': {
        'epoch': EPOCH,
        'batch': BATCH,
        'lr': LR,
        'dropout': DROPOUT,
        'clip': CLIP,
        'seed': SEED,
        'resume_epoch': RESUME_EPOCH,
        'resume_div': RESUME_DIV,
        'loss_weight': {
            '1st': WEIGHT_A,
            '2nd': WEIGHT_B
        },
        'validation': {
            'test': VALID_TEST
        }
    },
    'transformer': {
        'hid_dim': HID_DIM,
        'pf_dim': PF_DIM,
        'encoder': {
            'n_layer': ENC_LAYER,
            'n_head': ENC_HEAD
        },
        'decoder': {
            'n_layer': DEC_LAYER,
            'n_head': DEC_HEAD
        }
    },
    'cnn': {
        'channel': CNN_CHANNEL,
        'kernel': CNN_KERNEL
    }
}
with open(d_out+'/parameter.json', 'w', encoding='utf-8') as f:
    json.dump(parameters, f, ensure_ascii=False, indent=4, sort_keys=True)

epoch_start = 0
div_start = 0
best_epoch = 0
best_div = 0
best_loss_valid = float('inf')
a_performance = {
    'loss_train': [],
    'loss_valid': [],
    'loss_test': [],
    'datetime': [],
    'current_epoch': 0,
    'current_div': 0,
    'best_epoch': best_epoch,
    'best_div': best_div,
    'best_loss_valid': best_loss_valid
}

(3) training settings


In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, f_feature, f_label_onset, f_label_offset, f_label_mpe, f_label_velocity, f_idx, n_slice):
        super().__init__()

        with open(f_feature, 'rb') as f:
            feature = pickle.load(f)

        with open(f_label_onset, 'rb') as f:
            label_onset = pickle.load(f)
        with open(f_label_offset, 'rb') as f:
            label_offset = pickle.load(f)
        with open(f_label_mpe, 'rb') as f:
            label_mpe = pickle.load(f)
        if f_label_velocity is not None:
            self.flag_velocity = True
            with open(f_label_velocity, 'rb') as f:
                label_velocity = pickle.load(f)
        else:
            self.flag_velocity = False

        with open(f_idx, 'rb') as f:
            idx = pickle.load(f)

        self.feature = torch.from_numpy(feature)
        self.label_onset = torch.from_numpy(label_onset)
        self.label_offset = torch.from_numpy(label_offset)
        self.label_mpe = torch.from_numpy(label_mpe)
        if self.flag_velocity:
            self.label_velocity = torch.from_numpy(label_velocity)
        if n_slice > 1:
            idx_tmp = torch.from_numpy(idx)
            self.idx = idx_tmp[:int(len(idx_tmp) / n_slice) * n_slice][::n_slice]
        else:
            self.idx = torch.from_numpy(idx)
        self.data_size = len(self.idx)

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        # margin: 32
        # num_frame: 128
        idx_feature_s = self.idx[idx] - MARGIN_B
        idx_feature_e = self.idx[idx] + NUM_FRAME + MARGIN_F

        idx_label_s = self.idx[idx]
        idx_label_e = self.idx[idx] + NUM_FRAME

        # a_feature: [margin+num_frame+margin, n_feature] -(transpose)-> spec: [n_feature, margin+num_frame+margin]
        spec = (self.feature[idx_feature_s:idx_feature_e]).T

        # label_onset: [num_frame, n_note]
        label_onset = self.label_onset[idx_label_s:idx_label_e]

        # label_offset: [num_frame, n_note]
        label_offset = self.label_offset[idx_label_s:idx_label_e]

        # label_mpe: [num_frame, n_note]
        # bool -> float
        label_mpe = self.label_mpe[idx_label_s:idx_label_e].float()

        # label_velocity: [num_frame, n_note]
        # int8 -> long
        if self.flag_velocity:
            label_velocity = self.label_velocity[idx_label_s:idx_label_e].long()
            return spec, label_onset, label_offset, label_mpe, label_velocity
        else:
            return spec, label_onset, label_offset, label_mpe

In [None]:
# (5) dataset loading (w/o divide)
print('(5) dataset loading')
d_dataset = DATASET_DIR
if DIV_TRAIN <= 1:
    dataset_train = MyDataset(d_dataset + '/feature/train.pkl',
                                        d_dataset + '/label_onset/train.pkl',
                                        d_dataset + '/label_offset/train.pkl',
                                        d_dataset + '/label_mpe/train.pkl',
                                        d_dataset + '/label_velocity/train.pkl',
                                        d_dataset + '/idx/train.pkl',
                                        N_SLICE)
    dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size = BATCH, shuffle = True)
    print('## nstep train: ' + str(len(dataloader_train)))
if DIV_VALID <= 1:
    dataset_valid = MyDataset(d_dataset + '/feature/valid.pkl',
                                        d_dataset + '/label_onset/valid.pkl',
                                        d_dataset + '/label_offset/valid.pkl',
                                        d_dataset + '/label_mpe/valid.pkl',
                                        d_dataset + '/label_velocity/valid.pkl',
                                        d_dataset + '/idx/valid.pkl',
                                        N_SLICE)
    dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size = BATCH, shuffle = False)
    print('## nstep valid: ' + str(len(dataloader_valid)))
if (VALID_TEST is True) and (DIV_TEST <= 1):
    dataset_test = MyDataset(d_dataset + '/feature/test.pkl',
                                        d_dataset + '/label_onset/test.pkl',
                                        d_dataset + '/label_offset/test.pkl',
                                        d_dataset + '/label_mpe/test.pkl',
                                        d_dataset + '/label_velocity/test.pkl',
                                        d_dataset + '/idx/test.pkl',
                                        N_SLICE)
    dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size = BATCH, shuffle = False)
    print('## nstep test : ' + str(len(dataloader_test)))

(5) dataset loading
## nstep valid: 11843


In [None]:
def train(model, iterator, optimizer,
          criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
          criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
          weight_A, weight_B,
          device, verbose_flag):
    model.train()
    epoch_loss = 0

    for i, (input_spec, label_onset, label_offset, label_mpe, label_velocity) in enumerate(iterator):
        input_spec = input_spec.to(device, non_blocking=True)
        label_onset = label_onset.to(device, non_blocking=True)
        label_offset = label_offset.to(device, non_blocking=True)
        label_mpe = label_mpe.to(device, non_blocking=True)
        label_velocity = label_velocity.to(device, non_blocking=True)
        # input_spec: [batch_size, n_bins, margin_b+n_frame+margin_f] (8, 256, 192)
        # label_onset: [batch_size, n_frame, n_note] (8, 128, 88)
        # label_velocity: [batch_size, n_frame, n_note] (8, 128, 88)
        if verbose_flag is True:
            print('***** train i : '+str(i)+' *****')
            print('(1) input_spec  : '+str(input_spec.size()))
            print(input_spec)
            print('(1) label_mpe   : '+str(label_mpe.size()))
            print(label_mpe)
            print('(1) label_velocity : '+str(label_velocity.size()))
            print(label_velocity)

        optimizer.zero_grad()
        output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = model(input_spec)
        # output_onset_A: [batch_size, n_frame, n_note] (8, 128, 88)
        # output_onset_B: [batch_size, n_frame, n_note] (8, 128, 88)
        # output_velocity_A: [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128)
        # output_velocity_B: [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128)

        if verbose_flag is True:
            print('(2) output_onset_A : '+str(output_onset_A.size()))
            print(output_onset_A)
            print('(2) output_onset_B : '+str(output_onset_B.size()))
            print(output_onset_B)
            print('(2) output_velocity_A : '+str(output_velocity_A.size()))
            print(output_velocity_A)
            print('(2) output_velocity_B : '+str(output_velocity_B.size()))
            print(output_velocity_B)

        output_onset_A = output_onset_A.contiguous().view(-1)
        output_offset_A = output_offset_A.contiguous().view(-1)
        output_mpe_A = output_mpe_A.contiguous().view(-1)
        output_velocity_A_dim = output_velocity_A.shape[-1]
        output_velocity_A = output_velocity_A.contiguous().view(-1, output_velocity_A_dim)

        output_onset_B = output_onset_B.contiguous().view(-1)
        output_offset_B = output_offset_B.contiguous().view(-1)
        output_mpe_B = output_mpe_B.contiguous().view(-1)
        output_velocity_B_dim = output_velocity_B.shape[-1]
        output_velocity_B = output_velocity_B.contiguous().view(-1, output_velocity_B_dim)

        # output_onset_A: [batch_size * n_frame * n_note] (90112)
        # output_onset_B: [batch_size * n_frame * n_note] (90112)
        # output_velocity_A: [batch_size * n_note * n_frame, n_velocity] (90112, 128)
        # output_velocity_B: [batch_size * n_note * n_frame, n_velocity] (90112, 128)

        if verbose_flag is True:
            print('(3) output_onset_A : '+str(output_onset_A.size()))
            print('(3) output_onset_B : '+str(output_onset_B.size()))
            print('(3) output_velocity_A : '+str(output_velocity_A.size()))
            print('(3) output_velocity_B : '+str(output_velocity_B.size()))

        label_onset = label_onset.contiguous().view(-1)
        label_offset = label_offset.contiguous().view(-1)
        label_mpe = label_mpe.contiguous().view(-1)
        label_velocity = label_velocity.contiguous().view(-1)
        # label_onset: [batch_size * n_frame * n_note] (90112)
        # label_velocity: [batch_size * n_frame * n_note] (90112)
        if verbose_flag is True:
            print('(4) label_onset   :'+str(label_onset.size()))
            print(label_onset)
            print('(4) label_velocity   :'+str(label_velocity.size()))
            print(label_velocity)

        loss_onset_A = criterion_onset_A(output_onset_A, label_onset)
        loss_offset_A = criterion_offset_A(output_offset_A, label_offset)
        loss_mpe_A = criterion_mpe_A(output_mpe_A, label_mpe)
        loss_velocity_A = criterion_velocity_A(output_velocity_A, label_velocity)
        loss_A = loss_onset_A + loss_offset_A + loss_mpe_A + loss_velocity_A

        loss_onset_B = criterion_onset_B(output_onset_B, label_onset)
        loss_offset_B = criterion_offset_B(output_offset_B, label_offset)
        loss_mpe_B = criterion_mpe_B(output_mpe_B, label_mpe)
        loss_velocity_B = criterion_velocity_B(output_velocity_B, label_velocity)
        loss_B = loss_onset_B + loss_offset_B + loss_mpe_B + loss_velocity_B

        loss = weight_A * loss_A + weight_B * loss_B
        if verbose_flag is True:
            print('(5) loss:'+str(loss.size()))
            print(loss)

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


##
## validation
##
def valid(model, iterator,
          criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
          criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
          weight_A, weight_B,
          device):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for i, (input_spec, label_onset, label_offset, label_mpe, label_velocity) in enumerate(iterator):
            input_spec = input_spec.to(device, non_blocking=True)
            label_onset = label_onset.to(device, non_blocking=True)
            label_offset = label_offset.to(device, non_blocking=True)
            label_mpe = label_mpe.to(device, non_blocking=True)
            label_velocity = label_velocity.to(device, non_blocking=True)

            output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = model(input_spec)

            output_onset_A = output_onset_A.contiguous().view(-1)
            output_offset_A = output_offset_A.contiguous().view(-1)
            output_mpe_A = output_mpe_A.contiguous().view(-1)
            output_velocity_A_dim = output_velocity_A.shape[-1]
            output_velocity_A = output_velocity_A.contiguous().view(-1, output_velocity_A_dim)

            output_onset_B = output_onset_B.contiguous().view(-1)
            output_offset_B = output_offset_B.contiguous().view(-1)
            output_mpe_B = output_mpe_B.contiguous().view(-1)
            output_velocity_B_dim = output_velocity_B.shape[-1]
            output_velocity_B = output_velocity_B.contiguous().view(-1, output_velocity_B_dim)

            label_onset = label_onset.contiguous().view(-1)
            label_offset = label_offset.contiguous().view(-1)
            label_mpe = label_mpe.contiguous().view(-1)
            label_velocity = label_velocity.contiguous().view(-1)

            loss_onset_A = criterion_onset_A(output_onset_A, label_onset)
            loss_offset_A = criterion_offset_A(output_offset_A, label_offset)
            loss_mpe_A = criterion_mpe_A(output_mpe_A, label_mpe)
            loss_velocity_A = criterion_velocity_A(output_velocity_A, label_velocity)
            loss_A = loss_onset_A + loss_offset_A + loss_mpe_A + loss_velocity_A

            loss_onset_B = criterion_onset_B(output_onset_B, label_onset)
            loss_offset_B = criterion_offset_B(output_offset_B, label_offset)
            loss_mpe_B = criterion_mpe_B(output_mpe_B, label_mpe)
            loss_velocity_B = criterion_velocity_B(output_velocity_B, label_velocity)
            loss_B = loss_onset_B + loss_offset_B + loss_mpe_B + loss_velocity_B

            loss = weight_A * loss_A + weight_B * loss_B

            epoch_loss += loss.item()

    return epoch_loss, len(iterator)

In [None]:
 # (7) training
d_dataset = DATASET_DIR
print('(7) training')
print(' epoch_start      : '+str(epoch_start))
print(' div_start        : '+str(div_start))
for epoch in range(epoch_start, EPOCH):
    for div in range(0, 2):
        if div < div_start:
            continue
        print('[epoch: '+str(epoch).zfill(3)+' div: '+str(div).zfill(3)+']')

        # (7-1) training
        print('(7-1) training')
        if DIV_TRAIN > 1:
            dataset_train = MyDataset(d_dataset + '/feature/train_' + str(div).zfill(3)+'.pkl',
                                                d_dataset + '/label_onset/train_' + str(div).zfill(3)+'.pkl',
                                                d_dataset + '/label_offset/train_' + str(div).zfill(3)+'.pkl',
                                                d_dataset + '/label_mpe/train_' + str(div).zfill(3)+'.pkl',
                                                d_dataset + '/label_velocity/train_' + str(div).zfill(3)+'.pkl',
                                                d_dataset + '/idx/train_' + str(div).zfill(3)+'.pkl',
                                                N_SLICE)
            dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size = BATCH, shuffle=True)
            print('## nstep train: '+str(len(dataloader_train)))

        epoch_loss_train = train(model, dataloader_train, optimizer,
                                        criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
                                        criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
                                        WEIGHT_A, WEIGHT_B,
                                        device, verbose_flag = True)

        if DIV_TRAIN > 1:
            del dataset_train, dataloader_train

        # (7-2) validation
        print('(7-2) validation')
        if DIV_VALID > 1:
            epoch_loss_valid = 0
            num_data_valid = 0
            for div_valid in range(DIV_VALID):
                dataset_valid = MyDataset(d_dataset+'/feature/valid_'+str(div_valid).zfill(3)+'.pkl',
                                                    d_dataset+'/label_onset/valid_'+str(div_valid).zfill(3)+'.pkl',
                                                    d_dataset+'/label_offset/valid_'+str(div_valid).zfill(3)+'.pkl',
                                                    d_dataset+'/label_mpe/valid_'+str(div_valid).zfill(3)+'.pkl',
                                                    d_dataset+'/label_velocity/valid_'+str(div_valid).zfill(3)+'.pkl',
                                                    d_dataset+'/idx/valid_'+str(div_valid).zfill(3)+'.pkl',
                                                    N_SLICE)
                dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size= BATCH, shuffle=False)
                print('## nstep valid: '+str(len(dataloader_valid)))
                retval = valid(model, dataloader_valid,
                                        criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
                                        criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
                                        WEIGHT_A, WEIGHT_B,
                                        device)
                epoch_loss_valid += retval[0]
                num_data_valid += retval[1]
                del dataset_valid, dataloader_valid
        else:
            epoch_loss_valid, num_data_valid = valid(model, dataloader_valid,
                                                            criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
                                                            criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
                                                            WEIGHT_A, WEIGHT_B,
                                                            device)
        epoch_loss_valid /= num_data_valid

        # (7-3) test
        if VALID_TEST is True:
            print('(7-3) test')
            if DIV_TEST > 1:
                epoch_loss_test = 0
                num_data_test = 0
                for div_test in range(DIV_TEST):
                    dataset_test = MyDataset(d_dataset+'/feature/test_'+str(div_test).zfill(3)+'.pkl',
                                                        d_dataset+'/label_onset/test_'+str(div_test).zfill(3)+'.pkl',
                                                        d_dataset+'/label_offset/test_'+str(div_test).zfill(3)+'.pkl',
                                                        d_dataset+'/label_mpe/test_'+str(div_test).zfill(3)+'.pkl',
                                                        d_dataset+'/label_velocity/test_'+str(div_test).zfill(3)+'.pkl',
                                                        d_dataset+'/idx/test_'+str(div_test).zfill(3)+'.pkl',
                                                        N_SLICE)
                    dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size= BATCH, shuffle=False)
                    print('## nstep test: '+str(len(dataloader_test)))
                    retval = valid(model, dataloader_test,
                                            criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
                                            criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
                                            WEIGHT_A, WEIGHT_B,
                                            device)
                    epoch_loss_test += retval[0]
                    num_data_test += retval[1]
                    del dataset_test, dataloader_test
            else:
                epoch_loss_test, num_data_test = valid(model, dataloader_test,
                                                                criterion_onset_A, criterion_offset_A, criterion_mpe_A, criterion_velocity_A,
                                                                criterion_onset_B, criterion_offset_B, criterion_mpe_B, criterion_velocity_B,
                                                                WEIGHT_A, WEIGHT_B,
                                                                device)
            epoch_loss_test /= num_data_test
        else:
            epoch_loss_test = 0.0
        print('[epoch: '+str(epoch).zfill(3)+' div: '+str(div).zfill(3)+']')
        print(' loss(train) :'+str(epoch_loss_train))
        print(' loss(valid) :'+str(epoch_loss_valid))
        if VALID_TEST is True:
            print(' loss(test) :'+str(epoch_loss_test))

        # (7-4) save model
        with open(d_out+'/model_'+str(epoch).zfill(3)+'_'+str(div).zfill(3)+'.pkl', 'wb') as f:
            pickle.dump(model, f, protocol=4)
        torch.save({
            'epoch': epoch,
            'div': div,
            'epoch_loss_train': epoch_loss_train,
            'epoch_loss_valid': epoch_loss_valid,
            'epoch_loss_test': epoch_loss_test,
            'best_epoch': epoch,
            'best_div': div,
            'best_loss_valid': best_loss_valid,
            'optimizer_dict': optimizer.state_dict(),
            'scheduler_dict': scheduler.state_dict(),
            'model_dict': model.state_dict(),
            'random': {
                'torch': torch.get_rng_state(),
                'torch_random': torch.random.get_rng_state(),
                'cuda': torch.cuda.get_rng_state(),
                'cuda_all': torch.cuda.get_rng_state_all()
            },
            'model': model},
                    d_out+'/model_'+str(epoch).zfill(3)+'_'+str(div).zfill(3)+'.dat')

        if best_loss_valid > epoch_loss_valid:
            best_loss_valid = epoch_loss_valid
            best_epoch = epoch
            best_div = div
            with open(d_out+'/best_epoch.txt', 'w') as f:
                f.write(str(epoch).zfill(3)+'_'+str(div).zfill(3))
            with open(d_out+'/best_model.pkl', 'wb') as f:
                pickle.dump(model, f, protocol=4)
            torch.save({
                'epoch': epoch,
                'div': div,
                'epoch_loss_train': epoch_loss_train,
                'epoch_loss_valid': epoch_loss_valid,
                'epoch_loss_test': epoch_loss_test,
                'best_epoch': epoch,
                'best_div': div,
                'best_loss_valid': best_loss_valid,
                'optimizer_dict': optimizer.state_dict(),
                'scheduler_dict': scheduler.state_dict(),
                'model_dict': model.state_dict(),
                'random': {
                    'torch': torch.get_rng_state(),
                    'torch_random': torch.random.get_rng_state(),
                    'cuda': torch.cuda.get_rng_state(),
                    'cuda_all': torch.cuda.get_rng_state_all()
                },
                'model': model},
                        d_out+'/best_model.dat')

        # (7-5) save performance
        a_performance['loss_train'].append(epoch_loss_train)
        a_performance['loss_valid'].append(epoch_loss_valid)
        a_performance['loss_test'].append(epoch_loss_test)
        a_performance['datetime'].append(datetime.datetime.now().isoformat())
        a_performance['current_epoch'] = epoch
        a_performance['current_div'] = div
        a_performance['best_epoch'] = best_epoch
        a_performance['best_div'] = best_div
        a_performance['best_loss_valid'] = best_loss_valid
        with open(d_out+'/performance.json', 'w', encoding='utf-8') as f:
            json.dump(a_performance, f, ensure_ascii=False, indent=4, sort_keys=True)
        with open(d_out+'/performance_'+str(epoch).zfill(3)+'_'+str(div).zfill(3)+'.json', 'w', encoding='utf-8') as f:
            json.dump(a_performance, f, ensure_ascii=False, indent=4, sort_keys=True)

        # (7-6) scheduler update
        scheduler.step(epoch_loss_valid)

    div_start = 0

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
          [ 14.4200,  -4.8892,  -3.0444,  ...,  -7.3819,  -8.2224,  -9.0348],
          ...,
          [ 14.2510,  -5.5109,  -3.3439,  ...,  -7.1876,  -8.1356,  -8.8760],
          [ 13.1814,  -6.5121,  -4.1908,  ...,  -8.1443,  -9.3677,  -9.5948],
          [ 14.7964,  -5.8946,  -3.0240,  ...,  -6.8611,  -8.3545,  -9.2437]],

         ...,

         [[ 14.5062,  -5.7607,  -3.3030,  ...,  -7.2772,  -8.4123,  -9.1504],
          [ 14.3989,  -5.8970,  -3.0267,  ...,  -7.3571,  -8.8878, -10.1048],
          [ 14.9342,  -6.6563,  -4.1795,  ...,  -8.0649,  -9.3904, -10.1061],
          ...,
          [ 14.1322,  -5.5269,  -3.7144,  ...,  -7.7558,  -8.0618,  -7.8360],
          [ 14.8542,  -5.6506,  -3.5231,  ...,  -7.4895,  -8.1800,  -9.3401],
          [ 15.0366,  -5.4807,  -3.1127,  ...,  -7.6602,  -8.5949,  -8.8354]],

         [[ 14.6661,  -6.1065,  -3.4579,  ...,  -7.8350,  -8.3646,  -9.4362],
          [ 14.9014,  -5.758

# Extras

In [None]:
%pip install pretty_midi

In [None]:
bestModelPath = './content/drive/MyDrive/Apollo/Checkpoint/best_model.pkl'
testList = './content/drive/MyDrive/Apollo/Files/list/test.list'
mode = 'combination'
MPE_DIR = './content/drive/MyDrive/Apollo/output/mpe'
OUT_NOTE_DIR = './content/drive/MyDrive/Apollo/output/note'
thred_mpe = 0.5
thred_onset = 0.5
thred_offset = 0.5
calc_feature = False
calc_transcript = True
n_stride = 0
ablation = False

In [None]:
from enought import AMT2

In [None]:
AMT1 = AMT2(model_path = '/content/drive/MyDrive/Apollo/Checkpoint/best_model.pkl', verbose_flag = True)

torch version: 2.4.0+cu121
torch cuda   : True
/content/drive/MyDrive/Apollo/Checkpoint/best_model.pkl


AttributeError: Can't get attribute 'Model_SPEC2MIDI' on <module '__main__'>

In [None]:
if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('-f_config', help='config json file', default='../corpus/config.json')
    # parser.add_argument('-f_list', help='file list', default='../corpus/MAESTRO-V3/list/test.list')
    # parser.add_argument('-d_cp', help='checkpoint directory', default='../checkpoint')
    # parser.add_argument('-m', help='input model file', default='best_model.pkl')
    # parser.add_argument('-mode', help='mode to transcript (combination|single)', default='combination')
    # parser.add_argument('-d_wav', help='corpus wav directory', default='../corpus/MAESTRO-V3/wav')
    # parser.add_argument('-d_fe', help='corpus feature directory', default='../corpus/MAESTRO-V3/feature')
    # parser.add_argument('-d_mpe', help='output directory for .mpe', default='result/mpe')
    # parser.add_argument('-d_note', help='output directory for .json', default='result/note')
    # parser.add_argument('-thred_mpe', help='threshold value for mpe detection', type=float, default=0.5)
    # parser.add_argument('-thred_onset', help='threshold value for onset detection', type=float, default=0.5)
    # parser.add_argument('-thred_offset', help='threshold value for offset detection', type=float, default=0.5)
    # parser.add_argument('-calc_feature', help='flag to calculate feature data', action='store_true')
    # parser.add_argument('-calc_transcript', help='flag to calculate transcript data', action='store_true')
    # parser.add_argument('-n_stride', help='number of samples for offset', type=int, default=0)
    # parser.add_argument('-ablation', help='ablation mode', action='store_true')
    # args = parser.parse_args()

    print('** AMT: inference for evaluation **')
    print(' file list      : '+str(testList))
    print(' checkpoint')
    print('  directory     : '+str(CHECKPOINT_DIR))
    print('  model file    : '+str(bestModelPath))
    print(' directories')
    print('  wav           : '+str(WAV_DIR))
    print('  feature       : '+str(FEATURE_DIR))
    print('  onset/mpe     : '+str(MPE_DIR))
    print('  json          : '+str(OUT_NOTE_DIR))
    print(' threshold value')
    print('  onset         : '+str(thred_onset))
    print('  offset        : '+str(thred_offset))
    print('  mpe           : '+str(thred_mpe))
    print(' calculation')
    print('  wav2feature   : '+str(calc_feature))
    print('  transcript    : '+str(calc_transcript))
    print(' stride         : '+str(n_stride))
    print(' ablation mode  : '+str(ablation))

    # parameters
    with open(CHECKPOINT_DIR + '/parameter.json', 'r', encoding='utf-8') as f:
        parameters = json.load(f)

    # list file
    a_list = []
    with open(testList, 'r', encoding='utf-8') as f:
        a_list_tmp = f.readlines()

    for fname in a_list_tmp:
        a_list.append(fname.rstrip('\n'))

    del a_list_tmp


    # # AMT class
    # AMT1 = AMT2(model_path = bestModelPath, verbose_flag = False)

    # inference
    out_dir_mpe = MPE_DIR
    out_dir_note = OUT_NOTE_DIR

    for fname in a_list:
        print('['+str(fname)+']')

        # feature
        if calc_feature is True:
            a_feature = AMT1.wav2feature(WAV_DIR + '/' + fname + '.wav')
            with open(FEATURE_DIR + '/' + fname + '.pkl', 'wb') as f:
                pickle.dump(a_feature, f, protocol=4)
        else:
            with open(FEATURE_DIR + '/' + fname + '.pkl', 'rb') as f:
                a_feature = pickle.load(f)

        # transcript
        if calc_transcript is True:
            if mode == 'combination':
                if n_stride > 0:
                    output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity, output_2nd_onset, output_2nd_offset, output_2nd_mpe, output_2nd_velocity = AMT1.transcript_stride(a_feature, n_stride, mode=mode, ablation_flag=ablation)
                else:
                    output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity, output_2nd_onset, output_2nd_offset, output_2nd_mpe, output_2nd_velocity = AMT1.transcript(a_feature, mode=mode, ablation_flag=ablation)
            else:
                if n_stride > 0:
                    output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity = AMT1.transcript_stride(a_feature, n_stride, mode=mode, ablation_flag=ablation)
                else:
                    output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity = AMT1.transcript(a_feature, mode=mode, ablation_flag=ablation)

            with open(out_dir_mpe+'/'+fname+'_1st.onset', 'wb') as f:
                pickle.dump(output_1st_onset, f, protocol=4)
            with open(out_dir_mpe+'/'+fname+'_1st.offset', 'wb') as f:
                pickle.dump(output_1st_offset, f, protocol=4)
            with open(out_dir_mpe+'/'+fname+'_1st.mpe', 'wb') as f:
                pickle.dump(output_1st_mpe, f, protocol=4)
            with open(out_dir_mpe+'/'+fname+'_1st.velocity', 'wb') as f:
                pickle.dump(output_1st_velocity, f, protocol=4)

            if mode == 'combination':
                with open(out_dir_mpe+'/'+fname+'_2nd.onset', 'wb') as f:
                    pickle.dump(output_2nd_onset, f, protocol=4)
                with open(out_dir_mpe+'/'+fname+'_2nd.offset', 'wb') as f:
                    pickle.dump(output_2nd_offset, f, protocol=4)
                with open(out_dir_mpe+'/'+fname+'_2nd.mpe', 'wb') as f:
                    pickle.dump(output_2nd_mpe, f, protocol=4)
                with open(out_dir_mpe+'/'+fname+'_2nd.velocity', 'wb') as f:
                    pickle.dump(output_2nd_velocity, f, protocol=4)

        else:
            with open(out_dir_mpe+'/'+fname+'_1st.onset', 'rb') as f:
                output_1st_onset = pickle.load(f)
            with open(out_dir_mpe+'/'+fname+'_1st.offset', 'rb') as f:
                output_1st_offset = pickle.load(f)
            with open(out_dir_mpe+'/'+fname+'_1st.mpe', 'rb') as f:
                output_1st_mpe = pickle.load(f)
            with open(out_dir_mpe+'/'+fname+'_1st.velocity', 'rb') as f:
                output_1st_velocity = pickle.load(f)

            if mode == 'combination':
                with open(out_dir_mpe+'/'+fname+'_2nd.onset', 'rb') as f:
                    output_2nd_onset = pickle.load(f)
                with open(out_dir_mpe+'/'+fname+'_2nd.offset', 'rb') as f:
                    output_2nd_offset = pickle.load(f)
                with open(out_dir_mpe+'/'+fname+'_2nd.mpe', 'rb') as f:
                    output_2nd_mpe = pickle.load(f)
                with open(out_dir_mpe+'/'+fname+'_2nd.velocity', 'rb') as f:
                    output_2nd_velocity = pickle.load(f)

        # note (mpe2note)
        a_note_1st_predict = AMT1.mpe2note(a_onset=output_1st_onset,
                                          a_offset=output_1st_offset,
                                          a_mpe=output_1st_mpe,
                                          a_velocity=output_1st_velocity,
                                          thred_onset=thred_onset,
                                          thred_offset=thred_offset,
                                          thred_mpe=thred_mpe,
                                          mode_velocity='ignore_zero',
                                          mode_offset='shorter')
        if mode == 'combination':
            a_note_2nd_predict = AMT1.mpe2note(a_onset=output_2nd_onset,
                                              a_offset=output_2nd_offset,
                                              a_mpe=output_2nd_mpe,
                                              a_velocity=output_2nd_velocity,
                                              thred_onset=thred_onset,
                                              thred_offset=thred_offset,
                                              thred_mpe=thred_mpe,
                                              mode_velocity='ignore_zero',
                                              mode_offset='shorter')
        with open(out_dir_note+'/'+fname+'_1st.json', 'w', encoding='utf-8') as f:
            json.dump(a_note_1st_predict, f, ensure_ascii=False, indent=4, sort_keys=False)
        if mode == 'combination':
            with open(out_dir_note+'/'+fname+'_2nd.json', 'w', encoding='utf-8') as f:
                json.dump(a_note_2nd_predict, f, ensure_ascii=False, indent=4, sort_keys=False)

    print('** done **')

In [None]:
d_ref = './corpus/MAESTRO-V3/reference'
d_est = './output/note'
d_out = './result'
velocity = False
output = '2nd'

In [None]:
import argparse
import json
import mir_eval
import numpy as np

In [None]:
if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('-f_list', help='file list', default='../corpus/MAESTRO/LIST/test.list')
    # parser.add_argument('-d_ref', help='reference directory', default='../corpus/MAPS/ref_16ms_new')
    # parser.add_argument('-d_est', help='estimation directory', default='result/')
    # parser.add_argument('-d_out', help='output directory', default='result/')
    # parser.add_argument('-velocity', help='w/ velocity', action='store_true')
    # parser.add_argument('-output', help='output_1st(1st)|output_2nd(2nd)', default='2nd')
    # args = parser.parse_args()

    print('** mir_eval: transcription (note) **')
    print(' file list     : '+str(testList))
    print(' directories')
    print('  reference    : '+str(d_ref))
    print('  estimation   : '+str(d_est))
    print('  output       : '+str(d_out))
    if velocity is True:
        print(' with velocity : ON')
    else:
        print(' with velocity : OFF')
    print(' output        : '+str(output))

    def _load_transcription_velocity(filename):
        """Loader for data in the format start, end, pitch, velocity."""
        starts, ends, pitches, velocities = mir_eval.io.load_delimited(
            filename, [float, float, int, int])
        # Stack into an interval matrix
        intervals = np.array([starts, ends]).T
        # return pitches and velocities as np.ndarray
        pitches = np.array(pitches)
        velocities = np.array(velocities)
        return intervals, pitches, velocities

    # list file
    a_list = []
    with open(testList, 'r', encoding='utf-8') as f:
        a_list_tmp = f.readlines()
    for fname in a_list_tmp:
        a_list.append(fname.rstrip('\n'))
    del a_list_tmp
    if testList.endswith('test.list'):
        valid_data = '_test'
    elif testList.endswith('valid.list'):
        valid_data = '_valid'
    elif testList.endswith('train.list'):
        valid_data = '_train'
    else:
        valid_data = ''

    if velocity is False:
        print(' result file   : '+str('result_note'+valid_data+'_'+str(output)+'.json'))
        result = {
            "Precision": 0.0,
            "Recall": 0.0,
            "F-measure": 0.0,
            "Average_Overlap_Ratio": 0.0,
            "Precision_no_offset": 0.0,
            "Recall_no_offset": 0.0,
            "F-measure_no_offset": 0.0,
            "Average_Overlap_Ratio_no_offset": 0.0,
            "Onset_Precision": 0.0,
            "Onset_Recall": 0.0,
            "Onset_F-measure": 0.0,
            "Offset_Precision": 0.0,
            "Offset_Recall": 0.0,
            "Offset_F-measure": 0.0
        }
    else:
        print(' result file   : '+str('result_note_velocity'+valid_data+'_'+str(output)+'.json'))
        result = {
            "Precision": 0.0,
            "Recall": 0.0,
            "F-measure": 0.0,
            "Average_Overlap_Ratio": 0.0,
            "Precision_no_offset": 0.0,
            "Recall_no_offset": 0.0,
            "F-measure_no_offset": 0.0,
            "Average_Overlap_Ratio_no_offset": 0.0
        }

    count = 0
    for fname in a_list:
        print(fname)

        # convert estimated file from json to txt
        with open(d_est.rstrip('/')+'/'+fname+'_'+str(output)+'.json', 'r', encoding='utf-8') as f:
            est_obj = json.load(f)

        est_file = d_est.rstrip('/')+'/'+fname+'_'+str(output)+'.txt'
        fo = open(est_file, 'w', encoding='utf-8')
        if velocity is False:
            for obj in est_obj:
                pitch_freq = 440.0*pow(2.0, (int(obj['pitch']) - 69)/12)
                if obj['offset'] - obj['onset'] > 0.0:
                    fo.write(str(obj['onset'])+'\t'+str(obj['offset'])+'\t'+str(pitch_freq)+'\n')
        else:
            for obj in est_obj:
                if obj['offset'] - obj['onset'] > 0.0:
                    fo.write(str(obj['onset'])+'\t'+str(obj['offset'])+'\t'+str(obj['pitch'])+'\t'+str(obj['velocity'])+'\n')
        fo.close()
        del est_obj

        # calculate score
        if velocity is False:
            ref_file = d_ref.rstrip('/')+'/'+fname+'.txt'
            out_file = d_out.rstrip('/')+'/'+fname+'_result_note_'+str(output)+'.json'
            ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_file)
            est_int, est_pitch = mir_eval.io.load_valued_intervals(est_file)
            scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int, est_pitch)
        else:
            ref_file = d_ref.rstrip('/')+'/'+fname+'_velocity.txt'
            out_file = d_out.rstrip('/')+'/'+fname+'_result_note_velocity_'+str(output)+'.json'
            ref_int, ref_pitch, ref_vel = _load_transcription_velocity(ref_file)
            est_int, est_pitch, est_vel = _load_transcription_velocity(est_file)
            scores = mir_eval.transcription_velocity.evaluate(ref_int, ref_pitch, ref_vel,
                                                              est_int, est_pitch, est_vel)

        # output
        with open(out_file, 'w', encoding='utf-8') as f:
            json.dump(scores, f, ensure_ascii=False, indent=4, sort_keys=False)

        # total result
        for attr in scores:
            result[attr] += scores[attr]
        count += 1

    # output (total)
    for attr in result:
        result[attr] /= count
    if velocity is False:
        fo = open(d_est.rstrip('/')+'/result_note'+valid_data+'_'+str(output)+'.json', 'w', encoding='utf-8')
    else:
        fo = open(d_est.rstrip('/')+'/result_note_velocity'+valid_data+'_'+str(output)+'.json', 'w', encoding='utf-8')
    json.dump(result, fo, ensure_ascii=False, indent=4, sort_keys=False)
    fo.close()
    print(result)
    print('** done **')

In [None]:
velocity = True

In [None]:
if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('-f_list', help='file list', default='../corpus/MAESTRO/LIST/test.list')
    # parser.add_argument('-d_ref', help='reference directory', default='../corpus/MAPS/ref_16ms_new')
    # parser.add_argument('-d_est', help='estimation directory', default='result/')
    # parser.add_argument('-d_out', help='output directory', default='result/')
    # parser.add_argument('-velocity', help='w/ velocity', action='store_true')
    # parser.add_argument('-output', help='output_1st(1st)|output_2nd(2nd)', default='2nd')
    # args = parser.parse_args()

    print('** mir_eval: transcription (note) **')
    print(' file list     : '+str(testList))
    print(' directories')
    print('  reference    : '+str(d_ref))
    print('  estimation   : '+str(d_est))
    print('  output       : '+str(d_out))
    if velocity is True:
        print(' with velocity : ON')
    else:
        print(' with velocity : OFF')
    print(' output        : '+str(output))

    def _load_transcription_velocity(filename):
        """Loader for data in the format start, end, pitch, velocity."""
        starts, ends, pitches, velocities = mir_eval.io.load_delimited(
            filename, [float, float, int, int])
        # Stack into an interval matrix
        intervals = np.array([starts, ends]).T
        # return pitches and velocities as np.ndarray
        pitches = np.array(pitches)
        velocities = np.array(velocities)
        return intervals, pitches, velocities

    # list file
    a_list = []
    with open(testList, 'r', encoding='utf-8') as f:
        a_list_tmp = f.readlines()
    for fname in a_list_tmp:
        a_list.append(fname.rstrip('\n'))
    del a_list_tmp
    if testList.endswith('test.list'):
        valid_data = '_test'
    elif testList.endswith('valid.list'):
        valid_data = '_valid'
    elif testList.endswith('train.list'):
        valid_data = '_train'
    else:
        valid_data = ''

    if velocity is False:
        print(' result file   : '+str('result_note'+valid_data+'_'+str(output)+'.json'))
        result = {
            "Precision": 0.0,
            "Recall": 0.0,
            "F-measure": 0.0,
            "Average_Overlap_Ratio": 0.0,
            "Precision_no_offset": 0.0,
            "Recall_no_offset": 0.0,
            "F-measure_no_offset": 0.0,
            "Average_Overlap_Ratio_no_offset": 0.0,
            "Onset_Precision": 0.0,
            "Onset_Recall": 0.0,
            "Onset_F-measure": 0.0,
            "Offset_Precision": 0.0,
            "Offset_Recall": 0.0,
            "Offset_F-measure": 0.0
        }
    else:
        print(' result file   : '+str('result_note_velocity'+valid_data+'_'+str(output)+'.json'))
        result = {
            "Precision": 0.0,
            "Recall": 0.0,
            "F-measure": 0.0,
            "Average_Overlap_Ratio": 0.0,
            "Precision_no_offset": 0.0,
            "Recall_no_offset": 0.0,
            "F-measure_no_offset": 0.0,
            "Average_Overlap_Ratio_no_offset": 0.0
        }

    count = 0
    for fname in a_list:
        print(fname)

        # convert estimated file from json to txt
        with open(d_est.rstrip('/')+'/'+fname+'_'+str(output)+'.json', 'r', encoding='utf-8') as f:
            est_obj = json.load(f)

        est_file = d_est.rstrip('/')+'/'+fname+'_'+str(output)+'.txt'
        fo = open(est_file, 'w', encoding='utf-8')
        if velocity is False:
            for obj in est_obj:
                pitch_freq = 440.0*pow(2.0, (int(obj['pitch']) - 69)/12)
                if obj['offset'] - obj['onset'] > 0.0:
                    fo.write(str(obj['onset'])+'\t'+str(obj['offset'])+'\t'+str(pitch_freq)+'\n')
        else:
            for obj in est_obj:
                if obj['offset'] - obj['onset'] > 0.0:
                    fo.write(str(obj['onset'])+'\t'+str(obj['offset'])+'\t'+str(obj['pitch'])+'\t'+str(obj['velocity'])+'\n')
        fo.close()
        del est_obj

        # calculate score
        if velocity is False:
            ref_file = d_ref.rstrip('/')+'/'+fname+'.txt'
            out_file = d_out.rstrip('/')+'/'+fname+'_result_note_'+str(output)+'.json'
            ref_int, ref_pitch = mir_eval.io.load_valued_intervals(ref_file)
            est_int, est_pitch = mir_eval.io.load_valued_intervals(est_file)
            scores = mir_eval.transcription.evaluate(ref_int, ref_pitch, est_int, est_pitch)
        else:
            ref_file = d_ref.rstrip('/')+'/'+fname+'_velocity.txt'
            out_file = d_out.rstrip('/')+'/'+fname+'_result_note_velocity_'+str(output)+'.json'
            ref_int, ref_pitch, ref_vel = _load_transcription_velocity(ref_file)
            est_int, est_pitch, est_vel = _load_transcription_velocity(est_file)
            scores = mir_eval.transcription_velocity.evaluate(ref_int, ref_pitch, ref_vel,
                                                              est_int, est_pitch, est_vel)

        # output
        with open(out_file, 'w', encoding='utf-8') as f:
            json.dump(scores, f, ensure_ascii=False, indent=4, sort_keys=False)

        # total result
        for attr in scores:
            result[attr] += scores[attr]
        count += 1

    # output (total)
    for attr in result:
        result[attr] /= count
    if velocity is False:
        fo = open(d_est.rstrip('/')+'/result_note'+valid_data+'_'+str(output)+'.json', 'w', encoding='utf-8')
    else:
        fo = open(d_est.rstrip('/')+'/result_note_velocity'+valid_data+'_'+str(output)+'.json', 'w', encoding='utf-8')
    json.dump(result, fo, ensure_ascii=False, indent=4, sort_keys=False)
    fo.close()
    print(result)
    print('** done **')

In [None]:
import argparse
import pickle
import json
import mir_eval
import copy
import numpy as np
import math


In [None]:
hop = 16
d_est = './output/mpe'

In [None]:
def note2freq(note_number):
    return 440.0 * pow(2.0, (int(note_number) - 69) / 12)

if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('-f_config', help='config json file', default='../corpus/config.json')
    # parser.add_argument('-f_list', help='file list', default='../corpus/MAESTRO/LIST/test.list')
    # parser.add_argument('-d_ref', help='reference directory', default='../corpus/MAPS/ref_16ms_new')
    # parser.add_argument('-d_est', help='estimation directory', default='result/')
    # parser.add_argument('-d_out', help='output directory', default='result/')
    # parser.add_argument('-thred_mpe', help='threshold value for MPE (default=0.5)', type=float, default=float(0.5))
    # parser.add_argument('-hop', help='hop length(ms) (default=16)', type=int, choices=[10, 16], default=16)
    # parser.add_argument('-output', help='output_1st(1st)|output_2nd(2nd)', default='2nd')
    # args = parser.parse_args()

    print('** mir_eval: MPE **')
    print(' file list     : '+str(testList))
    print(' directories')
    print('  reference    : '+str(d_ref))
    print('  estimation   : '+str(d_est))
    print('  output       : '+str(d_out))
    print(' threshold mpe : '+str(thred_mpe))
    print(' hop length    : '+str(hop))
    print(' output        : '+str(output))

    a_list = []
    with open(testList, 'r', encoding='utf-8') as f:
        a_list_tmp = f.readlines()
    for fname in a_list_tmp:
        a_list.append(fname.rstrip('\n'))
    del a_list_tmp
    if testList.endswith('test.list'):
        valid_data = '_test'
    elif testList.endswith('valid.list'):
        valid_data = '_valid'
    elif testList.endswith('train.list'):
        valid_data = '_train'
    else:
        valid_data = ''
    print(' result file   : '+str('result_mpe'+valid_data+'_'+str(output)+'.json'))

    result_tmp = {
        "Precision": 0.0,
        "Recall": 0.0,
        "Accuracy": 0.0,
        "Substitution Error": 0.0,
        "Miss Error": 0.0,
        "False Alarm Error": 0.0,
        "Total Error": 0.0,
        "Chroma Precision": 0.0,
        "Chroma Recall": 0.0,
        "Chroma Accuracy": 0.0,
        "Chroma Substitution Error": 0.0,
        "Chroma Miss Error": 0.0,
        "Chroma False Alarm Error": 0.0,
        "Chroma Total Error": 0.0
    }
    result = {
        '10ms': copy.deepcopy(result_tmp),
        '16ms': copy.deepcopy(result_tmp)
    }

    count = 0
    for fname in a_list:
        print(fname)

        # reference file
        ref_10ms_file = d_ref.rstrip('/')+'/'+fname+'_mpe_10ms.txt'

        # estimated file
        with open(d_est.rstrip('/')+'/'+fname+'_'+str(output)+'.mpe', 'rb') as f:
            a_mpe_est = pickle.load(f)
        nbin=a_mpe_est.shape[1]

        if hop == 16:
            # convert estimated file from .mpe to .txt
            ref_16ms_file = d_ref.rstrip('/')+'/'+fname+'_mpe_16ms.txt'
            with open(ref_16ms_file, 'r', encoding='utf-8') as f:
                a_ref_16ms = f.readlines()
            nframe = min(len(a_ref_16ms), len(a_mpe_est))

            est_16ms_file = d_est.rstrip('/')+'/'+fname+'_mpe_16ms_'+str(output)+'.txt'
            fo = open(est_16ms_file, 'w', encoding='utf-8')
            for ii in range(nframe):
                fo.write(str(round(ii*0.016, 3)))
                for jj in range(nbin):
                    if a_mpe_est[ii][jj] >= thred_mpe:
                        fo.write('\t'+str(note2freq(jj+NOTE_MIN)))
                fo.write('\n')
            fo.close()
            del a_mpe_est

            # calculate score (16ms)
            ref_16ms_times, ref_16ms_freqs = mir_eval.io.load_ragged_time_series(ref_16ms_file)
            est_16ms_times, est_16ms_freqs = mir_eval.io.load_ragged_time_series(est_16ms_file)
            scores_16ms = mir_eval.multipitch.evaluate(ref_16ms_times, ref_16ms_freqs, est_16ms_times, est_16ms_freqs)
            # output (16ms)
            with open(d_out.rstrip('/')+'/'+fname+'_result_mpe_16ms_'+str(output)+'.json', 'w', encoding='utf-8') as f:
                json.dump(scores_16ms, f, ensure_ascii=False, indent=4, sort_keys=False)

            # calculate score (10ms)
            ref_10ms_times, ref_10ms_freqs = mir_eval.io.load_ragged_time_series(ref_10ms_file)
            nframe_10ms = math.ceil(est_16ms_times[len(est_16ms_times)-1] / 0.01 + 1)
            est_10ms_times = np.array([0.0]*nframe_10ms)
            for ii in range(nframe_10ms):
                est_10ms_times[ii] = ii*0.01
            est_10ms_freqs = mir_eval.multipitch.resample_multipitch(est_16ms_times, est_16ms_freqs, est_10ms_times)
            scores_10ms = mir_eval.multipitch.evaluate(ref_10ms_times, ref_10ms_freqs, est_10ms_times, est_10ms_freqs)
            # output (10ms)
            with open(d_out.rstrip('/')+'/'+fname+'_result_mpe_10ms_'+str(output)+'.json', 'w', encoding='utf-8') as f:
                json.dump(scores_10ms, f, ensure_ascii=False, indent=4, sort_keys=False)

            # total result
            for attr in scores_16ms:
                result['16ms'][attr] += scores_16ms[attr]
        else:
            # convert estimated file from .mpe to .txt
            with open(ref_10ms_file, 'r', encoding='utf-8') as f:
                a_ref_10ms = f.readlines()
            nframe = min(len(a_ref_10ms), len(a_mpe_est))

            est_10ms_file = d_est.rstrip('/')+'/'+fname+'_mpe_10ms_'+str(output)+'.txt'
            fo = open(est_10ms_file, 'w', encoding='utf-8')
            for ii in range(nframe):
                fo.write(str(round(ii*0.01, 3)))
                for jj in range(nbin):
                    if a_mpe_est[ii][jj] >= thred_mpe:
                        fo.write('\t'+str(note2freq(jj+NOTE_MIN)))
                fo.write('\n')
            fo.close()
            del a_mpe_est

            # calculate score
            ref_10ms_times, ref_10ms_freqs = mir_eval.io.load_ragged_time_series(ref_10ms_file)
            est_10ms_times, est_10ms_freqs = mir_eval.io.load_ragged_time_series(est_10ms_file)
            scores_10ms = mir_eval.multipitch.evaluate(ref_10ms_times, ref_10ms_freqs, est_10ms_times, est_10ms_freqs)
            # output
            with open(d_out.rstrip('/')+'/'+fname+'_result_mpe_10ms_'+str(output)+'.json', 'w', encoding='utf-8') as f:
                json.dump(scores_10ms, f, ensure_ascii=False, indent=4, sort_keys=False)

        # total result
        for attr in scores_10ms:
            result['10ms'][attr] += scores_10ms[attr]
        count += 1

    # output (total)
    for attr in result['10ms']:
        if hop == 16:
            result['16ms'][attr] /= count
        result['10ms'][attr] /= count

    # f1-score
    if hop == 16:
        if (result['16ms']['Precision'] + result['16ms']['Recall']) > 0.0:
            result['16ms']['f1'] = (2.0 * result['16ms']['Precision'] * result['16ms']['Recall']) / (result['16ms']['Precision'] + result['16ms']['Recall'])
        else:
            result['16ms']['f1'] = 0.0
    if (result['10ms']['Precision'] + result['10ms']['Recall']) > 0.0:
        result['10ms']['f1'] = (2.0 * result['10ms']['Precision'] * result['10ms']['Recall']) / (result['10ms']['Precision'] + result['10ms']['Recall'])
    else:
        result['10ms']['f1'] = 0.0

    with open(d_out.rstrip('/')+'/result_mpe'+valid_data+'_'+str(output)+'.json', 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=4, sort_keys=False)
    print(result)
    print('** done **')

In [None]:
def transcribe_audio_to_midi(model_path, audio_path, output_midi_path):
    features = AMT1.wav2feature(audio_path)
    outputs = AMT1.transcript(features)
    print(outputs)
    notes = AMT1.mpe2note(a_onset = outputs[0], a_offset = outputs[1], a_mpe = outputs[2], a_velocity = outputs[3])
    AMT1.note2midi(notes, output_midi_path)
    print(f"Transcription complete. MIDI file saved to {output_midi_path}")

if __name__ == "__main__":
    transcribe_audio_to_midi(
        model_path='./checkpoint/MAESTRO-V3/best_model.pkl',
        audio_path='./corpus/MAESTRO-V3/wav/test_001.wav',
        output_midi_path='./output_midi/out.midi'
    )