# Apollo Model

## Imports

In [12]:
import tensorflow as tf
import pickle
import numpy as np
import librosa
import pretty_midi
from keras import layers, Model, activations

## Building the Model

### AMT

In [4]:
class AMT():
    def __init__(self, config, model_path, batch_size = 1, verbose = False):

        self.config = config
        
        if model_path is None:
            self.model = None

        else:
            with open(model_path, 'rb') as model_file:
                self.model = pickle.load(model_file) #Read a binary file
            self.model = self.model.to(self.device)
            self.model.eval()
            if verbose:
                print(self.model)
            
            self.batch_size = batch_size

    
    def wavToFeatures(self, wav_file):
        y_mono, sr = librosa.load(wav_file, mono = True)

        # Resample
        target_sr = self.config['feature']['sr']
        y_resampled = librosa.resample(y_mono, orig_sr= sr, target_sr = target_sr)

        # Compute Mel Spectrogram
        mel_spectrogram = librosa.feature.melspectrogram(
            y= y_resampled,
            sr = target_sr,
            n_fft = self.config['feature']['fft_bins'],
            hop_length = self.config['feature']['hop_sample'],
            n_mels = self.config['feature']['mel_bins'],
            window = 'hann')
        
        # Convert to log scale
        log_offset = self.config['feature']['log_offset']
        mel_spectrogram_log = np.log(mel_spectrogram + log_offset)

        # Trasnpose the spectrogram
        mel_spectrogram_log_transposed = mel_spectrogram_log.T

        # Convert the spectrogram to Tensor
        a_feature = tf.convert_to_tensor(mel_spectrogram_log_transposed, dtype=tf.float32)

        return a_feature
    

    def transcript(self, a_feature, mode = 'combination', ablation = False):
        a_feature = np.array(a_feature, dtype=tf.float32)

        # Create the padding at the beginning (before the feature)
        a_tmp_b = np.full(
            [self.config[input]['margin_b'], self.config['feature']['n_bins']],
            self.config['input']['min_value'],
            dtype= np.float32
        )

        # Calculate the length of padding needed at the end
        len_s = int(np.ceil(a_feature[0] / self.config['input']['num_frame'])* self.config['input']['num_frame']) - a_feature.shape[0]

        # Create the padding at the end (after the feature)
        a_tmp_f = np.full(
            [len_s + self.config['input']['margin_f'], self.config['feature']['n_bins']],
            self.config['input']['min_value'],
            dtype= np.float32
        )

        # Concatenate the beginning padding, feature, and finish padding
        a_input = np.concatenate([a_tmp_b, a_feature.numpy(), a_tmp_f], axis=0)

        # Convert to a TesnsorFlow tensor
        a_input = tf.convert_to_tensor(a_input, dtype=tf.float32)

        a_output_onset_A = np.zeros(
            (a_feature.shape[0] + len_s, self.congif['midi']['num_note']),
            dtype= np.float32
        )

        a_output_offset_A = np.zeros(
            (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
            dtype= np.float32
        )

        a_output_mpe_A = np.zeros(
            (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
            dtype= np.float32
        )

        a_output_velocity_A = np.zeros(
            (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
            dtype= np.float32
        )

        if mode == 'combination':
            a_output_onset_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
            a_output_offset_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
            a_output_mpe_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
            a_output_velocity_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
        
        self.model.eval()

        for i in range(0, a_feature.shape[0], self.config['input']['num_frame']):
            input_spec = tf.transpose(a_input[i:i+self.config['input']['margin_b']+self.config['input']['num_frame']+self.config['input']['margin_f']], perm=[1, 0])
            input_spec = tf.expand_dims(input_spec, axis = 0)

            if mode == 'combination':
                if ablation:
                    output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, \
                    output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec, training=False)
                else:
                    output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, \
                    output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec, training=False)
            else:
                output_onset_A, output_offset_A, output_mpe_A, output_velocity_A = self.model(input_spec, training=False)
            
            a_output_onset_A[i:i+self.config['input']['num_frame']] = tf.squeeze(output_onset_A, axis=0).numpy()
            a_output_offset_A[i:i+self.config['input']['num_frame']] = tf.squeeze(output_offset_A, axis=0).numpy()
            a_output_mpe_A[i:i+self.config['input']['num_frame']] = tf.squeeze(output_mpe_A, axis=0).numpy()
            a_output_velocity_A[i:i+self.config['input']['num_frame']] = tf.argmax(tf.squeeze(output_velocity_A, axis=0), axis=2).numpy()

            if mode == 'combination':
                a_output_onset_B[i:i+self.config['input']['num_frame']] = tf.squeeze(output_onset_B, axis=0).numpy()
                a_output_offset_B[i:i+self.config['input']['num_frame']] = tf.squeeze(output_offset_B, axis=0).numpy()
                a_output_mpe_B[i:i+self.config['input']['num_frame']] = tf.squeeze(output_mpe_B, axis=0).numpy()
                a_output_velocity_B[i:i+self.config['input']['num_frame']] = tf.argmax(tf.squeeze(output_velocity_B, axis=0), axis=2).numpy()

        if mode == 'combination':
            return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A, \
                   a_output_onset_B, a_output_offset_B, a_output_mpe_B, a_output_velocity_B
        else:
            return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A
    
    def trasnscipt_stride(self, a_feature, n_offset, mode = 'combination', ablation = False):
        a_feature = np.array(a_feature, dtype=tf.float32)

        # Create the padding at the beginning (before the feature)
        half_frame = int(self.config['input']['num_frame']/2)
        a_tmp_b = np.full(
            [self.config[input]['margin_b']+ n_offset, self.config['feature']['n_bins']],
            self.config['input']['min_value'],
            dtype= np.float32
        )

        # Calculate the length of padding needed at the end
        tmp_len = a_feature.shape[0] + self.config['input']['margin_b'] + self.config['input']['margin_f'] + half_frame
        len_s = int(np.ceil(tmp_len / half_frame) * half_frame) - tmp_len

        # Create the padding at the end (after the feature)
        a_tmp_f = np.full(
            [len_s+self.config['input']['margin_f']+ (half_frame-n_offset), self.config['feature']['n_bins']],
            self.config['input']['min_value'],
            dtype= np.float32
        )

        # Concatenate the beginning padding, feature, and end padding
        a_input = np.concatenate([a_tmp_b, a_feature.numpy(), a_tmp_f], axis=0)

        # Convert to a TesnsorFlow tensor
        a_input = tf.convert_to_tensor(a_input, dtype=tf.float32)

        a_output_onset_A = np.zeros(
            (a_feature.shape[0] + len_s, self.congif['midi']['num_note']),
            dtype= np.float32
        )

        a_output_offset_A = np.zeros(
            (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
            dtype= np.float32
        )

        a_output_mpe_A = np.zeros(
            (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
            dtype= np.float32
        )

        a_output_velocity_A = np.zeros(
            (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
            dtype= np.float32
        )

        if mode == 'combination':
            a_output_onset_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
            a_output_offset_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
            a_output_mpe_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
            a_output_velocity_B = np.zeros(
                (a_feature.shape[0] + len_s, self.config['midi']['num_note']),
                dtype= np.float32
            )
        
        self.model.eval()

        for i in range(0, a_feature.shape[0], half_frame):
            input_spec = tf.transpose(a_input[i:i+self.config['input']['margin_b']+self.config['input']['num_frame']+self.config['input']['margin_f']], perm=[1, 0])
            input_spec = tf.expand_dims(input_spec, axis = 0)

            if mode == 'combination':
                if ablation:
                    output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, \
                    output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec, training=False)
                else:
                    output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, \
                    output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec, training=False)
            else:
                output_onset_A, output_offset_A, output_mpe_A, output_velocity_A = self.model(input_spec, training=False)
            
            a_output_onset_A[i:i+half_frame] = tf.squeeze(output_onset_A, axis=0)[n_offset:n_offset+half_frame].numpy()
            a_output_offset_A[i:i+half_frame] = tf.squeeze(output_offset_A, axis=0)[n_offset:n_offset+half_frame].numpy()
            a_output_mpe_A[i:i+half_frame] = tf.squeeze(output_mpe_A, axis=0)[n_offset:n_offset+half_frame].numpy()
            a_output_velocity_A[i:i+half_frame] = tf.argmax(tf.squeeze(output_velocity_A, axis=0)[n_offset:n_offset+half_frame], axis=2).numpy()

            if mode == 'combination':
                a_output_onset_B[i:i+half_frame] = tf.squeeze(output_onset_B, axis=0)[n_offset:n_offset+half_frame].numpy()
                a_output_offset_B[i:i+half_frame] = tf.squeeze(output_offset_B, axis=0)[n_offset:n_offset+half_frame].numpy()
                a_output_mpe_B[i:i+half_frame] = tf.squeeze(output_mpe_B, axis=0)[n_offset:n_offset+half_frame].numpy()
                a_output_velocity_B[i:i+half_frame] = tf.argmax(tf.squeeze(output_velocity_B, axis=0)[n_offset:n_offset+half_frame], axis=2).numpy()

        if mode == 'combination':
            return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A, \
                   a_output_onset_B, a_output_offset_B, a_output_mpe_B, a_output_velocity_B
        else:
            return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A
    
    def mpeToNote(self, a_onset = None, a_offset = None, a_mpe = None, a_velocity = None, thred_onset = 0.5, thred_offset = 0.5, thred_mpe = 0.5, mode_velocity = 'ignore_zero', mode_offset = 'shorter'):
        a_note = []
        hop_sec = float(self.congif['feature']['hop_sample'] / self.config['feature']['sr'])
        
        for j in range(self.config['midi']['num_note']):
            a_onset_detect = []

            for i in range(len(a_onset)):
                if a_onset[i][j] >= thred_onset:
                    left_flag = True
                    for k in range (i-1, -1, -1):
                        if a_onset[i][j] > a_onset[k][j]:
                            left_flag = True
                            break
                        elif a_onset[i][j] < a_onset[k][j]:
                            left_flag = False
                            break
                    right_flag = True
                    for k in range(i+1, len(a_onset)):
                        if a_onset[i][j] > a_onset[k][j]:
                            right_flag = True
                            break
                        elif a_onset[i][j] < a_onset[k][j]:
                            right_flag = False
                            break
                    if left_flag and right_flag:
                        if i==0 or i == len(a_onset)-1:
                            onset_time = i* hop_sec
                        else:
                            if a_onset[i-1][j] == a_onset[i+1][j]:
                                onset_time = i * hop_sec
                            elif a_onset[i-1][j] > a_onset[i+1][j]:
                                    onset_time = (i * hop_sec - (hop_sec * 0.5 * (a_onset[i-1][j] - a_onset[i+1][j]) / (a_onset[i][j] - a_onset[i+1][j])))
                            else:
                                onset_time = (i * hop_sec + (hop_sec * 0.5 * (a_onset[i+1][j] - a_onset[i-1][j]) / (a_onset[i][j] - a_onset[i-1][j])))
                        a_onset_detect.append({'loc': i, 'onset_time': onset_time})
                
                a_offset_detect = []

                for i in range(len(a_offset)):
                    if a_offset[i][j] >= thred_offset:
                        left_flag = True
                        for k in range (i-1, -1, -1):
                            if a_offset[i][j] > a_offset[k][j]:
                                left_flag = True
                                break
                            elif a_offset[i][j] < a_offset[k][j]:
                                left_flag = False
                                break
                        right_flag = True
                        for k in range(i+1, len(a_offset)):
                            if a_offset[i][j] > a_offset[k][j]:
                                right_flag = True
                                break
                            elif a_offset[i][j] < a_offset[k][j]:
                                right_flag = False
                                break
                        if left_flag and right_flag:
                            if i==0 or i == len(a_offset)-1:
                                offset_time = i* hop_sec
                            else:
                                if a_offset[i-1][j] == a_offset[i+1][j]:
                                    offset_time = i * hop_sec
                                elif a_offset[i-1][j] > a_offset[i+1][j]:
                                    offset_time = (i * hop_sec - (hop_sec * 0.5 * (a_offset[i-1][j] - a_offset[i+1][j]) / (a_offset[i][j] - a_offset[i+1][j])))
                                else:
                                    offset_time = (i * hop_sec + (hop_sec * 0.5 * (a_offset[i+1][j] - a_offset[i-1][j]) / (a_offset[i][j] - a_offset[i-1][j])))
                            a_offset_detect.append({'loc': i, 'offset_time': offset_time})
                
                time_next = 0.0
                time_offset = 0.0
                time_mpe = 0.0

                for idx_on in range(len(a_onset_detect)):
                    loc_onset = a_onset_detect[idx_on]['loc']
                    time_onset = a_onset_detect[idx_on]['onset_time']

                    if idx_on + 1 < len(a_onset_detect):
                        loc_next = a_onset_detect[idx_on+1]['loc']
                        time_next = a_onset_detect[idx_on+1]['onset_time']
                    else:
                        loc_next = len(a_mpe) 
                        time_next = (loc_next - 1) * hop_sec
                    
                    # offset
                    loc_offset = loc_onset + 1
                    flag_offset = False

                    for idx_off in range(len(a_offset_detect)):
                        if loc_onset < a_offset_detect[idx_off]['loc']:
                            loc_offset = a_offset_detect[idx_off]['loc']
                            time_offset = a_offset_detect[idx_off]['offset_time']
                            flag_offset = True
                            break
                        if loc_offset > loc_next:
                            loc_offset = loc_next
                            time_offset = time_next
                    
                    loc_mpe = loc_onset + 1
                    flag_mpe = False

                    for idx_mpe in range(len(a_mpe)):
                        if a_mpe[idx_mpe][j] < thred_mpe:
                            loc_mpe = idx_mpe
                            time_mpe = loc_mpe * hop_sec
                            flag_mpe = True
                            break
                    

                    pitch_value = int(j + self.config['midi']['note_min'])
                    velocity_value = int(a_velocity[loc_onset][j])

                    if flag_offset is False and flag_mpe is False:
                        offset_value = float(time_next)
                    elif flag_offset is True and flag_mpe is False:
                        offset_value = float(time_offset)
                    elif flag_offset is False and flag_mpe is True:
                        offset_value = float(time_mpe)
                    else:
                        if mode_offset == 'offset':
                            offset_value = float(time_offset)
                        elif mode_offset == 'longer':
                            if loc_offset >= loc_mpe:
                                offset_value = float(time_offset)
                            else:
                                offset_value = float(time_mpe)
                        else:
                            if loc_offset <= loc_mpe:
                                offset_value = float(time_offset)
                            else:   
                                 offset_value = float(time_mpe)
                    
                    if mode_velocity != 'ignore_zero':
                        a_note.append({'pitch': pitch_value, 'onset': float(time_onset), 'offset': offset_value, 'velocity': velocity_value})
                    else:
                        if velocity_value > 0:
                            a_note.append({'pitch': pitch_value, 'onset': float(time_onset), 'offset': offset_value, 'velocity': velocity_value})
                    
                    if (len(a_note) > 0) and (a_note[len(a_note)-1]['pitch'] == a_note[len(a_note)-2]['pitch']) and (a_note[len(a_note)-1]['onset'] == a_note[len(a_note)-2]['onset']):
                        a_note[len(a_note)-2]['offset'] = a_note[len(a_note)-1]['onset']
                    
        a_note = sorted(sorted(a_note, key = lambda x : x['pitch']), key = lambda x : x['onset'])
        return a_note

    def noteToMIDI(self, a_note, midi_file):
        midi = pretty_midi.PrettyMIDI()
        instrument = pretty_midi.Instrument(program=0)
        for note in a_note:
            instrument.notes.append(pretty_midi.Note(velocity=note['velocity'], pitch=note['pitch'], start=note['onset'], end=note['offset']))
        midi.instruments.append(instrument)
        midi.write(midi_file)

        return

### Model spectogram to midi

#### Model

In [None]:
class Model_SPEC2MIDI(Model):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder_spec2midi = encoder
        self.decoder_spec2midi = decoder

    def forward(self, input_spec):
        enc_vector = self.encoder_spec2midi(input_spec)
        
        output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector)

        return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B, attention

#### Encoder

In [None]:
class Encoder_SPEC2MIDI(Model):
    def __init__(self, n_margin,n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
        super().__init__()

        self.n_frame = n_frame
        self.n_bin = n_bin
        self.cnn_channel = cnn_channel
        self.cnn_kernel = cnn_kernel
        self.hid_dim = hid_dim
        self.device = device

        self.conv = layers.Conv2D(1, self.cnn_channel, kernel_size=(1, self.cnn_kernel))
        self.n_proc = n_margin * 2 + 1
        self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1))
        self.tok_embedding_freq = layers.Dense(hid_dim, self.cnn_dim)
        self.pos_embedding_freq = layers.Embedding(n_bin, hid_dim)
        self.layers_freq = [
            EncoderLayer(hid_dim, n_heads, pf_dim, dropout)
            for _ in range(n_layers)
        ]
        self.dropout = layers.Dropout(dropout)
        self.scale_freq = tf.sqrt(tf.constant([hid_dim], dtype= tf.float32))

        with tf.device('/gpu:0'):
           self.scale_freq = tf.sqrt(tf.constant([hid_dim], dtype= tf.float32))

    def forward(self, spec_in):
        batch_size = spec_in.shape[0]

        spec = tf.image.extract_patches(
            images = tf.expand_dims(spec_in, -1),
            sizes = [1, 1, self.n_proc, 1],
            strides = [1, 1, 1, 1],
            rates = [1, 1, 1, 1],
            padding = 'VALID',
        )

        spec = tf.reshape(spec, (batch_size, self.n_frame, self.n_bin, self.n_proc))

        spec_cnn = tf.reshape(spec, (batch_size * self.n_frame, self.n_bin, self.n_proc, 1))
        spec_cnn = self.conv(spec_cnn)
        spec_cnn = tf.transpose(spec_cnn, perm=[0, 2, 3, 1])
        spec_cnn = tf.reshape(spec_cnn, (batch_size * self.n_frame, self.n_bin, self.cnn_dim))

        spec_cnn_freq = spec_cnn
        spec_emb_freq = self.tok_embedding_freq(spec_cnn_freq)

        pos_freq = tf.range(0, self.n_bin)
        pos_freq = tf.tile(tf.expand_dims(pos_freq,0), [batch_size * self.n_frame, 1])
        spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq))

        for layer_freq in self.layers_freq:
            spec_freq = layer_freq(spec_freq)
        

        spec_freq = tf.reshape(spec_freq, (batch_size, self.n_frame, self.n_bin, self.hid_dim))

        return spec_freq

