In [1]:
import os
import numpy as np
import pickle
import torch
from fractions import Fraction
import collections

# ==========================
# Part 1. REMI to numpy
# ==========================
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        # The list of unique tokens
        self.idx_to_token = ['<pad>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq and token != '<pad>'
        ]
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
    
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple, np.ndarray)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.token_to_idx.get(token, self.unk) for token in tokens]
    
    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]
    
    @property
    def unk(self):
        return self.token_to_idx.get('<unk>', 1)  # Returns index 1 by default if '<unk>' is not in the vocabulary

# ==========================
# 2. Definition of IntegratedRemi2Np Class (No Changes)
# ==========================
class IntegratedRemi2Np:
    def __init__(self, vocabs, num_max_bars=64, num_input=5, num_output=2):
        self.vocabs = vocabs
        self.num_max_bars = num_max_bars
        self.num_steps = num_max_bars * 16  # Quantize one bar into 16 units
        self.num_input = num_input  # note, bar, key, tempo, velocity
        self.num_output = num_output  # chord, emotion

    def process_all_files(self, filepaths):
        x_data = []
        y_data = []

        for filepath in filepaths:
            remi_data = self.load_remi_data(filepath)
            x, y = self.preprocess(remi_data)
            x_data.append(x)
            y_data.append(y)

        print("Conversion completed")
        x_stacked = np.vstack(x_data)
        y_stacked = np.vstack(y_data)

        return x_stacked, y_stacked

    def load_remi_data(self, filepath):
        # Function to read REMI data and return it in dictionary form
        remi_data = {
            "melody": self.parse_remi_file(filepath, event_types=["Bar", "Position", "Note On"]),
            "velocity": self.parse_remi_file(filepath, event_types=["Bar", "Position", "Note Velocity"]),
            "tempo": self.parse_remi_file(filepath, event_types=["Tempo Value"]),
            "chords": self.parse_remi_file(filepath, event_types=["Bar", "Position", "Chord"]),
            "key": self.parse_remi_file(filepath, event_types=["Bar", "Position", "Key"]),
            "emotion": self.parse_remi_file(filepath, event_types=["Emotion"])  # Parsing Emotion events
        }
        return remi_data

    def parse_remi_file(self, filepath, event_types):
        events = []
        with open(filepath, 'r') as file:
            for line in file:
                for event_type in event_types:
                    if f"name={event_type}" in line:
                        events.append(self.parse_event(line))
        return events

    def parse_event(self, line):
        event = {}

        # Extract name
        if "name=" in line:
            event_name = line.split("name=")[1].split(",")[0].strip()
            event['name'] = event_name

        # Extract value
        if "value=" in line:
            event_value = line.split("value=")[1].split(",")[0].strip()
            event['value'] = event_value

        # Extract text
        if "text=" in line:
            event_text = line.split("text=")[1].strip()
            event_text = event_text.rstrip(')')
            event['text'] = event_text

        # Extract time
        if "time=" in line:
            event_time = line.split("time=")[1].split(",")[0].strip()
            event['time'] = int(event_time)

        return event

    def preprocess(self, remi_data):
        # Calculate the actual length of the song
        total_bars = self.get_total_bars(remi_data['melody'])
        actual_num_steps = min(total_bars * 16, self.num_steps)

        # Initialization: Create numpy arrays filled with <pad>
        x_sequence = np.full((self.num_steps, self.num_input), self.vocabs['vocab_note'].token_to_idx['<pad>'], dtype=int)
        y_sequence = np.full((self.num_steps, self.num_output), self.vocabs['vocab_chord'].token_to_idx['<pad>'], dtype=int)

        # Process melody
        melody_sequence = self.process_melody(remi_data['melody'], actual_num_steps)
        x_sequence[:actual_num_steps, 0] = melody_sequence

        # Process bar
        bar_sequence = self.process_bar(remi_data['melody'], actual_num_steps)
        x_sequence[:actual_num_steps, 1] = bar_sequence

        # Process key
        key_sequence = self.process_key(remi_data['key'], actual_num_steps)
        x_sequence[:actual_num_steps, 2] = key_sequence

        # Process tempo
        tempo_sequence = self.process_tempo(remi_data['tempo'], actual_num_steps)
        x_sequence[:actual_num_steps, 3] = tempo_sequence

        # Process velocity
        velocity_sequence = self.process_velocity(remi_data['velocity'], actual_num_steps)
        x_sequence[:actual_num_steps, 4] = velocity_sequence

        # Process chord
        chord_sequence = self.process_chords(remi_data['chords'], actual_num_steps)
        y_sequence[:actual_num_steps, 0] = chord_sequence

        # Process emotion (if needed)
        emotion_sequence = self.process_emotion(remi_data['emotion'], actual_num_steps)
        y_sequence[:actual_num_steps, 1] = emotion_sequence

        return x_sequence.reshape(1, self.num_steps, self.num_input), y_sequence.reshape(1, self.num_steps, self.num_output)

    def get_total_bars(self, melody_events):
        total_bars = 0
        for event in melody_events:
            if event['name'] == 'Bar':
                bar_number = int(event['text'])
                if bar_number > total_bars:
                    total_bars = bar_number
        return total_bars

    def process_melody(self, melody_events, actual_num_steps):
        melody_sequence = [self.vocabs['vocab_note'].token_to_idx['<pad>']] * actual_num_steps
        current_note = self.vocabs['vocab_note'].token_to_idx['<pad>']
        current_position = None
        current_bar = 1

        for event in melody_events:
            if event['name'] == 'Bar':
                current_bar = int(event['text'])
            elif event['name'] == 'Position':
                pos = float(Fraction(event['value']))
                index = int((pos - 1) * 16 + (current_bar - 1) * 16)
                current_position = index
            elif event['name'] == 'Note On':
                pitch = int(event['value'])
                current_note = self.vocabs['vocab_note'].token_to_idx.get(str(pitch), self.vocabs['vocab_note'].unk)

            # Place the note index at the current position
            if current_position is not None and current_position < actual_num_steps:
                melody_sequence[current_position] = current_note

        # Fill forward
        for i in range(1, actual_num_steps):
            if melody_sequence[i] == self.vocabs['vocab_note'].token_to_idx['<pad>']:
                melody_sequence[i] = melody_sequence[i - 1]

        return melody_sequence

    def process_bar(self, melody_events, actual_num_steps):
        bar_sequence = [self.vocabs['vocab_bar'].token_to_idx['<pad>']] * actual_num_steps
        current_bar = 1
        for event in melody_events:
            if event['name'] == 'Bar':
                current_bar = int(event['text'])
                # Calculate the start index of the corresponding bar
                bar_start_index = (current_bar - 1) * 16
                if bar_start_index < actual_num_steps:
                    # Fill the current bar number from the start of the current bar to just before the next bar
                    end_index = min(bar_start_index + 16, actual_num_steps)
                    bar_token = self.vocabs['vocab_bar'].token_to_idx.get(str(current_bar), self.vocabs['vocab_bar'].unk)
                    for i in range(bar_start_index, end_index):
                        bar_sequence[i] = bar_token
        return bar_sequence

    def process_key(self, key_events, actual_num_steps):
        key_sequence = [self.vocabs['vocab_key'].token_to_idx['<pad>']] * actual_num_steps
        current_position = None
        current_bar = 1
        previous_key = self.vocabs['vocab_key'].token_to_idx['<pad>']

        for event in key_events:
            if event['name'] == 'Bar':
                current_bar = int(event['text'])
            elif event['name'] == 'Position':
                pos = float(Fraction(event['value']))
                index = int((pos - 1) * 16 + (current_bar - 1) * 16)
                current_position = index
            elif event['name'] == 'Key':
                key_value = event['value']
                key_token = self.vocabs['vocab_key'].token_to_idx.get(key_value, self.vocabs['vocab_key'].unk)
                if current_position is not None and current_position < actual_num_steps:
                    key_sequence[current_position] = key_token
                    previous_key = key_token

        # Fill forward
        for i in range(actual_num_steps):
            if key_sequence[i] == self.vocabs['vocab_key'].token_to_idx['<pad>']:
                key_sequence[i] = previous_key
            else:
                previous_key = key_sequence[i]

        return key_sequence

    def process_tempo(self, tempo_events, actual_num_steps):
        tempo_sequence = [self.vocabs['vocab_tempo'].token_to_idx['<pad>']] * actual_num_steps
        tempo_value = self.vocabs['vocab_tempo'].token_to_idx['<pad>']

        if tempo_events:
            for event in tempo_events:
                if event['name'] == 'Tempo Value':
                    tempo_value = self.vocabs['vocab_tempo'].token_to_idx.get(str(event['value']), self.vocabs['vocab_tempo'].unk)
                    break  # Use only the first tempo value

        tempo_sequence = [tempo_value] * actual_num_steps
        return tempo_sequence

    def process_velocity(self, velocity_events, actual_num_steps):
        velocity_sequence = [self.vocabs['vocab_velocity'].token_to_idx['<pad>']] * actual_num_steps
        current_velocity = self.vocabs['vocab_velocity'].token_to_idx['<pad>']
        current_position = None
        current_bar = 1

        for event in velocity_events:
            if event['name'] == 'Bar':
                current_bar = int(event['text'])
            elif event['name'] == 'Position':
                pos = float(Fraction(event['value']))
                index = int((pos - 1) * 16 + (current_bar - 1) * 16)
                current_position = index
            elif event['name'] == 'Note Velocity':
                velocity_value = event['value']
                velocity_token = self.vocabs['vocab_velocity'].token_to_idx.get(str(velocity_value), self.vocabs['vocab_velocity'].unk)
                current_velocity = velocity_token

            if current_position is not None and current_position < actual_num_steps:
                velocity_sequence[current_position] = current_velocity

        # Fill forward
        for i in range(1, actual_num_steps):
            if velocity_sequence[i] == self.vocabs['vocab_velocity'].token_to_idx['<pad>']:
                velocity_sequence[i] = velocity_sequence[i - 1]

        return velocity_sequence

    def process_chords(self, chord_events, actual_num_steps):
        chord_sequence = [self.vocabs['vocab_chord'].token_to_idx['<pad>']] * actual_num_steps
        current_chord = self.vocabs['vocab_chord'].token_to_idx['<pad>']
        current_position = None
        current_bar = 1

        for event in chord_events:
            if event['name'] == 'Bar':
                current_bar = int(event['text'])
            elif event['name'] == 'Position':
                pos = float(Fraction(event['value']))
                index = int((pos - 1) * 16 + (current_bar - 1) * 16)
                current_position = index
            elif event['name'] == 'Chord':
                chord_value = event['value']
                chord_token = self.vocabs['vocab_chord'].token_to_idx.get(chord_value, self.vocabs['vocab_chord'].unk)
                current_chord = chord_token

            if current_position is not None and current_position < actual_num_steps:
                chord_sequence[current_position] = current_chord

        # Fill forward
        for i in range(1, actual_num_steps):
            if chord_sequence[i] == self.vocabs['vocab_chord'].token_to_idx['<pad>']:
                chord_sequence[i] = chord_sequence[i - 1]

        return chord_sequence

    def process_emotion(self, emotion_events, actual_num_steps):
        emotion_sequence = [self.vocabs['vocab_chord'].token_to_idx['<pad>']] * actual_num_steps  # Assuming emotion is second output
        emotion_label = self.vocabs['vocab_chord'].token_to_idx['<pad>']

        if emotion_events:
            for event in emotion_events:
                if event['name'] == 'Emotion':
                    emotion_label = self.vocabs['vocab_chord'].token_to_idx.get(str(event['value']), self.vocabs['vocab_chord'].unk)
                    break  # Use only the first emotion label

        emotion_sequence = [emotion_label] * actual_num_steps
        return emotion_sequence

