In [None]:
import os
import torch
import pickle
import collections
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn as nn
import numpy as np
import glob
import re
import matplotlib.pyplot as plt
import dill as pickle  # Using dill for serialization


# ==========================
# Vocab Class Definition
# ==========================
class Vocab:
    """Vocabulary for text tokens."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        # Create list of unique tokens
        self.idx_to_token = ['<pad>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq and token != '<pad>'
        ]
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
    
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]
    
    @property
    def unk(self):
        return self.token_to_idx.get('<unk>', 1)  # Returns index 1 if '<unk>' is not in the vocabulary


# ==========================
# MIDI to Note Name Conversion Function
# ==========================
def midi_to_note_name(midi_numbers):
    """
    Converts a list or array of MIDI pitch numbers to a list of note names.
    
    Args:
        midi_numbers (list or array): List or array of MIDI pitch numbers
    
    Returns:
        note_names (list): List of note names
    """
    NOTE_NAMES_LIST = ['C', 'C#', 'D', 'D#', 'E', 'F',
                       'F#', 'G', 'G#', 'A', 'A#', 'B']
    note_names = []
    for num in midi_numbers:
        if num == 0 or num == '<pad>':
            note_names.append('<pad>')
        else:
            note_name = NOTE_NAMES_LIST[int(num) % 12]
            note_names.append(note_name)
    return note_names


# ==========================
# Dataset Class Definition
# ==========================
class MultiInputDataset(Dataset):
    def __init__(self, note, bar, key, chord):
        """
        Initializes the dataset with note, bar, key, and chord tensors.
        
        Args:
            note (Tensor): Tensor of shape (num_samples, num_steps)
            bar (Tensor): Tensor of shape (num_samples, num_steps)
            key (Tensor): Tensor of shape (num_samples, num_steps)
            chord (Tensor): Tensor of shape (num_samples, num_steps)
        """
        self.note = note
        self.bar = bar
        self.key = key
        self.chord = chord
    
    def __len__(self):
        return self.note.size(0)
    
    def __getitem__(self, idx):
        return {
            'note': self.note[idx],
            'bar': self.bar[idx],
            'chord': self.chord[idx],
            'key': self.key[idx]
        }


# ==========================
# Base Model Class Definition
# ==========================
class MultiInputModelBase(nn.Module):
    def __init__(
        self, 
        vocab_sizes,       # Dictionary of vocab sizes for each feature
        embed_dims,        # Dictionary of embedding dimensions for each feature
        hidden_size=256,   # Hidden state size
        num_layers=3, 
        num_classes=62,    # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(MultiInputModelBase, self).__init__()
        # Embedding layers (using note and bar features)
        self.embed_note = nn.Embedding(vocab_sizes['note'], embed_dims['note'], padding_idx=0)
        self.embed_bar = nn.Embedding(vocab_sizes['bar'], embed_dims['bar'], padding_idx=0)
        
        # Sum of embedding dimensions
        self.total_embed_dim = sum(embed_dims.values())


# ==========================
# Deep LSTM Model Definition
# ==========================
class DenoiseLSTM(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        hidden_size=256,    # Hidden state size
        num_layers=3, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(DenoiseLSTM, self).__init__(vocab_sizes, embed_dims, hidden_size, num_layers, num_classes, dropout)
        
        self.lstm = nn.LSTM(
            input_size=self.total_embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            dropout=dropout
        )
        self.dropout_layer = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        out, _ = self.lstm(x)
        out = self.dropout_layer(out)
        out = self.fc(out)     # (batch_size, seq_len, num_classes)
        return out


# ==========================
# Deep BiLSTM Model Definition
# ==========================
class DenoiseBiLSTM(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        hidden_size=256,    # Hidden state size
        num_layers=3, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(DenoiseBiLSTM, self).__init__(vocab_sizes, embed_dims, hidden_size, num_layers, num_classes, dropout)
        
        self.lstm = nn.LSTM(
            input_size=self.total_embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True, 
            dropout=dropout
        )
        self.layer_norm = nn.LayerNorm(hidden_size * 2)
        self.dropout_layer = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, num_classes)
    
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        out, _ = self.lstm(x)
        out = self.layer_norm(out)
        out = self.dropout_layer(out)
        out = self.fc(out)     # (batch_size, seq_len, num_classes)
        return out


# ==========================
# Positional Encoding for Transformer
# ==========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):  # max_len set to 1024
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return x


# ==========================
# Transformer Model Definition
# ==========================
class DenoiseTransformer(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        nhead=8, 
        num_encoder_layers=4, 
        dim_feedforward=512, 
        dropout=0.1
    ):
        super(DenoiseTransformer, self).__init__(vocab_sizes, embed_dims, hidden_size=256, num_layers=3, num_classes=num_classes, dropout=dropout)
        
        self.pos_encoder = PositionalEncoding(self.total_embed_dim)
        encoder_layers = nn.TransformerEncoderLayer(self.total_embed_dim, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
        self.fc = nn.Linear(self.total_embed_dim, num_classes)
        
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = self.fc(x)     # (batch_size, seq_len, num_classes)
        return x


# ==========================
# Window Filter Function Definition
# ==========================
def apply_window_filter(key_hat, window_size=64):
    """
    Applies a window filter to the predicted key sequence.
    
    Args:
        key_hat (1D numpy array): Array of predicted key indices.
        window_size (int): Size of the window (e.g., 64).
    
    Returns:
        filtered_key_hat (1D numpy array): Filtered key sequence.
    """
    key_hat = np.array(key_hat)
    output = key_hat.copy()
    num_steps = len(key_hat)
    num_windows = num_steps // window_size

    for n in range(num_windows):
        start = n * window_size
        end = (n + 1) * window_size
        window = key_hat[start:end]
        window_non_pad = window[window != 0]

        if len(window_non_pad) == 0:
            continue  # If only <pad> exists in the window

        counts = np.bincount(window_non_pad)
        most_freq = np.argmax(counts)
        freq = counts[most_freq]

        # Check if the most frequent value is unique
        if np.sum(counts == freq) > 1:
            # If multiple values have the same highest frequency, attempt to resolve using the next window
            replacement = None
            for m in range(n + 1, num_windows):
                next_start = m * window_size
                next_end = (m + 1) * window_size
                next_window = key_hat[next_start:next_end]
                next_window_non_pad = next_window[next_window != 0]
                if len(next_window_non_pad) == 0:
                    continue
                next_counts = np.bincount(next_window_non_pad)
                next_most_freq = np.argmax(next_counts)
                next_freq = next_counts[next_most_freq]
                if np.sum(next_counts == next_freq) == 1:
                    replacement = next_most_freq
                    break
            if replacement is not None:
                # Replace non-pad keys in the current window with the replacement
                window_filtered = np.where(window != 0, replacement, window)
                output[start:end] = window_filtered
            else:
                # If replacement cannot be determined, retain the original window
                continue
        else:
            # Replace non-pad keys in the window with the most frequent key
            replacement = most_freq
            window_filtered = np.where(window != 0, replacement, window)
            output[start:end] = window_filtered

    return output


# ==========================
# Model Saving and Loading Functions
# ==========================
def save_models_and_vocabs(models, vocabs, save_dir):
    """
    Saves multiple Vocab objects and models to the specified directory.
    
    Args:
        models (dict): Dictionary with model names as keys and model instances as values.
        vocabs (dict): Dictionary with vocab names as keys and Vocab objects as values.
        save_dir (str): Directory path to save the models and Vocab objects.
    """
    # Create directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print(f"Directory created: {save_dir}")
    else:
        print(f"Directory already exists: {save_dir}")
    
    # Save Vocab objects using dill
    vocab_save_path = os.path.join(save_dir, 'vocabs.pkl')
    with open(vocab_save_path, 'wb') as f:
        pickle.dump(vocabs, f)
        print(f"Vocab objects saved: {vocab_save_path}")
    
    # Save each model
    for name, model in models.items():
        save_path = os.path.join(save_dir, f"{name}.pth")  # e.g., DeepLSTM.pth
        torch.save(model.state_dict(), save_path)
        print(f"Model '{name}' saved: {save_path}")
    
    print("All Vocab objects and models have been successfully saved.")


def load_models_and_vocabs(model_classes, save_dir, device):
    """
    Loads saved Vocab objects and multiple models from the specified directory.
    
    Args:
        model_classes (dict): Dictionary with model names as keys and model classes as values.
        save_dir (str): Directory path where models and Vocab objects are saved.
        device (torch.device): Device to load the models onto (CPU or GPU).
    
    Returns:
        loaded_models (dict): Dictionary with model names as keys and loaded model instances as values.
        vocabs (dict): Dictionary with vocab names as keys and loaded Vocab objects as values.
    """
    # Load Vocab objects using dill
    vocab_save_path = os.path.join(save_dir, 'vocabs.pkl')
    with open(vocab_save_path, 'rb') as f:
        vocabs = pickle.load(f)
        print("Vocab objects loaded.")
    
    # Initialize a dictionary to store loaded models
    loaded_models = {}
    
    # Load each model
    for name, model_class in model_classes.items():
        # Instantiate the model
        if name == 'DeepLSTM':
            model = model_class(
                vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
                embed_dims={'note': 64, 'bar': 16},
                hidden_size=256, num_layers=3, num_classes=len(vocabs['vocab_key']), dropout=0.5
            )
        elif name == 'DeepBiLSTM':
            model = model_class(
                vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
                embed_dims={'note': 64, 'bar': 16},
                hidden_size=256, num_layers=3, num_classes=len(vocabs['vocab_key']), dropout=0.5
            )
        elif name == 'Transformer':
            model = model_class(
                vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
                embed_dims={'note': 64, 'bar': 16},
                num_classes=len(vocabs['vocab_key']),
                nhead=8, num_encoder_layers=4, dim_feedforward=512, dropout=0.1
            )
        else:
            raise ValueError(f"Unknown model name: {name}")
        
        # Move model to the specified device
        model.to(device)
        
        # Load the model's state_dict
        load_path = os.path.join(save_dir, f"{name}.pth")
        if os.path.exists(load_path):
            model.load_state_dict(torch.load(load_path, map_location=device))
            print(f"Model '{name}' loaded from {load_path}")
        else:
            print(f"Model file for '{name}' does not exist at {load_path}")
            continue
        
        # Set the model to evaluation mode
        model.eval()
        
        # Add the loaded model to the dictionary
        loaded_models[name] = model
    
    print("All models and Vocab objects have been successfully loaded.")
    return loaded_models, vocabs


# ==========================
# Evaluation and Visualization Functions
# ==========================
def compute_validation_accuracy(models, valid_loader, device, filter_func, window_size=64):
    """
    Computes the validation accuracy for each model after applying a filter.
    
    Args:
        models (dict): Dictionary with model names as keys and loaded model instances as values.
        valid_loader (DataLoader): Validation DataLoader.
        device (torch.device): Device to perform computations on.
        filter_func (function): Function to apply filtering (e.g., apply_window_filter).
        window_size (int): Size of the window for filtering.
    
    Returns:
        accuracy_dict (dict): Dictionary with model names as keys and accuracy scores as values.
    """
    accuracy_dict = {name: {'correct': 0, 'total': 0} for name in models.keys()}
    
    for batch in valid_loader:
        note = batch['note'].to(device)
        bar = batch['bar'].to(device)
        key = batch['key'].to(device)
        chord = batch['chord'].to(device)
        
        for name, model in models.items():
            model.eval()
            with torch.no_grad():
                inputs = {
                    'note': note,
                    'bar': bar,
                    'chord': chord  # Assuming chord is used; if not, replace with zeros or appropriate tensor
                }
                outputs = model(inputs)  # (batch_size, seq_len, num_classes)
                _, predicted = torch.max(outputs, dim=2)  # (batch_size, seq_len)
                predicted = predicted.cpu().numpy()  # Convert to numpy array
                
                # Apply filter
                filtered_predicted = np.array([filter_func(pred, window_size) for pred in predicted])
                
                # Convert back to tensor
                filtered_predicted = torch.from_numpy(filtered_predicted).long().to(device)  # (batch_size, seq_len)
                
                # Compute accuracy excluding <pad> tokens
                correct = ((filtered_predicted == key) & (key != 0)).sum().item()
                total = (key != 0).sum().item()
                
                accuracy_dict[name]['correct'] += correct
                accuracy_dict[name]['total'] += total
    
    # Calculate final accuracy scores
    for name in accuracy_dict:
        if accuracy_dict[name]['total'] > 0:
            accuracy_dict[name] = accuracy_dict[name]['correct'] / accuracy_dict[name]['total']
        else:
            accuracy_dict[name] = 0.0
    
    return accuracy_dict


def visualize_predictions_with_accuracy(models, vocabs, valid_dataset, valid_loader, device, filter_func, window_size=64, sample_idx=500, max_steps=256):
    """
    Visualizes the predictions of multiple models for a specific validation sample and displays overall validation accuracy.
    
    Args:
        models (dict): Dictionary with model names as keys and loaded model instances as values.
        vocabs (dict): Dictionary with vocab names as keys and Vocab objects as values.
        valid_dataset (Dataset): Validation dataset.
        valid_loader (DataLoader): Validation DataLoader.
        device (torch.device): Device to perform computations on.
        filter_func (function): Function to apply filtering (e.g., apply_window_filter).
        window_size (int): Size of the window for filtering.
        sample_idx (int): Index of the sample in the validation dataset to visualize.
        max_steps (int): Maximum number of time steps to display in the visualization.
    """
    # Compute overall validation accuracy
    accuracy_dict = compute_validation_accuracy(models, valid_loader, device, filter_func, window_size)
    
    # Select the specific sample
    sample = valid_dataset[sample_idx]
    inputs = {
        'note': sample['note'].unsqueeze(0).to(device),
        'bar': sample['bar'].unsqueeze(0).to(device),
        'chord': sample['chord'].unsqueeze(0).to(device)  # Assuming chord is used
    }
    targets = sample['key'].unsqueeze(0).to(device)
    
    # Store predictions from each model
    predictions = {}
    
    for name, model in models.items():
        model.eval()
        with torch.no_grad():
            outputs = model(inputs)  # (1, seq_len, num_classes)
            _, predicted = torch.max(outputs, dim=2)  # (1, seq_len)
            predictions[name] = predicted.cpu().numpy().squeeze()  # (seq_len,)
    
    # Convert targets and notes to tokens
    targets_tokens = vocabs['vocab_key'].to_tokens(targets.cpu().numpy().squeeze())
    note_tokens_sample = vocabs['vocab_note'].to_tokens(inputs['note'].cpu().numpy().squeeze())
    
    # Convert predictions to tokens
    predicted_tokens = {name: vocabs['vocab_key'].to_tokens(pred) for name, pred in predictions.items()}
    
    # Apply filter to predictions
    filtered_predictions = {name: apply_window_filter(pred, window_size) for name, pred in predictions.items()}
    filtered_predicted_tokens = {name: vocabs['vocab_key'].to_tokens(pred) for name, pred in filtered_predictions.items()}
    
    # Determine the number of steps to display
    num_display = len(targets_tokens) if max_steps is None else min(len(targets_tokens), max_steps)
    
    plt.figure(figsize=(20, 15))
    
    for idx, (name, pred_tokens) in enumerate(filtered_predicted_tokens.items(), 1):
        plt.subplot(len(models), 1, idx)
        plt.plot(range(num_display), note_tokens_sample[:num_display], label='Input Note', alpha=0.5)
        plt.plot(range(num_display), targets_tokens[:num_display], label='Target Key', alpha=0.8)
        plt.plot(range(num_display), pred_tokens[:num_display], label='Predicted Key', alpha=0.8)
        plt.title(f'Prediction using {name}')
        plt.xlabel('Time Step')
        plt.ylabel('Token')
        plt.legend()
        # Display validation accuracy on the plot
        plt.text(0.99, 0.95, f'Validation Accuracy: {accuracy_dict[name]:.4f}', 
                 horizontalalignment='right',
                 verticalalignment='top',
                 transform=plt.gca().transAxes,
                 fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
    
    plt.tight_layout()
    plt.show()


In [None]:
import os
import torch
import pickle
import collections
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn as nn
import numpy as np
import glob
import re
import matplotlib.pyplot as plt

# Assuming the above definitions are in a module or have been run in the same script

# ==========================
# 1. Data Loading and Preprocessing
# ==========================
# Define data paths
x_path = r"example path ... aug_x_pop.npy"
y_path = r"example path ... aug_y_pop.npy"

# Load data
x = np.load(x_path, allow_pickle=True)  # shape: (num_samples, num_steps, 5)
y = np.load(y_path, allow_pickle=True)  # shape: (num_samples, num_steps, 2)

print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")

# Separate features from x (using note, bar, and key)
note_tokens = x[:, :, 0]      # (num_samples, num_steps)
bar_tokens = x[:, :, 1]       # (num_samples, num_steps)

# Separate features from y
key_tokens = x[:, :, 2]       # (num_samples, num_steps) # key is at x[:,:,2]
chord_tokens = y[:, :, 0]     # (num_samples, num_steps) # assuming chord is y[:,:,0]

# Create token lists for each feature
note_tokens_list = note_tokens.tolist()
bar_tokens_list = bar_tokens.tolist()
key_tokens_list = key_tokens.tolist()
chord_tokens_list = chord_tokens.tolist()

# Create vocabularies with '<unk>' as a reserved token
vocabs = {
    'vocab_note': Vocab(tokens=note_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
    'vocab_bar': Vocab(tokens=bar_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
    'vocab_key': Vocab(tokens=key_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
    'vocab_chord': Vocab(tokens=chord_tokens_list, min_freq=1, reserved_tokens=['<unk>'])
}

print(f"Vocabulary sizes:")
for vocab_name, vocab in vocabs.items():
    print(f"{vocab_name}: {len(vocab)}")

# Convert tokens to indices
note_indices = [vocabs['vocab_note'][line] for line in note_tokens_list]
bar_indices = [vocabs['vocab_bar'][line] for line in bar_tokens_list]
key_indices = [vocabs['vocab_key'][line] for line in key_tokens_list]
chord_indices = [vocabs['vocab_chord'][line] for line in chord_tokens_list]

# Convert to numpy arrays
note_indices = np.array(note_indices)
bar_indices = np.array(bar_indices)
key_indices = np.array(key_indices)
chord_indices = np.array(chord_indices)

# Convert to Tensors
note_tensor = torch.from_numpy(note_indices).long()         # (num_samples, num_steps)
bar_tensor = torch.from_numpy(bar_indices).long()           # (num_samples, num_steps)
key_tensor = torch.from_numpy(key_indices).long()           # (num_samples, num_steps)
chord_tensor = torch.from_numpy(chord_indices).long()       # (num_samples, num_steps)

# ==========================
# 2. Dataset and DataLoader Preparation
# ==========================
# Create the dataset (using note, bar, chord as inputs, key as target)
dataset = MultiInputDataset(
    note_tensor, bar_tensor, key_tensor, chord_tensor
)

# Split the dataset (80% training, 20% validation)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Valid dataset size: {len(valid_dataset)}")

# Define DataLoader (batch size 32)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# ==========================
# 3. Model Initialization
# ==========================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Target vocab size (key_vocab_size)
num_classes = len(vocabs['vocab_key'])
print(f"Number of classes: {num_classes}")

# Set vocab sizes and embedding dimensions for each feature
vocab_sizes = {
    'note': len(vocabs['vocab_note']),
    'bar': len(vocabs['vocab_bar'])
}

embed_dims = {
    'note': 64,      # Increased from 32 to 64
    'bar': 16        # Increased from 8 to 16
}

# Define model classes mapping
model_classes = {
    'DeepLSTM': DenoiseLSTM,
    'DeepBiLSTM': DenoiseBiLSTM,
    'Transformer': DenoiseTransformer
}

# Initialize models
models = {
    'DeepLSTM': DenoiseLSTM(vocab_sizes=vocab_sizes, embed_dims=embed_dims, hidden_size=256, num_layers=3, num_classes=num_classes, dropout=0.5).to(device),
    'DeepBiLSTM': DenoiseBiLSTM(vocab_sizes=vocab_sizes, embed_dims=embed_dims, hidden_size=256, num_layers=3, num_classes=num_classes, dropout=0.5).to(device),
    'Transformer': DenoiseTransformer(vocab_sizes=vocab_sizes, embed_dims=embed_dims, num_classes=num_classes, nhead=8, num_encoder_layers=4, dim_feedforward=512, dropout=0.1).to(device),
}

# ==========================
# 4. Loss Function and Optimizers
# ==========================
# Define loss function (CrossEntropyLoss expects (N, C) and (N,))
criterion = nn.CrossEntropyLoss(ignore_index=0)

# Define optimizers (learning rate adjusted to 0.001)
optimizers = {
    name: optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    for name, model in models.items()
}

# Dictionaries to record loss and accuracy
history = {
    name: {'train_loss': [], 'valid_loss': [], 'train_acc': [], 'valid_acc': []}
    for name in models.keys()
}

# Set directory to save models
save_dir = r"note_bar_to_key_models"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f"Directory created: {save_dir}")
else:
    print(f"Directory already exists: {save_dir}")

# ==========================
# 5. Training and Validation Loop
# ==========================
num_epochs = 50
best_valid_loss = {name: float('inf') for name in models.keys()}
start_time = time.time()  # Record training start time

for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()  # Record epoch start time
    print(f'\nEpoch {epoch}/{num_epochs}')
    
    for name, model in models.items():
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for batch in train_loader:
            # Extract inputs (note, bar, chord)
            inputs = {
                'note': batch['note'].to(device),
                'bar': batch['bar'].to(device),
                'chord': batch['chord'].to(device)
            }
            targets = batch['key'].to(device)
            
            optimizer = optimizers[name]
            optimizer.zero_grad()
            
            outputs = model(inputs)  # (batch_size, seq_len, num_classes)
            
            # Reshape for loss: (batch_size * seq_len, num_classes)
            outputs_reshaped = outputs.view(-1, num_classes)
            targets_reshaped = targets.view(-1)  # (batch_size * seq_len)
            
            loss = criterion(outputs_reshaped, targets_reshaped)
            loss.backward()
            
            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += loss.item() * targets.size(0)
            
            # Compute predictions
            _, predicted = torch.max(outputs, dim=2)  # (batch_size, seq_len)
            correct += ((predicted == targets) & (targets != 0)).sum().item()  # Exclude <pad> tokens
            total += (targets != 0).sum().item()
        
        epoch_loss = running_loss / train_size
        epoch_acc = correct / total
        history[name]['train_loss'].append(epoch_loss)
        history[name]['train_acc'].append(epoch_acc)
        
        # Validation step
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for batch in valid_loader:
                inputs = {
                    'note': batch['note'].to(device),
                    'bar': batch['bar'].to(device),
                    'chord': batch['chord'].to(device)
                }
                targets = batch['key'].to(device)
                
                outputs = model(inputs)
                
                outputs_reshaped = outputs.view(-1, num_classes)
                targets_reshaped = targets.view(-1)
                
                loss = criterion(outputs_reshaped, targets_reshaped)
                val_loss += loss.item() * targets.size(0)
                
                _, predicted = torch.max(outputs, dim=2)
                val_correct += ((predicted == targets) & (targets != 0)).sum().item()
                val_total += (targets != 0).sum().item()
        
        epoch_val_loss = val_loss / valid_size
        epoch_val_acc = val_correct / val_total
        history[name]['valid_loss'].append(epoch_val_loss)
        history[name]['valid_acc'].append(epoch_val_acc)
        
        # Periodic output
        if epoch % 10 == 0 or epoch == 1 or epoch == num_epochs:
            print(f'[{name}] Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, '
                  f'Valid Loss: {epoch_val_loss:.4f}, Valid Acc: {epoch_val_acc:.4f}')
        
        # Save model if validation loss has improved
        if epoch % 2 == 0 and epoch_val_loss < best_valid_loss[name]:
            best_valid_loss[name] = epoch_val_loss
            model_save_path = os.path.join(save_dir, f'{name}_epoch{epoch}_valacc{epoch_val_acc:.4f}.pt')
            torch.save(model.state_dict(), model_save_path)
            print(f"Model '{name}' saved at epoch {epoch} with validation accuracy {epoch_val_acc:.4f}")
    
    # Calculate elapsed time and estimate remaining time
    epoch_end_time = time.time()
    elapsed_time = epoch_end_time - start_time
    remaining_time = elapsed_time * (num_epochs / epoch) - elapsed_time
    
    # Print time information
    print(f'Epoch {epoch}/{num_epochs} completed. Time elapsed: {elapsed_time:.2f} seconds. '
          f'Estimated remaining time: {remaining_time:.2f} seconds.')

# ==========================
# 6. Visualize Loss and Accuracy
# ==========================
plt.figure(figsize=(20, 10))

# Visualize Loss
plt.subplot(2, 1, 1)
for name in models.keys():
    plt.plot(range(1, num_epochs + 1), history[name]['train_loss'], label=f'{name} Train', linestyle='-')
    plt.plot(range(1, num_epochs + 1), history[name]['valid_loss'], label=f'{name} Valid', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Visualize Accuracy
plt.subplot(2, 1, 2)
for name in models.keys():
    plt.plot(range(1, num_epochs + 1), history[name]['train_acc'], label=f'{name} Train', linestyle='-')
    plt.plot(range(1, num_epochs + 1), history[name]['valid_acc'], label=f'{name} Valid', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# ==========================
# 7. Model Loading and Evaluation
# ==========================
# Define a mapping from model names to their classes
model_classes_mapping = {
    'DeepLSTM': DenoiseLSTM,
    'DeepBiLSTM': DenoiseBiLSTM,
    'Transformer': DenoiseTransformer
}

# Define the directory where models are saved
save_dir = r"" # use your path

# Find best model files based on validation accuracy
best_models = find_best_model_files(save_dir)

# Initialize a dictionary to store evaluation results
evaluation_results = {}

# Evaluate each best model
for model_name, info in best_models.items():
    print(f"Evaluating best model for {model_name}: {info['file']} (Epoch {info['epoch']}, Val Acc {info['valacc']:.4f})")
    
    # Check if the model name exists in the mapping
    if model_name in model_classes_mapping:
        model_class = model_classes_mapping[model_name]
        # Instantiate the model
        if model_name == 'Transformer':
            model = model_class(
                vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
                embed_dims={'note': 64, 'bar': 16},
                num_classes=len(vocabs['vocab_key']),
                nhead=8, num_encoder_layers=4, dim_feedforward=512, dropout=0.1
            )
        else:
            model = model_class(
                vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
                embed_dims={'note': 64, 'bar': 16},
                hidden_size=256, num_layers=3, num_classes=len(vocabs['vocab_key']), dropout=0.5
            )
        model.to(device)
    else:
        print(f"Model class for '{model_name}' is not defined.")
        continue
    
    # Load the model's state_dict
    model.load_state_dict(torch.load(info['file'], map_location=device))
    model.to(device)
    
    # Evaluate the model
    val_loss, val_acc = evaluate_model(model, valid_loader, nn.CrossEntropyLoss(ignore_index=0), device)
    
    # Store the evaluation results
    evaluation_results[model_name] = {'val_loss': val_loss, 'val_acc': val_acc}
    print(f"{model_name} - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

# ==========================
# 8. Visualize Validation Results
# ==========================
# Generate bar graphs to visualize validation loss and accuracy
plt.figure(figsize=(14, 6))

# Validation Loss Visualization
plt.subplot(1, 2, 1)
model_names = list(evaluation_results.keys())
val_losses = [evaluation_results[name]['val_loss'] for name in model_names]
plt.bar(model_names, val_losses, color='skyblue')
plt.xlabel('Model')
plt.ylabel('Validation Loss')
plt.title('Validation Loss of Best Models')
plt.xticks(rotation=45)

# Validation Accuracy Visualization
plt.subplot(1, 2, 2)
val_accs = [evaluation_results[name]['val_acc'] for name in model_names]
plt.bar(model_names, val_accs, color='salmon')
plt.xlabel('Model')
plt.ylabel('Validation Accuracy')
plt.title('Validation Accuracy of Best Models')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

# ==========================
# 9. Visualization Function Call
# ==========================
# Load models and vocabs
loaded_models, loaded_vocabs = load_models_and_vocabs(model_classes_mapping, save_dir, device)

# Define window size for filtering
window_size = 128  # Set the window size as needed

# Visualize predictions for a specific sample
sample_idx = 15  # Index of the sample to visualize
max_steps = 256  # Maximum number of time steps to display

visualize_predictions_with_accuracy(
    models=loaded_models,
    vocabs=vocabs,
    valid_dataset=valid_dataset,
    valid_loader=valid_loader,
    device=device,
    filter_func=apply_window_filter,
    window_size=window_size,
    sample_idx=sample_idx,
    max_steps=max_steps
)