#### Decoder

In [None]:
class Decoder_SPEC2MIDI(Model):
    def __init__(self,n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
        super().__init__()

        self.n_frame = n_frame
        self.n_bin = n_bin
        self.n_note = n_note
        self.n_velocity = n_velocity
        self.hid_dim = hid_dim
        self.device = device

        self.sigmoid = layers.Activation('sigmoid')
        self.dropout = layers.Dropout(rate = dropout)

        self.pos_embedding_freq = layers.Embedding(n_note, hid_dim)
        self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device)
        self.layers_freq = [
            DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
            for _ in range(n_layers - 1)
        ]

        self.fc_onset_freq = layers.Dense(hid_dim, 1)
        self.fc_offset_freq = layers.Dense(hid_dim, 1)
        self.fc_mpe_freq = layers.Dense(hid_dim, 1)
        self.fc_velocity_freq = layers.Dense(hid_dim, self.n_velocity)

        self.scale_freq = tf.sqrt(tf.constant([hid_dim], dtype= tf.float32))

        self.pos_embedding_freq = layers.Embedding(n_frame, hid_dim)
        self.layers_time = [
            EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
            for _ in range(n_layers)
        ]

        self.fc_onset_time = layers.Dense(hid_dim, 1)
        self.fc_offset_time = layers.Dense(hid_dim, 1)
        self.fc_mpe_time = layers.Dense(hid_dim, 1)
        self.fc_velocity_time = layers.Dense(hid_dim, self.n_velocity)

    def forward(self, enc_spec):
        batch_size = enc_spec.shape[0]
        enc_spec = enc_spec.reshape([batch_size * self.n_frame, self.n_bin, self.hid_dim])

        pos_freq = tf.range(self.n_note, dtype=tf.float32)[tf.newaxis, :]  # Create a 1D tensor and add a new axis
        pos_freq = tf.repeat(pos_freq, repeats=batch_size * self.n_frame, axis=0)  # Repeat the row

        midi_freq = self.pos_embedding_freq(pos_freq)

        midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq)

        for layer_freq in self.layers_freq:
            midi_freq, attention_freq = layer_freq(enc_spec, midi_freq)

        dim = attention_freq.shape
        attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]])

        output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note]))
        output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note]))
        output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note]))
        output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity])

        midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim])

        pos_time = tf.range(self.n_frame, dtype=tf.float32)[tf.newaxis, :]  # Create a 1D tensor and add a new axis
        pos_time = tf.repeat(pos_time, repeats=batch_size * self.n_note, axis=0)  # Repeat the row        
        midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time))

        for layer_time in self.layers_time:
            midi_time, _ = layer_time(midi_time)

        output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous())
        output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous())
        output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous())
        output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous()
        
        return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, attention_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time