# Load or Create Vocabularies
vocab_path = r"../vocabs.pkl"
if os.path.exists(vocab_path):
    with open(vocab_path, 'rb') as f:
        vocabs = pickle.load(f)
    print("Vocabularies loaded.")
else:
    # Create lists of tokens from Origin data
    note_tokens_list = x_origin[:, :, 0].tolist()
    bar_tokens_list = x_origin[:, :, 1].tolist()
    key_tokens_list = x_origin[:, :, 2].tolist()
    tempo_tokens_list = x_origin[:, :, 3].tolist()
    velocity_tokens_list = x_origin[:, :, 4].tolist()
    chord_tokens_list = y_origin[:, :, 0].tolist()
    
    vocabs = {
        'vocab_note': Vocab(tokens=note_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_bar': Vocab(tokens=bar_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_key': Vocab(tokens=key_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_tempo': Vocab(tokens=tempo_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_velocity': Vocab(tokens=velocity_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_chord': Vocab(tokens=chord_tokens_list, min_freq=1, reserved_tokens=['<unk>'])
    }
    # Save Vocabulary
    with open(vocab_path, 'wb') as f:
        pickle.dump(vocabs, f)
    print("Vocabularies created and saved.")

print(f"Vocabulary sizes:")
for vocab_name, vocab in vocabs.items():
    print(f"{vocab_name}: {len(vocab)}")


def remove_key_chord_blocks(input_file_path, output_file_path):
    with open(input_file_path, 'r') as infile, open(output_file_path, 'w') as outfile:
        lines = infile.readlines()
        i = 0
        while i < len(lines):
            line = lines[i]
            if 'name=Position' in line:
                if i + 1 < len(lines):
                    next_line = lines[i + 1]
                    if 'name=Key' in next_line or 'name=Chord' in next_line:
                        i += 2
                        continue
            if 'name=Key' in line or 'name=Chord' in line:
                i += 1
                continue
            outfile.write(line)
            i += 1

# generate imperfect REMI without chord & key information. We will convert this imperfect REMI into restored REMI using our model.
input_file_path = r'001_remi.txt'
output_file_path = r'001_no_key_chord_remi.txt'
remove_key_chord_blocks(input_file_path, output_file_path)

Vocabularies loaded.
Vocabulary sizes:
vocab_note: 94
vocab_bar: 66
vocab_key: 26
vocab_tempo: 86
vocab_velocity: 129
vocab_chord: 62


In [2]:
# ==========================
# part 2. data preprocess
# ==========================

def indices_to_tokens(vocab, indices):
    tokens_list = []
    for c in indices:
        tokens = vocab.to_tokens(c)
        tokens_list.append(tokens)
    return tokens_list

# IntegratedRemi2Np class instance
preprocessor_target = IntegratedRemi2Np(vocabs=vocabs, num_max_bars=64, num_input=5, num_output=2)
preprocessor_source = IntegratedRemi2Np(vocabs=vocabs, num_max_bars=64, num_input=5, num_output=2)

# REMI files list
target_remi_files = [r"001_remi.txt"]  # original(with key & chord)
source_remi_files = [r"001_no_key_chord_remi.txt"]  # source(key & chord removed)

# convert to numpy
x_target, y_target = preprocessor_target.process_all_files(target_remi_files)
x_source, y_source = preprocessor_source.process_all_files(source_remi_files)

# save numpy dataset(if needed)
np.save(r"x_target_001.npy", x_target)
np.save(r"y_target_001.npy", y_target)

np.save(r"x_source_001.npy", x_source)
np.save(r"y_source_001.npy", y_source)

print("REMI to NP conversion completed")

# data load (if needed)
x_target_path = r"x_target_001.npy"
y_target_path = r"y_target_001.npy"

x_target = np.load(x_target_path, allow_pickle=True)  # shape: (num_samples, num_steps, 5)
y_target = np.load(y_target_path, allow_pickle=True)  # shape: (num_samples, num_steps, 2)

print(f"x_target shape: {x_target.shape}")
print(f"y_target shape: {y_target.shape}")

x_source_path = r"x_source_001.npy"
y_source_path = r"y_source_001.npy"

x_source = np.load(x_source_path, allow_pickle=True)  # shape: (1, num_steps, 5)
y_source = np.load(y_source_path, allow_pickle=True)  # shape: (1, num_steps, 2)

print(f"x_source shape: {x_source.shape}")  # (1, num_steps, 5)
print(f"y_source shape: {y_source.shape}")  # (1, num_steps, 2)

# convert index to token

x_target_tokens = {
    'note': indices_to_tokens(vocabs['vocab_note'], x_target[:, :, 0]),
    'bar': indices_to_tokens(vocabs['vocab_bar'], x_target[:, :, 1]),
    'key': indices_to_tokens(vocabs['vocab_key'], x_target[:, :, 2]),
    'tempo': indices_to_tokens(vocabs['vocab_tempo'], x_target[:, :, 3]),
    'velocity': indices_to_tokens(vocabs['vocab_velocity'], x_target[:, :, 4]),
}

y_target_tokens = {
    'chord': indices_to_tokens(vocabs['vocab_chord'], y_target[:, :, 0]),
    'emotion': indices_to_tokens(vocabs['vocab_chord'], y_target[:, :, 1]), 
}

x_source_tokens = {
    'note': indices_to_tokens(vocabs['vocab_note'], x_source[:, :, 0]),
    'bar': indices_to_tokens(vocabs['vocab_bar'], x_source[:, :, 1]),
    'key': indices_to_tokens(vocabs['vocab_key'], x_source[:, :, 2]),
    'tempo': indices_to_tokens(vocabs['vocab_tempo'], x_source[:, :, 3]),
    'velocity': indices_to_tokens(vocabs['vocab_velocity'], x_source[:, :, 4]),
}

y_source_tokens = {
    'chord': indices_to_tokens(vocabs['vocab_chord'], y_source[:, :, 0]),
    'emotion': indices_to_tokens(vocabs['vocab_chord'], y_source[:, :, 1]), 
}

Conversion completed
Conversion completed
REMI to NP conversion completed
x_target shape: (1, 1024, 5)
y_target shape: (1, 1024, 2)
x_source shape: (1, 1024, 5)
y_source shape: (1, 1024, 2)


In [3]:
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt

# ==========================
# Part 3. Model Load
# ==========================

# Multi-input Process Model
class MultiInputModelBase(nn.Module):
    def __init__(
        self, 
        vocab_sizes,     
        embed_dims,       
        hidden_size=256,  
        num_layers=3, 
        num_classes=62,    
        dropout=0.5
    ):
        super(MultiInputModelBase, self).__init__()
        # Embedding layers
        self.embed_note = nn.Embedding(vocab_sizes['note'], embed_dims['note'], padding_idx=0)
        self.embed_bar = nn.Embedding(vocab_sizes['bar'], embed_dims['bar'], padding_idx=0)
        self.total_embed_dim = sum(embed_dims.values())

# DeepBiLSTM Key Model
class DeepBiLSTM(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        hidden_size=256,   
        num_layers=3, 
        num_classes=62,  
        dropout=0.5
    ):
        super(DeepBiLSTM, self).__init__(vocab_sizes, embed_dims, hidden_size, num_layers, num_classes, dropout)
        
        self.lstm = nn.LSTM(
            input_size=self.total_embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True, 
            dropout=dropout
        )
        self.layer_norm = nn.LayerNorm(hidden_size * 2)
        self.dropout_layer = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, num_classes)
    
    def forward(self, inputs):
        # Inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        out, _ = self.lstm(x)
        out = self.layer_norm(out)
        out = self.dropout_layer(out)
        out = self.fc(out)     # (batch_size, seq_len, num_classes)
        return out

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return x

# Transformer Chord Model
class Transformer(nn.Module):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        num_classes, 
        nhead=8, 
        num_encoder_layers=4, 
        dim_feedforward=512, 
        dropout=0.1
    ):
        super(Transformer, self).__init__()
        
        # Embedding layers
        self.embed_note = nn.Embedding(vocab_sizes['note'], embed_dims['note'], padding_idx=0)
        self.embed_bar = nn.Embedding(vocab_sizes['bar'], embed_dims['bar'], padding_idx=0)
        self.embed_key = nn.Embedding(vocab_sizes['key'], embed_dims['key'], padding_idx=0)
        self.embed_tempo = nn.Embedding(vocab_sizes['tempo'], embed_dims['tempo'], padding_idx=0)
        self.embed_velocity = nn.Embedding(vocab_sizes['velocity'], embed_dims['velocity'], padding_idx=0)
        
        # Sum of embedding dimensions
        self.total_embed_dim = sum(embed_dims.values())
        
        # Positional Encoding
        self.pos_encoder = PositionalEncoding(self.total_embed_dim)
        
        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=self.total_embed_dim, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
        
        # Output layer
        self.fc = nn.Linear(self.total_embed_dim, num_classes)
        
    def forward(self, inputs, src_key_padding_mask=None):
        """
        Args:
            inputs (dict): {'note': Tensor, 'bar': Tensor, 'key': Tensor, 'tempo': Tensor, 'velocity': Tensor}
            src_key_padding_mask (Tensor, optional): Padding mask (batch_size, seq_len)
        
        Returns:
            Tensor: Chord prediction results (batch_size, seq_len, num_classes)
        """
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim_note)
        bar = self.embed_bar(inputs['bar'])          # (batch_size, seq_len, embed_dim_bar)
        key = self.embed_key(inputs['key'])          # (batch_size, seq_len, embed_dim_key)
        tempo = self.embed_tempo(inputs['tempo'])    # (batch_size, seq_len, embed_dim_tempo)
        velocity = self.embed_velocity(inputs['velocity'])  # (batch_size, seq_len, embed_dim_velocity)
        
        # Concatenate embeddings
        x = torch.cat([note, bar, key, tempo, velocity], dim=-1)      # (batch_size, seq_len, total_embed_dim)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Pass through Transformer Encoder
        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        
        # Final output
        x = self.fc(x)                              # (batch_size, seq_len, num_classes)
        return x