#### Encoder Layer class

In [None]:
class EncoderLayer(layers):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()

        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForwardLayer(hid_dim, pf_dim, dropout)

        self.dropout = layers.Dropout(rate = dropout)

    def forward(self, src):

        _src, _ = self.self_attention(src, src, src)
        src = self.layer_norm(src + self.dropout(_src))
        _src = self.feed_forward(src)
        src = self.layer_norm(src + self.dropout(_src))

        return src #recieve data vector of the last layer

#### Decoder Layer Zero

In [None]:
class DecoderLayer_Zero(layers):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()

        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.feed_forward = PositionwiseFeedForwardLayer(hid_dim, pf_dim, dropout)

        self.dropout = layers.Dropout(rate = dropout)

    def forward(self, enc_src, trg):

        _trg, attention = self.encoder_attention(trg, enc_src, enc_src)
        trg = self.layer_norm(trg + self.dropout(_trg))
        _trg = self.feed_forward(trg)
        trg = self.layer_norm(trg + self.dropout(_trg))

        return trg, attention #recieve a vector for initializing the SPEC2MIDI model

#### Decoder Layer

In [None]:
class DecoderLayer(layers):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.layer_norm = layers.LayerNormalization(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.feed_forward = PositionwiseFeedForwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = layers.Dropout(dropout)

    def forward(self, enc_src, trg):
        #trg = [batch_size, trg_len, hid_dim]
        #enc_src = [batch_size, src_len, hid_dim]
        
        #self attention
        _trg, _ = self.self_attention(trg, trg, trg)
        #dropout, residual layer and layer normalization
        trg = self.layer_norm(trg + self.dropout(_trg))

        #encoder attention layer
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src)
        #dropout, residual and layer normalization
        trg = self.layer_norm(trg + self.dropout(_trg))

        #feed forward
        _trg = self.feed_forward(trg)
        #dropout, residual and layer normalization
        trg = self.layer_norm(trg + self.dropout(_trg))

        return trg, attention

#### MultiHead Attention Layer

In [None]:
class MultiHeadAttentionLayer(layers):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = layers.Dense(hid_dim, hid_dim)
        self.fc_k = layers.Dense(hid_dim, hid_dim)
        self.fc_v = layers.Dense(hid_dim, hid_dim)
        self.fc_o = layers.Dense(hid_dim, hid_dim)
        self.dropout = layers.Dropout(dropout)
        self.scale = tf.math.sqrt(tf.cast(self.head_dim, tf.float32))

        def forward(self, query, key, value):
            batch_size = query.shape[0]
            #query = [batch_size, query_len, hid_dim]
            #key = [batch_size, key_len, hid_dim]
            #value = [batch_size, value_len, hid_dim]

            Q = self.fc_q(query)
            K = self.fc_k(query)
            V = self.fc_v(query)

            Q = tf.reshape(Q, [batch_size, -1, self.n_heads, self.head_dim])
            Q = tf.transpose(Q, perm=[0, 2, 1, 3])

            K = tf.reshape(K, [batch_size, -1, self.n_heads, self.head_dim])
            K = tf.transpose(K, perm=[0, 2, 1, 3])

            V = tf.reshape(V, [batch_size, -1, self.n_heads, self.head_dim])
            V = tf.transpose(V, perm=[0, 2, 1, 3])

            energy = tf.matmul(Q, tf.transpose(K, perm = [0, 1, 3, 2])) / self.scale

            attention = activations.softmax(energy, axis=-1)

            x = tf.matmul(self.dropout(attention), V)

            x = tf.transpose(x, perm=[0, 2, 1, 3])

            x = tf.reshape(x, [batch_size, -1, self.hid_dim])

            x = self.fc_o(x)

            return x, attention