# 4.4 Define Filter Function (No Changes)
def apply_window_filter(key_hat, window_size=128):
    """
    key_hat: 1D numpy array of shape (num_steps,)
    window_size: int, size of the window (16 or 64 or 128)
    
    Returns:
        filtered_key_hat: 1D numpy array after applying the filter
    """
    key_hat = np.array(key_hat)
    output = key_hat.copy()
    num_steps = len(key_hat)
    num_windows = num_steps // window_size

    for n in range(num_windows):
        start = n * window_size
        end = (n + 1) * window_size
        window = key_hat[start:end]
        window_non_pad = window[window != 0]

        if len(window_non_pad) == 0:
            continue  # Skip if only <pad> tokens are present

        counts = np.bincount(window_non_pad)
        most_freq = np.argmax(counts)
        freq = counts[most_freq]

        # Check if the most frequent value is unique
        if np.sum(counts == freq) > 1:
            # If multiple modes exist, try to resolve using the next window's value
            replacement = None
            for m in range(n + 1, num_windows):
                next_start = m * window_size
                next_end = (m + 1) * window_size
                next_window = key_hat[next_start:end]
                next_window_non_pad = next_window[next_window != 0]
                if len(next_window_non_pad) == 0:
                    continue
                next_counts = np.bincount(next_window_non_pad)
                next_most_freq = np.argmax(next_counts)
                next_freq = next_counts[next_most_freq]
                if np.sum(next_counts == next_freq) == 1:
                    replacement = next_most_freq
                    break
            if replacement is not None:
                # Replace the current window's non-pad keys with the replacement
                window_filtered = np.where(window != 0, replacement, window)
                output[start:end] = window_filtered
            else:
                # If no replacement found, keep the original
                continue
        else:
            # If there is a unique mode, replace with it
            replacement = most_freq
            window_filtered = np.where(window != 0, replacement, window)
            output[start:end] = window_filtered

    return output