#### Positionwise Feed Forward Layer

In [None]:
class PositionwiseFeedForwardLayer(layers):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = layers.Dense(hid_dim, pf_dim)
        self.fc_2 = layers.Dense(pf_dim, hid_dim)
        self.dropout = layers.Dropout(dropout)

    def forward(self, x):
        #x = [batch_size, seq_len, hid_dim]

        x = self.dropout(activations.relu(self.fc_1(x)))

        x = self.fc_2(x)

        return x

### Spectogram to MIDI Ablation Model

#### Single output model

In [None]:
class Model_single(Model):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, input_spec):
        # input spec = [batch_size, n_bin, margin + n_frame + margin], (8, 256, 192)

        enc_vector = self.encoder_spec2midi(input_spec)

        output_onset, output_offset, output_mpe, output_velocity = self.decoder_spec2midi(enc_vector)

        return output_onset, output_offset, output_mpe, output_velocity 

#### Combined output model

In [None]:
class Model_combination(Model):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_spec):
        # input spec = [batch_size, n_bin, margin + n_frame + margin], (8, 256, 192)

        enc_vector = self.encoder_spec2midi(input_spec)

        output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector)

        return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B

#### Encoder

In [None]:
class Encoder_CNNtime_SAfreq(tf.Module):
    def __init__(self, n_margin, n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
        super.__init__()

        self.n_frame = n_frame
        self.n_bin = n_bin
        self.cnn_channel = cnn_channel
        self.cnn_kernel = cnn_kernel
        self.hid_dim = hid_dim
        self.device = device

        self.conv = layers.Conv2D(1, self.cnn_channel, (1, self.cnn_kernel))
        self.n_proc = n_margin * 2 + 1 
        self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1))
        self.tok_embedding_freq = layers.Dense(self.cnn_dim, hid_dim)
        self.pos_embedding_freq = layers.Embedding(self.n_bin, self.hid_dim)

        self.layers_freq = [
            EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
            for _ in range(n_layers)
        ]

        self.dropout = layers.Dropout(dropout)
        self.scale_freq = tf.math.sqrt(tf.cast(self.hid_dim, tf.float32))


    def forward(self, spec_in):
        #spec_in = [batch_size, cnn_channel, n_bin, n_margin + n_frame + n_margin - (cnn_kernel - 1)] (8, 4, 256, 188)

        batch_size = spec_in.shape[0]

        spec_cnn = self.conv(tf.expand_dims(spec_in, axis=1))

        spec_cnn = tf.image.extract_patches(
            images=spec_cnn,
            sizes=[1, 1, 61, 1],
            strides=[1, 1, 1, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )

        # Equivalent to PyTorch permute
        spec_cnn = tf.transpose(spec_cnn, perm=[0, 3, 2, 1, 4])

        # Equivalent to PyTorch reshape
        spec_cnn = tf.reshape(spec_cnn, [batch_size * self.n_frame, self.n_bin, self.cnn_dim])

        spec_emb_freq = self.tok_embedding_freq(spec_cnn)

        pos_freq = tf.tile(tf.expand_dims(tf.range(self.n_bin), 0), [batch_size * self.n_frame, 1])
        pos_freq = tf.cast(pos_freq, tf.float32)

        spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embbedding_freq(pos_freq))

        for layer_freq in self.layers_freq:
            spec_freq = layer_freq(spec_freq)
        
        spec_freq = tf.reshape(spec_freq, [batch_size, self.n_frame, self.n_bin, self.hid_dim])

        return spec_freq