# 4.5 Define IntegratedModel Class (With Changes)
class IntegratedModel(nn.Module):
    def __init__(self, deep_bilstm_model, chord_model, filter_func, window_size=128):
        super(IntegratedModel, self).__init__()
        self.deep_bilstm_model = deep_bilstm_model
        self.chord_model = chord_model
        self.filter_func = filter_func
        self.window_size = window_size

        # Freeze DeepBiLSTM model parameters to prevent training
        for param in self.deep_bilstm_model.parameters():
            param.requires_grad = False

    def forward(self, inputs):
        """
        Args:
            inputs (dict): {'note': Tensor, 'bar': Tensor, 'tempo': Tensor, 'velocity': Tensor}

        Returns:
            outputs_chord (Tensor): Chord prediction results (batch_size, seq_len, num_classes)
            predicted_key_tensor (Tensor): Filtered Key prediction results (batch_size, seq_len)
        """
        # Predict key using DeepBiLSTM model
        with torch.no_grad():
            key_inputs = {
                'note': inputs['note'],
                'bar': inputs['bar']
            }
            outputs_key = self.deep_bilstm_model(key_inputs)
            _, predicted_key = torch.max(outputs_key, dim=2)  # (batch_size, seq_len)
            predicted_key = predicted_key.cpu().numpy()

            # Apply filter
            filtered_predicted_key = []
            for pred in predicted_key:
                filtered = self.filter_func(pred, window_size=self.window_size)
                filtered_predicted_key.append(filtered)

            filtered_predicted_key = np.array(filtered_predicted_key)
            predicted_key_tensor = torch.from_numpy(filtered_predicted_key).long().to(inputs['note'].device)

        # Use the predicted key as input for the Chord prediction model
        chord_inputs = {
            'note': inputs['note'],
            'bar': inputs['bar'],
            'key': predicted_key_tensor,
            'tempo': inputs['tempo'],
            'velocity': inputs['velocity']
        }

        # Create padding mask (based on note padding)
        src_key_padding_mask = (inputs['note'] == 0)  # (batch_size, seq_len)
        src_key_padding_mask = src_key_padding_mask.bool()

        # Predict chord
        outputs_chord = self.chord_model(chord_inputs, src_key_padding_mask=src_key_padding_mask)
        return outputs_chord, predicted_key_tensor

# Model Loading
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize DeepBiLSTM model
deep_bilstm_model_loaded = DeepBiLSTM(
    vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
    embed_dims={'note': 64, 'bar': 16},
    hidden_size=256,
    num_layers=3,
    num_classes=len(vocabs['vocab_key']),
    dropout=0.5
).to(device)

deep_bilstm_model_path = r"../../model_checkpoint/DeepBiLSTM_filtered.pth" # Modify the path as needed

# Modify part: Remove "model." prefix from state_dict keys
state_dict = torch.load(deep_bilstm_model_path, map_location=device)

# Remove "model." prefix from keys
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('model.'):
        new_key = k[6:]  # Remove "model." (first 6 characters)
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

# Load the modified state_dict into the model
deep_bilstm_model_loaded.load_state_dict(new_state_dict)

deep_bilstm_model_loaded.eval()

# Initialize Chord Prediction Transformer model
chord_model_loaded = Transformer(
    vocab_sizes={
        'note': len(vocabs['vocab_note']),
        'bar': len(vocabs['vocab_bar']),
        'key': len(vocabs['vocab_key']),
        'tempo': len(vocabs['vocab_tempo']),
        'velocity': len(vocabs['vocab_velocity'])
    },
    embed_dims={
        'note': 64,
        'bar': 16,
        'key': 32,
        'tempo': 16,
        'velocity': 16
    },
    num_classes=len(vocabs['vocab_chord']),
    nhead=8,
    num_encoder_layers=2,
    dim_feedforward=128,
    dropout=0.3
).to(device)

# Create IntegratedModel instance
integrated_model = IntegratedModel(
    deep_bilstm_model=deep_bilstm_model_loaded,
    chord_model=chord_model_loaded,
    filter_func=apply_window_filter,
    window_size=128
).to(device)

# Duplicate IntegratedModel for loading saved weights
integrated_model_loaded = IntegratedModel(
    deep_bilstm_model=deep_bilstm_model_loaded,
    chord_model=chord_model_loaded,
    filter_func=apply_window_filter,
    window_size=128
).to(device)

integrated_model_path = r"../../model_checkpoint/IntegratedModel_epoch200_valacc0.9303.pt" # Modify the path as needed
integrated_model_loaded.load_state_dict(torch.load(integrated_model_path, map_location=device))
integrated_model_loaded.eval()


Using device: cuda


  state_dict = torch.load(deep_bilstm_model_path, map_location=device)
  integrated_model_loaded.load_state_dict(torch.load(integrated_model_path, map_location=device))