#### Encoder CNN block + SA freq 

In [None]:
class Encoder_CNNblock_SAfreq(tf.Module):
    def __init__(self, n_margin, n_frame, n_bin, hid_dim, n_layers, n_heads, pf_dim, dropout, dropout_convblock, device):
        super().__init__()

        self.device = device 
        self.frame = n_frame
        self.n_bin = n_bin
        self.hid_dim = hid_dim

        k = 3 
        p = 1

        layers_conv_1 = []

        ch1 = 48
        layers_conv_1.append(layers.Conv2D(ch1, ,)
        layers_conv_1.append(layers.
        layers_conv_1.append(layers.
        layers_conv_1.append(layers.
        layers_conv_1.append(layers.
        layers_conv_1.append(layers.
        layers_conv_1.append(layers.

## Preparing the DataSet

In [3]:
class DataSet:
    def __init__(self, features, label_onset, label_offset, lable_mpe, label_velocity, idx, config, n_slice):
        super().__init__()

        with open(features, 'rb') as f:
            features = pickle.load(f)
        
        with open(label_onset, 'rb') as f:
            label_onset = pickle.load(f)
        
        with open(label_offset, 'rb') as f:
            label_offset = pickle.load(f)

        with open(label_mpe, 'rb') as f:
            label_mpe = pickle.load(f)

        if label_velocity is not None:
            self.flag_velocity = True
            with open(label_velocity, 'rb') as f:
                label_velocity = pickle.load(f)
        
        else:
            self.flag_velocity = False
        
        with open(idx, 'rb') as f:
            idx = pickle.load(f)

        self.features = tf.convert_to_tensor(features, dtype=tf.float32)
        self.label_onset = tf.convert_to_tensor(label_onset, dtype=tf.float32)
        self.label_offset = tf.convert_to_tensor(label_offset, dtype=tf.float32)
        self.label_mpe = tf.convert_to_tensor(label_mpe, dtype=tf.float32)

        if self.flag_velocity:
            self.label_velocity = tf.convert_to_tensor(label_velocity, dtype=tf.float32)
        
        if n_slice > 1:
            idx_tmp = tf.convert_to_tensor(idx)
            self.idx = idx_tmp[:int(len(idx_tmp)/n_slice)*n_slice][::n_slice]
        
        else:
            self.idx = tf.convert_to_tensor(idx)

        self.features = config
        self.data_size = len(self.idx)

    def __len__(self):
        return self.data_size
    
    def __getitem__(self, idx):
        ixd_features_s = self.idx[idx] - self.config['input']['margin_b']
        ixd_features_e = self.idx[idx] + self.config['input']['num_frame'] + self.config['input']['margin_f']

        idx_label_s = self.idx[idx]
        idx_label_e = self.idx[idx] + self.config['output']['num_frame']

        spec = (self.features[ixd_features_s:idx_label_e]).T

        leble_onset = self.label_onset[idx_label_s:idx_label_e]

        label_offset = self.label_offset[idx_label_s:idx_label_e]

        label_mpe = self.label_mpe[idx_label_s:idx_label_e].float()

        if self.flag_velocity:
            label_velocity = self.label_velocity[idx_label_s:idx_label_e].long()
            return spec, leble_onset, label_offset, label_mpe, label_velocity
        else:
            return spec, leble_onset, label_offset, label_mpe




* Reminder for later, use dataset = tf.data.Dataset.from_tensor_slices((my_dataset.data, my_dataset.labels))