IntegratedModel(
  (deep_bilstm_model): DeepBiLSTM(
    (embed_note): Embedding(94, 64, padding_idx=0)
    (embed_bar): Embedding(66, 16, padding_idx=0)
    (lstm): LSTM(80, 256, num_layers=3, batch_first=True, dropout=0.5, bidirectional=True)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout_layer): Dropout(p=0.5, inplace=False)
    (fc): Linear(in_features=512, out_features=26, bias=True)
  )
  (chord_model): Transformer(
    (embed_note): Embedding(94, 64, padding_idx=0)
    (embed_bar): Embedding(66, 16, padding_idx=0)
    (embed_key): Embedding(26, 32, padding_idx=0)
    (embed_tempo): Embedding(86, 16, padding_idx=0)
    (embed_velocity): Embedding(129, 16, padding_idx=0)
    (pos_encoder): PositionalEncoding()
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=144, out

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# ==========================
# Part 4. Predict Key & Chord
# ==========================

class SingleSampleDataset(Dataset):
    def __init__(self, x_data):
        self.x_data = x_data  # x_data is a numpy array of shape (1, num_steps, 5)
    
    def __len__(self):
        return 1  # Single sample
    
    def __getitem__(self, idx):
        sample = {
            'note': torch.tensor(self.x_data[0, :, 0], dtype=torch.long),
            'bar': torch.tensor(self.x_data[0, :, 1], dtype=torch.long),
            'tempo': torch.tensor(self.x_data[0, :, 3], dtype=torch.long),
            'velocity': torch.tensor(self.x_data[0, :, 4], dtype=torch.long)
        }
        return sample

# Visualization Function
def midi_to_note_name(midi_numbers):
    """
    Converts a list of MIDI pitch numbers to a list of note names.

    Args:
        midi_numbers (list or array): List or array of MIDI pitch numbers

    Returns:
        note_names (list): List of note names
    """
    note_names_list = ['C', 'C#', 'D', 'D#', 'E', 'F',
                       'F#', 'G', 'G#', 'A', 'A#', 'B']
    note_names = []
    for num in midi_numbers:
        if num == 0:
            note_names.append('<pad>')
        else:
            note_name = note_names_list[int(num) % 12]
            note_names.append(note_name)
    return note_names

"""
Visualizes the model's prediction results for a single sample.
"""

# Accuracy Calculation Function
def calculate_accuracy(predicted_tokens, target_tokens):
    """
    Calculates the accuracy between two lists of tokens.

    Args:
        predicted_tokens (list): List of tokens predicted by the model
        target_tokens (list): List of target (correct) tokens
        
    Returns:
        float: Accuracy
    """
    total = len(target_tokens)
    matches = sum(p == t for p, t in zip(predicted_tokens, target_tokens))
    accuracy = matches / total
    return accuracy

# Initialize the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Create Dataset (pass x_source)
single_sample_dataset = SingleSampleDataset(x_source)

# Create DataLoader
single_sample_loader = DataLoader(single_sample_dataset, batch_size=1, shuffle=False)

# Set the model to evaluation mode
integrated_model_loaded.eval()
with torch.no_grad():
    for batch in single_sample_loader:
        # Move input data to the device
        inputs = {
            'note': batch['note'].to(device),        # shape: (batch_size, seq_len)
            'bar': batch['bar'].to(device),
            'tempo': batch['tempo'].to(device),
            'velocity': batch['velocity'].to(device)
        }

        # Perform prediction with the model
        outputs_chord, predicted_key_tensor = integrated_model_loaded(inputs)
        _, predicted_chord = torch.max(outputs_chord, dim=2)  # predicted_chord shape: (batch_size, seq_len)

        # Move predictions to CPU and convert to numpy arrays
        predicted_chord_np = predicted_chord.cpu().numpy()        # shape: (batch_size, seq_len)
        predicted_key_np = predicted_key_tensor.cpu().numpy()    # shape: (batch_size, seq_len)

        print('Predicted chord & key generated')

        # Convert predicted key and chord to tokens
        predicted_chord_tokens = vocabs['vocab_chord'].to_tokens(predicted_chord_np.squeeze())
        predicted_key_tokens = vocabs['vocab_key'].to_tokens(predicted_key_np.squeeze())

        # Extract input note sequence (indices)
        note_indices_np = batch['note'].cpu().numpy().squeeze()   # shape: (seq_len,)
        # Convert indices to tokens (for visualization)
        note_tokens = vocabs['vocab_note'].to_tokens(note_indices_np)

# Visualization
# Sequence length
seq_len = len(note_tokens)

# Display only a subset of sequences for visualization (e.g., first 64)
num_display = min(64, seq_len)
x_axis = range(num_display)

# Convert note sequence to note names
note_names = midi_to_note_name(note_indices_np[:num_display])

# Extract actual target tokens
targets_chord_tokens = y_target_tokens['chord'][0][:num_display]
targets_key_tokens = x_target_tokens['key'][0][:num_display]

# Extract predicted tokens (already converted above)
predicted_chord_tokens_display = predicted_chord_tokens[:num_display]
predicted_key_tokens_display = predicted_key_tokens[:num_display]

# Create a set of all tokens
all_tokens_set = set(note_names +
                     targets_chord_tokens +
                     predicted_chord_tokens_display +
                     targets_key_tokens +
                     predicted_key_tokens_display)

# Separate note names and chord names
note_names_order = ['C', 'C#', 'D', 'D#', 'E', 'F',
                    'F#', 'G', 'G#', 'A', 'A#', 'B', '<pad>']
note_names_in_tokens = [name for name in note_names_order if name in all_tokens_set]

# Extract chord names (tokens containing ':')
chord_names_in_tokens = [token for token in all_tokens_set if ':' in token]
chord_names_in_tokens.sort()

# Extract key names
key_names_in_tokens = [token for token in all_tokens_set if token in vocabs['vocab_key'].idx_to_token]
key_names_in_tokens.sort()

# Extract other tokens (e.g., '<unk>', etc.)
other_tokens = [token for token in all_tokens_set if token not in note_names_in_tokens and
                token not in chord_names_in_tokens and token not in key_names_in_tokens]
other_tokens.sort()

# Set the order of all tokens
sorted_all_tokens = note_names_in_tokens + chord_names_in_tokens + key_names_in_tokens + other_tokens

# Map tokens to numerical indices
token_to_index = {token: idx for idx, token in enumerate(sorted_all_tokens)}
index_to_token = {idx: token for token, idx in token_to_index.items()}

# Convert note, target, and predicted tokens to numerical indices
note_indices_plot = [token_to_index.get(token, -1) for token in note_names[:num_display]]
targets_chord_indices_plot = [token_to_index.get(token, -1) for token in targets_chord_tokens]
predicted_chord_indices_plot = [token_to_index.get(token, -1) for token in predicted_chord_tokens_display]
targets_key_indices_plot = [token_to_index.get(token, -1) for token in targets_key_tokens]
predicted_key_indices_plot = [token_to_index.get(token, -1) for token in predicted_key_tokens_display]

# Plotting
plt.figure(figsize=(20, 8))

plt.plot(x_axis, note_indices_plot, label='Input Note (x_source)', alpha=0.5)
plt.plot(x_axis, targets_chord_indices_plot, label='Target Chord (y_target)', alpha=0.8)
plt.plot(x_axis, predicted_chord_indices_plot, label='Predicted Chord', alpha=0.8)
plt.plot(x_axis, targets_key_indices_plot, label='Target Key (x_source)', linestyle='-', color='black', linewidth=1)
plt.plot(x_axis, predicted_key_indices_plot, label='Predicted Key', linestyle='--', color='gray', linewidth=2)

plt.title('Chord and Key Prediction (Integrated Model)')
plt.xlabel('Time Step')
plt.ylabel('Token')

# Set y-axis ticks to token names
plt.yticks(ticks=range(len(sorted_all_tokens)), labels=sorted_all_tokens)

plt.legend(loc='upper right')  # Fix legend position to upper right
plt.tight_layout()
plt.savefig(r'inference_plot.png', dpi=300, bbox_inches='tight')  # Added file extension
plt.show()

# Calculate Accuracy
# Calculate accuracy between two token lists
def calculate_accuracy(predicted_tokens, target_tokens):
    """
    Calculates the accuracy between two lists of tokens.
    
    Args:
        predicted_tokens (list): List of tokens predicted by the model
        target_tokens (list): List of target (correct) tokens
        
    Returns:
        float: Accuracy
    """
    total = len(target_tokens)
    matches = sum(p == t for p, t in zip(predicted_tokens, target_tokens))
    accuracy = matches / total
    return accuracy

# Calculate accuracy for Chord and Key
chord_accuracy = calculate_accuracy(predicted_chord_tokens_display, targets_chord_tokens)
key_accuracy = calculate_accuracy(predicted_key_tokens_display, targets_key_tokens)

# Print results
print(f"Predicted Chord vs Target Chord Accuracy: {chord_accuracy:.4f}")
print(f"Predicted Key vs Target Key Accuracy: {key_accuracy:.4f}")


Using device: cuda
Predicted chord & key generated


In [None]:
import os
from fractions import Fraction
from collections import defaultdict

# ==========================
# Part 5. REMI Restoration
# ==========================

# Set file paths
input_remi_file = r"001_no_key_chord_remi.txt"
output_remi_file = r"001_processed_remi.txt"


def parse_remi_file_blocks(filepath):
    """
    Parses a REMI file into event blocks and returns a list of blocks.
    Each block consists of a list of event dictionaries.
    
    Args:
        filepath (str): Path to the REMI file
    
    Returns:
        list of tuples: List of (block_type, block_events) tuples
    """
    blocks = []
    current_block = []
    with open(filepath, 'r') as file:
        lines = file.readlines()
        num_lines = len(lines)
        i = 0
        while i < num_lines:
            line = lines[i].strip()
            if not line:
                i += 1
                continue  # Skip empty lines
            event = parse_event_line(line)
            if not event:
                i += 1
                continue  # Skip invalid events

            if event['name'] == 'Note On':
                # Note block includes the two previous events and the next event
                note_block = []
                if i >= 2:
                    prev_event1 = parse_event_line(lines[i - 2].strip())
                    prev_event2 = parse_event_line(lines[i - 1].strip())
                    if prev_event1 and prev_event2 and prev_event1['name'] == 'Position' and prev_event2['name'] == 'Note Velocity':
                        note_block.extend([prev_event1, prev_event2])
                note_block.append(event)  # 'Note On' event
                # Add the next event ('Note Duration')
                if i + 1 < num_lines:
                    next_event = parse_event_line(lines[i + 1].strip())
                    if next_event and next_event['name'] == 'Note Duration':
                        note_block.append(next_event)
                        i += 1  # Increment index as 'Note Duration' has been processed
                blocks.append(('note', note_block))
            elif event['name'] == 'Bar':
                # Bar block consists of a single event
                blocks.append(('bar', [event]))
            elif event['name'] == 'Tempo Class':
                # Tempo block includes the current event and the next event
                tempo_block = [event]
                if i + 1 < num_lines:
                    next_event = parse_event_line(lines[i + 1].strip())
                    if next_event and next_event['name'] == 'Tempo Value':
                        tempo_block.append(next_event)
                        i += 1  # Increment index as 'Tempo Value' has been processed
                blocks.append(('tempo', tempo_block))
            else:
                # Skip other events
                pass
            i += 1

    return blocks


def parse_event_line(line):
    """
    Parses a single line from a REMI file and returns an event dictionary.
    
    Args:
        line (str): A single line from the REMI file
    
    Returns:
        dict: Event dictionary
    """
    if not line.startswith('Event('):
        return None
    event = {}
    # Extract name
    if "name=" in line:
        event_name = line.split("name=")[1].split(",")[0].strip()
        event['name'] = event_name
    else:
        return None  # Skip lines without 'name='

    # Extract time
    if "time=" in line:
        event_time = line.split("time=")[1].split(",")[0].strip()
        event['time'] = int(event_time)
    else:
        event['time'] = None  # Or set to 0

    # Extract value
    if "value=" in line:
        event_value = line.split("value=")[1].split(",")[0].strip()
        event['value'] = event_value
    else:
        event['value'] = None

    # Extract text
    if "text=" in line:
        event_text = line.split("text=")[1].strip().rstrip(')')
        event['text'] = event_text
    else:
        event['text'] = None

    return event


# Define ticks
ticks_per_bar = 960  # 960 ticks per bar
ticks_per_time_step = ticks_per_bar / 16  # 60 ticks per time step


def time_step_to_tick_time(t0, time_step):
    """
    Converts a time step to tick time.
    
    Args:
        t0 (int): Initial time (ticks)
        time_step (int): Time step index
    
    Returns:
        int: Tick time
    """
    tick_time = t0 + ticks_per_time_step * time_step
    return int(tick_time)


def get_initial_time(note_dict, bar_dict):
    """
    Returns the initial time based on the first 'Note' block.
    
    Args:
        bar_dict (dict): Dictionary of Bar blocks
        note_dict (dict): Dictionary of Note blocks
    
    Returns:
        int: Initial time (ticks)
    """
    if note_dict:
        return min(note_dict.keys()) + min(bar_dict.keys())
    else:
        return 0


def get_last_time(blocks):
    """
    Returns the time of the last event in the song.
    
    Args:
        blocks (list): List of (block_type, block_events) tuples
    
    Returns:
        int: Time of the last event (ticks)
    """
    last_time = 0
    for _, block in blocks:
        for event in block:
            if event['time'] is not None and event['time'] > last_time:
                last_time = event['time']
    return last_time


def get_num_time_steps(t0, last_time, ticks_per_time_step):
    """
    Calculates the number of time steps based on the last time.
    
    Args:
        t0 (int): Initial time
        last_time (int): Last time in the song
        ticks_per_time_step (float): Ticks per time step
    
    Returns:
        int: Number of time steps in the song
    """
    total_ticks = last_time - t0
    num_time_steps = int(total_ticks / ticks_per_time_step) + 1
    return num_time_steps


def find_change_points(tokens):
    """
    Finds change points in a token sequence and returns their indices and tokens.
    
    Args:
        tokens (list): Token sequence
    
    Returns:
        list of tuples: List of (time_step_index, token) tuples
    """
    change_points = []
    prev_token = None
    for i, token in enumerate(tokens):
        if token != prev_token:
            change_points.append((i, token))
            prev_token = token
    return change_points

# Parse REMI file
blocks = parse_remi_file_blocks(input_remi_file)
print(f"Total number of blocks in the original REMI file: {len(blocks)}")


# Store Blocks in Dictionaries
note_dict = defaultdict(list)
bar_dict = {}
tempo_dict = {}
# Chord and Key are generated from model predictions, so they are excluded here

for block_type, block in blocks:
    time = block[0]['time']  # All events within a block have the same time

    if block_type == 'note':
        note_dict[time].append(block)
    elif block_type == 'bar':
        bar_dict[time] = block
    elif block_type == 'tempo':
        tempo_dict[time] = block

print(f"Number of Note blocks: {sum(len(v) for v in note_dict.values())}")
print(f"Number of Bar blocks: {len(bar_dict)}")
print(f"Number of Tempo blocks: {len(tempo_dict)}")

# Define necessary variables and functions

# Initial time (t0) extraction
def get_initial_time(note_dict, bar_dict):
    """
    Returns the initial time based on the first 'Note' block.
    
    Args:
        bar_dict (dict): Dictionary of Bar blocks
        note_dict (dict): Dictionary of Note blocks
    
    Returns:
        int: Initial time (ticks)
    """
    if note_dict:
        return min(note_dict.keys()) + min(bar_dict.keys())
    else:
        return 0


t0 = get_initial_time(note_dict, bar_dict)
print(f"Initial time (t0): {t0}")

# Extract the last time of the song
last_time_in_song = get_last_time(blocks)
print(f"Last time in the song: {last_time_in_song}")

# Calculate the number of time steps in the song
num_time_steps_in_song = get_num_time_steps(t0, last_time_in_song, ticks_per_time_step)
print(f"Number of time steps in the song: {num_time_steps_in_song}")

# Trim prediction results to match the song length
predicted_chord_tokens = predicted_chord_tokens[:num_time_steps_in_song]
predicted_key_tokens = predicted_key_tokens[:num_time_steps_in_song]

# Find change points
chord_change_points = find_change_points(predicted_chord_tokens)
key_change_points = find_change_points(predicted_key_tokens)

print(f"Chord change points: {chord_change_points}")
print(f"Key change points: {key_change_points}")

# Create Chord and Key dictionaries
chord_dict = {}
key_dict = {}

# Create Chord dictionary
for time_step, chord_token in chord_change_points:
    tick_time = time_step_to_tick_time(t0, time_step)
    position_in_bar = (time_step % 16) + 1  # From 1 to 16
    position_value = f"{position_in_bar}/16"

    # Create Position event
    position_event = {
        'name': 'Position',
        'time': tick_time,
        'value': position_value,
        'text': str(tick_time)
    }
    # Create Chord event
    chord_event = {
        'name': 'Chord',
        'time': tick_time,
        'value': chord_token,
        'text': chord_token
    }
    # Create block
    chord_block = [position_event, chord_event]
    chord_dict.setdefault(tick_time, []).append(chord_block)

# Create Key dictionary
for time_step, key_token in key_change_points:
    tick_time = time_step_to_tick_time(t0, time_step)
    position_in_bar = (time_step % 16) + 1  # From 1 to 16
    position_value = f"{position_in_bar}/16"

    # Create Position event
    position_event = {
        'name': 'Position',
        'time': tick_time,
        'value': position_value,
        'text': str(tick_time)
    }
    # Create Key event
    key_event = {
        'name': 'Key',
        'time': tick_time,
        'value': key_token,
        'text': key_token
    }
    # Create block
    key_block = [position_event, key_event]
    key_dict.setdefault(tick_time, []).append(key_block)

print(f"Number of Chord blocks created: {sum(len(v) for v in chord_dict.values())}")
print(f"Number of Key blocks created: {sum(len(v) for v in key_dict.values())}")

# Combine dictionaries
from collections import defaultdict

# Collect all time keys
all_times = set()
all_times.update(tempo_dict.keys())
all_times.update(bar_dict.keys())
all_times.update(key_dict.keys())
all_times.update(chord_dict.keys())
all_times.update(note_dict.keys())

# Define block priority: tempo > bar > key > chord > note
block_priority = {
    'tempo': 0,
    'bar': 1,
    'key': 2,
    'chord': 3,
    'note': 4,
}

# Create combined_dict
combined_dict = defaultdict(list)

for time in sorted(all_times):
    blocks_at_time = []

    # Add blocks based on priority
    if time in tempo_dict:
        blocks_at_time.append(('tempo', tempo_dict[time]))
    if time in bar_dict:
        blocks_at_time.append(('bar', bar_dict[time]))
    if time in key_dict:
        for key_block in key_dict[time]:
            blocks_at_time.append(('key', key_block))
    if time in chord_dict:
        for chord_block in chord_dict[time]:
            blocks_at_time.append(('chord', chord_block))
    if time in note_dict:
        # Add all note blocks
        for note_block in note_dict[time]:
            blocks_at_time.append(('note', note_block))

    # Sort blocks based on priority
    blocks_at_time.sort(key=lambda x: block_priority.get(x[0], 100))

    # Add to combined_dict
    for block_type, block in blocks_at_time:
        combined_dict[time].append(block)

print(f"Total number of combined time keys: {len(combined_dict)}")


# Write to REMI file
with open(output_remi_file, 'w') as output_file:
    # Save combined_dict contents in chronological order
    for time in sorted(combined_dict.keys()):
        blocks = combined_dict[time]
        for block in blocks:
            for event in block:
                event_str = f"Event(name={event['name']}, time={event['time']}, value={event['value']}, text={event['text']})"
                output_file.write(event_str + '\n')
            # output_file.write('\n')  # Add empty line to separate blocks

print(f"A new REMI file has been created: {output_remi_file}")
