In [None]:
import os
import torch
import pickle
import collections
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import warnings

# ==========================
# Vocab Class Definition
# ==========================
class Vocab:
    """Vocabulary for text tokens."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        # Create list of unique tokens
        self.idx_to_token = ['<pad>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq and token != '<pad>'
        ]
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
    
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]
    
    @property
    def unk(self):
        return self.token_to_idx.get('<unk>', 1)  # Returns index 1 if '<unk>' is not in the vocabulary

# ==========================
# Window Filter Function Definition
# ==========================
def apply_window_filter(key_hat, window_size=128):
    """
    Applies a window filter to the predicted key sequence.
    
    Args:
        key_hat (1D numpy array): Array of predicted key indices.
        window_size (int): Size of the window (e.g., 64, 128).
    
    Returns:
        filtered_key_hat (1D numpy array): Filtered key sequence.
    """
    key_hat = np.array(key_hat)
    output = key_hat.copy()
    num_steps = len(key_hat)
    num_windows = num_steps // window_size

    for n in range(num_windows):
        start = n * window_size
        end = (n + 1) * window_size
        window = key_hat[start:end]
        window_non_pad = window[window != 0]

        if len(window_non_pad) == 0:
            continue  # Skip if only <pad> exists in the window

        counts = np.bincount(window_non_pad)
        most_freq = np.argmax(counts)
        freq = counts[most_freq]

        # Check if the most frequent value is unique
        if np.sum(counts == freq) > 1:
            # If multiple values have the same highest frequency, attempt to resolve using the next window
            replacement = None
            for m in range(n + 1, num_windows):
                next_start = m * window_size
                next_end = (m + 1) * window_size
                next_window = key_hat[next_start:next_end]
                next_window_non_pad = next_window[next_window != 0]
                if len(next_window_non_pad) == 0:
                    continue
                next_counts = np.bincount(next_window_non_pad)
                next_most_freq = np.argmax(next_counts)
                next_freq = next_counts[next_most_freq]
                if np.sum(next_counts == next_freq) == 1:
                    replacement = next_most_freq
                    break
            if replacement is not None:
                # Replace non-pad keys in the current window with the replacement
                window_filtered = np.where(window != 0, replacement, window)
                output[start:end] = window_filtered
            else:
                # If replacement cannot be determined, retain the original window
                continue
        else:
            # Replace non-pad keys in the window with the most frequent key
            replacement = most_freq
            window_filtered = np.where(window != 0, replacement, window)
            output[start:end] = window_filtered

    return output

# ==========================
# Dataset Class Definition
# ==========================
class MultiInputDataset(Dataset):
    def __init__(self, note, bar, key, tempo, velocity, chord):
        """
        Initializes the dataset with note, bar, key, tempo, velocity, and chord tensors.
        
        Args:
            note (Tensor): Tensor of shape (num_samples, num_steps)
            bar (Tensor): Tensor of shape (num_samples, num_steps)
            key (Tensor): Tensor of shape (num_samples, num_steps)
            tempo (Tensor): Tensor of shape (num_samples, num_steps)
            velocity (Tensor): Tensor of shape (num_samples, num_steps)
            chord (Tensor): Tensor of shape (num_samples, num_steps)
        """
        self.note = note
        self.bar = bar
        self.key = key
        self.tempo = tempo
        self.velocity = velocity
        self.chord = chord
    
    def __len__(self):
        return self.note.size(0)
    
    def __getitem__(self, idx):
        return {
            'note': self.note[idx],
            'bar': self.bar[idx],
            'key': self.key[idx],
            'tempo': self.tempo[idx],
            'velocity': self.velocity[idx],
            'chord': self.chord[idx]
        }

# ==========================
# Base Model Class Definition
# ==========================
class MultiInputModelBase(nn.Module):
    def __init__(
        self, 
        vocab_sizes,       # Dictionary of vocab sizes for each feature
        embed_dims,        # Dictionary of embedding dimensions for each feature
        hidden_size=256,   # Hidden state size
        num_layers=3, 
        num_classes=62,    # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(MultiInputModelBase, self).__init__()
        # Embedding layers (using note and bar only)
        self.embed_note = nn.Embedding(vocab_sizes['note'], embed_dims['note'], padding_idx=0)
        self.embed_bar = nn.Embedding(vocab_sizes['bar'], embed_dims['bar'], padding_idx=0)
        
        # Sum of embedding dimensions
        self.total_embed_dim = sum(embed_dims.values())

# ==========================
# DeepBiLSTM Model Definition
# ==========================
class DeepBiLSTM(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        hidden_size=256,    # Hidden state size
        num_layers=3, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(DeepBiLSTM, self).__init__(vocab_sizes, embed_dims, hidden_size, num_layers, num_classes, dropout)
        
        self.lstm = nn.LSTM(
            input_size=self.total_embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True, 
            dropout=dropout
        )
        self.layer_norm = nn.LayerNorm(hidden_size * 2)
        self.dropout_layer = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, num_classes)
    
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        out, _ = self.lstm(x)
        out = self.layer_norm(out)
        out = self.dropout_layer(out)
        out = self.fc(out)     # (batch_size, seq_len, num_classes)
        return out

# ==========================
# Positional Encoding for Transformer
# ==========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return x

# ==========================
# DenoiseTransformer Model Definition
# ==========================
class DenoiseTransformer(nn.Module):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        num_classes, 
        nhead=8, 
        num_encoder_layers=4, 
        dim_feedforward=512, 
        dropout=0.1
    ):
        super(DenoiseTransformer, self).__init__()
        
        # Embedding layers
        self.embed_note = nn.Embedding(vocab_sizes['note'], embed_dims['note'], padding_idx=0)
        self.embed_bar = nn.Embedding(vocab_sizes['bar'], embed_dims['bar'], padding_idx=0)
        self.embed_key = nn.Embedding(vocab_sizes['key'], embed_dims['key'], padding_idx=0)
        self.embed_tempo = nn.Embedding(vocab_sizes['tempo'], embed_dims['tempo'], padding_idx=0)
        self.embed_velocity = nn.Embedding(vocab_sizes['velocity'], embed_dims['velocity'], padding_idx=0)
        
        # Sum of embedding dimensions
        self.total_embed_dim = sum(embed_dims.values())
        
        # Positional Encoding
        self.pos_encoder = PositionalEncoding(self.total_embed_dim)
        
        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=self.total_embed_dim, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
        
        # Output Layer
        self.fc = nn.Linear(self.total_embed_dim, num_classes)
        
    def forward(self, inputs, src_key_padding_mask=None):
        """
        Args:
            inputs (dict): {
                'note': Tensor, 
                'bar': Tensor, 
                'key': Tensor, 
                'tempo': Tensor, 
                'velocity': Tensor
            }
            src_key_padding_mask (Tensor, optional): Padding mask (batch_size, seq_len)
        
        Returns:
            Tensor: Chord prediction results (batch_size, seq_len, num_classes)
        """
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim_note)
        bar = self.embed_bar(inputs['bar'])          # (batch_size, seq_len, embed_dim_bar)
        key = self.embed_key(inputs['key'])          # (batch_size, seq_len, embed_dim_key)
        tempo = self.embed_tempo(inputs['tempo'])    # (batch_size, seq_len, embed_dim_tempo)
        velocity = self.embed_velocity(inputs['velocity'])  # (batch_size, seq_len, embed_dim_velocity)
        
        # Concatenate embeddings
        x = torch.cat([note, bar, key, tempo, velocity], dim=-1)      # (batch_size, seq_len, total_embed_dim)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Pass through Transformer Encoder
        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        
        # Final output
        x = self.fc(x)                              # (batch_size, seq_len, num_classes)
        return x

# ==========================
# IntegratedModel Class Definition
# ==========================
class IntegratedModel(nn.Module):
    def __init__(self, deep_bilstm_model, chord_model, filter_func, window_size=128):
        """
        Initializes the IntegratedModel with DeepBiLSTM and DenoiseTransformer models.
        
        Args:
            deep_bilstm_model (DeepBiLSTM): Pretrained DeepBiLSTM model for key prediction.
            chord_model (DenoiseTransformer): Pretrained DenoiseTransformer model for chord prediction.
            filter_func (function): Function to apply filtering on key predictions.
            window_size (int): Window size for filtering.
        """
        super(IntegratedModel, self).__init__()
        self.deep_bilstm_model = deep_bilstm_model
        self.chord_model = chord_model
        self.filter_func = filter_func
        self.window_size = window_size

        # Freeze DeepBiLSTM model parameters
        for param in self.deep_bilstm_model.parameters():
            param.requires_grad = False

    def forward(self, inputs):
        """
        Forward pass of the IntegratedModel.
        
        Args:
            inputs (dict): {
                'note': Tensor, 
                'bar': Tensor, 
                'tempo': Tensor, 
                'velocity': Tensor
            }
        
        Returns:
            outputs_chord (Tensor): Chord prediction results (batch_size, seq_len, num_classes)
            predicted_key_tensor (Tensor): Filtered Key prediction results (batch_size, seq_len)
        """
        # Predict keys using DeepBiLSTM model
        with torch.no_grad():
            key_inputs = {
                'note': inputs['note'],
                'bar': inputs['bar']
            }
            outputs_key = self.deep_bilstm_model(key_inputs)  # (batch_size, seq_len, num_classes)
            _, predicted_key = torch.max(outputs_key, dim=2)  # (batch_size, seq_len)
            predicted_key = predicted_key.cpu().numpy()

            # Apply filter
            filtered_predicted_key = []
            for pred in predicted_key:
                filtered = self.filter_func(pred, window_size=self.window_size)
                filtered_predicted_key.append(filtered)

            filtered_predicted_key = np.array(filtered_predicted_key)
            predicted_key_tensor = torch.from_numpy(filtered_predicted_key).long().to(inputs['note'].device)

        # Prepare inputs for Chord prediction model
        chord_inputs = {
            'note': inputs['note'],
            'bar': inputs['bar'],
            'key': predicted_key_tensor,
            'tempo': inputs['tempo'],
            'velocity': inputs['velocity']
        }

        # Create padding mask based on 'note' inputs
        src_key_padding_mask = (inputs['note'] == 0)  # (batch_size, seq_len)
        src_key_padding_mask = src_key_padding_mask.bool()

        # Predict chords
        outputs_chord = self.chord_model(chord_inputs, src_key_padding_mask=src_key_padding_mask)  # (batch_size, seq_len, num_classes)
        return outputs_chord, predicted_key_tensor

# ==========================
# Model Saving and Loading Functions
# ==========================
def save_models_and_vocabs(models, vocabs, save_dir):
    """
    Saves multiple Vocab objects and models to the specified directory.
    
    Args:
        models (dict): Dictionary with model names as keys and model instances as values.
        vocabs (dict): Dictionary with vocab names as keys and Vocab objects as values.
        save_dir (str): Directory path to save the models and Vocab objects.
    """
    # Create directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print(f"Directory created: {save_dir}")
    else:
        print(f"Directory already exists: {save_dir}")
    
    # Save Vocab objects using pickle
    vocab_save_path = os.path.join(save_dir, 'vocabs.pkl')
    with open(vocab_save_path, 'wb') as f:
        pickle.dump(vocabs, f)
        print(f"Vocab objects saved: {vocab_save_path}")
    
    # Save each model
    for name, model in models.items():
        save_path = os.path.join(save_dir, f"{name}.pth")  # e.g., DeepBiLSTM.pth
        torch.save(model.state_dict(), save_path)
        print(f"Model '{name}' saved: {save_path}")
    
    print("All Vocab objects and models have been successfully saved.")

def load_models_and_vocabs(model_classes, save_dir, device):
    """
    Loads saved Vocab objects and multiple models from the specified directory.
    
    Args:
        model_classes (dict): Dictionary with model names as keys and model classes as values.
        save_dir (str): Directory path where models and Vocab objects are saved.
        device (torch.device): Device to load the models onto (CPU or GPU).
    
    Returns:
        loaded_models (dict): Dictionary with model names as keys and loaded model instances as values.
        vocabs (dict): Dictionary with vocab names as keys and loaded Vocab objects as values.
    """
    # Load Vocab objects using pickle
    vocab_save_path = os.path.join(save_dir, 'vocabs.pkl')
    with open(vocab_save_path, 'rb') as f:
        vocabs = pickle.load(f)
        print("Vocab objects loaded.")
    
    # Initialize a dictionary to store loaded models
    loaded_models = {}
    
    # Load each model
    for name, model_class in model_classes.items():
        # Instantiate the model
        if name == 'DeepBiLSTM':
            model = model_class(
                vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
                embed_dims={'note': 64, 'bar': 16},
                hidden_size=256,
                num_layers=3,
                num_classes=len(vocabs['vocab_key']),
                dropout=0.5
            )
        elif name == 'DenoiseTransformer':
            model = model_class(
                vocab_sizes={
                    'note': len(vocabs['vocab_note']),
                    'bar': len(vocabs['vocab_bar']),
                    'key': len(vocabs['vocab_key']),
                    'tempo': len(vocabs['vocab_tempo']),
                    'velocity': len(vocabs['vocab_velocity'])
                },
                embed_dims={
                    'note': 64,
                    'bar': 16,
                    'key': 32,
                    'tempo': 16,
                    'velocity': 16
                },
                num_classes=len(vocabs['vocab_chord']),
                nhead=8,
                num_encoder_layers=4,
                dim_feedforward=512,
                dropout=0.1
            )
        else:
            raise ValueError(f"Unknown model name: {name}")
        
        # Move model to the specified device
        model.to(device)
        
        # Load the model's state_dict
        load_path = os.path.join(save_dir, f"{name}.pth")
        if os.path.exists(load_path):
            model.load_state_dict(torch.load(load_path, map_location=device))
            print(f"Model '{name}' loaded from {load_path}")
        else:
            print(f"Model file for '{name}' does not exist at {load_path}")
            continue
        
        # Set the model to evaluation mode
        model.eval()
        
        # Add the loaded model to the dictionary
        loaded_models[name] = model
    
    print("All models and Vocab objects have been successfully loaded.")
    return loaded_models, vocabs

# ==========================
# Evaluation and Visualization Functions
# ==========================
def compute_validation_accuracy(models, valid_loader, device, filter_func, window_size=64):
    """
    Computes the validation accuracy for each model after applying a filter.
    
    Args:
        models (dict): Dictionary with model names as keys and loaded model instances as values.
        valid_loader (DataLoader): Validation DataLoader.
        device (torch.device): Device to perform computations on.
        filter_func (function): Function to apply filtering on key predictions.
        window_size (int): Window size for filtering.
    
    Returns:
        accuracy_dict (dict): Dictionary with model names as keys and accuracy scores as values.
    """
    accuracy_dict = {name: {'correct': 0, 'total': 0} for name in models.keys()}
    
    for batch in valid_loader:
        note = batch['note'].to(device)
        bar = batch['bar'].to(device)
        tempo = batch['tempo'].to(device)
        velocity = batch['velocity'].to(device)
        key = batch['key'].to(device)
        chord = batch['chord'].to(device)
        
        for name, model in models.items():
            model.eval()
            with torch.no_grad():
                if name == 'IntegratedModel':
                    # Assuming IntegratedModel handles its own forwarding
                    outputs_chord, predicted_key = model({
                        'note': note,
                        'bar': bar,
                        'tempo': tempo,
                        'velocity': velocity
                    })
                    predicted_key = predicted_key.cpu().numpy()
                else:
                    # For individual models, handle inputs accordingly
                    if name == 'DeepBiLSTM':
                        key_inputs = {
                            'note': note,
                            'bar': bar
                        }
                        outputs_key = model(key_inputs)
                        _, predicted_key = torch.max(outputs_key, dim=2)
                        predicted_key = predicted_key.cpu().numpy()
                    elif name == 'DenoiseTransformer':
                        chord_inputs = {
                            'note': note,
                            'bar': bar,
                            'key': key,
                            'tempo': tempo,
                            'velocity': velocity
                        }
                        outputs_chord = model(chord_inputs)
                        _, predicted_chord = torch.max(outputs_chord, dim=2)
                        predicted_chord = predicted_chord.cpu().numpy()
                        predicted_key = key.cpu().numpy()  # Use ground truth key for accuracy
                    else:
                        continue  # Skip unknown models
                
                # Apply filter to predicted keys
                if name == 'IntegratedModel':
                    filtered_predicted_key = np.array([filter_func(pred, window_size) for pred in predicted_key])
                    filtered_predicted_key = torch.from_numpy(filtered_predicted_key).long().to(device)
                else:
                    filtered_predicted_key = np.array([filter_func(pred, window_size) for pred in predicted_key])
                    filtered_predicted_key = torch.from_numpy(filtered_predicted_key).long().to(device)
                
                # Compute accuracy excluding <pad> tokens
                if name == 'IntegratedModel':
                    # Use predicted_chord for IntegratedModel
                    correct = ((outputs_chord.argmax(dim=2) == chord) & (chord != 0)).sum().item()
                    total = (chord != 0).sum().item()
                elif name == 'DenoiseTransformer':
                    correct = ((predicted_chord == chord) & (chord != 0)).sum().item()
                    total = (chord != 0).sum().item()
                else:
                    correct = ((filtered_predicted_key == key) & (key != 0)).sum().item()
                    total = (key != 0).sum().item()
                
                accuracy_dict[name]['correct'] += correct
                accuracy_dict[name]['total'] += total
    
    # Calculate final accuracy scores
    for name in accuracy_dict:
        if accuracy_dict[name]['total'] > 0:
            accuracy_dict[name] = accuracy_dict[name]['correct'] / accuracy_dict[name]['total']
        else:
            accuracy_dict[name] = 0.0
    
    return accuracy_dict

def midi_to_note_name(midi_numbers):
    """
    Converts a list or array of MIDI pitch numbers to a list of note names.
    
    Args:
        midi_numbers (list or array): List or array of MIDI pitch numbers
    
    Returns:
        note_names (list): List of note names
    """
    note_names_list = ['C', 'C#', 'D', 'D#', 'E', 'F',
                       'F#', 'G', 'G#', 'A', 'A#', 'B']
    note_names = []
    for num in midi_numbers:
        if num == 0:
            note_names.append('<pad>')
        else:
            note_name = note_names_list[int(num) % 12]
            note_names.append(note_name)
    return note_names

def visualize_sample_prediction(models, sample, sample_index, vocabs, start_step=0, end_step=64):
    """
    Visualizes the model's prediction results for a specific sample in the validation dataset and saves it as a PNG file.

    Args:
        models (dict): Dictionary with model names as keys and loaded model instances as values.
        sample (dict): Sample from the validation dataset.
        sample_index (int): Index of the sample in the validation dataset.
        vocabs (dict): Dictionary with vocab names as keys and Vocab objects as values.
        start_step (int): Starting time step to visualize (inclusive).
        end_step (int): Ending time step to visualize (exclusive).
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    inputs = {
        'note': sample['note'].unsqueeze(0).to(device),
        'bar': sample['bar'].unsqueeze(0).to(device),
        'tempo': sample['tempo'].unsqueeze(0).to(device),
        'velocity': sample['velocity'].unsqueeze(0).to(device)
    }
    targets_chord = sample['chord'].unsqueeze(0).to(device)
    targets_key = sample['key'].unsqueeze(0).to(device)

    # Get predictions from each model
    predictions = {}
    with torch.no_grad():
        for name, model in models.items():
            if name == 'IntegratedModel':
                outputs_chord, predicted_key = model(inputs)
                _, predicted_chord = torch.max(outputs_chord, dim=2)
                predictions[name] = {
                    'predicted_chord': predicted_chord.cpu().numpy().squeeze(),
                    'predicted_key': predicted_key.cpu().numpy().squeeze()
                }
            elif name == 'DeepBiLSTM':
                key_inputs = {
                    'note': inputs['note'],
                    'bar': inputs['bar']
                }
                outputs_key = model(key_inputs)
                _, predicted_key = torch.max(outputs_key, dim=2)
                predictions[name] = {
                    'predicted_key': predicted_key.cpu().numpy().squeeze()
                }
            elif name == 'DenoiseTransformer':
                chord_inputs = {
                    'note': inputs['note'],
                    'bar': inputs['bar'],
                    'key': targets_key,
                    'tempo': inputs['tempo'],
                    'velocity': inputs['velocity']
                }
                outputs_chord = model(chord_inputs)
                _, predicted_chord = torch.max(outputs_chord, dim=2)
                predictions[name] = {
                    'predicted_chord': predicted_chord.cpu().numpy().squeeze()
                }

    # Convert tokens to strings
    targets_chord_tokens = vocabs['vocab_chord'].to_tokens(targets_chord.cpu().numpy().squeeze())
    targets_key_tokens = vocabs['vocab_key'].to_tokens(targets_key.cpu().numpy().squeeze())
    note_tokens_indices = inputs['note'].cpu().numpy().squeeze()
    note_tokens_sample = midi_to_note_name(note_tokens_indices)

    # Plotting
    seq_length = len(targets_chord_tokens)
    start_step = max(0, start_step)
    end_step = min(end_step, seq_length)

    if start_step >= end_step:
        print(f"Start time step {start_step} is greater than or equal to the sequence length {seq_length}. Skipping visualization.")
        return

    x = range(start_step, end_step)

    # Generate set of all tokens within the specified range
    all_tokens_set = set(note_tokens_sample[start_step:end_step] + 
                         targets_chord_tokens[start_step:end_step] + 
                         targets_key_tokens[start_step:end_step])
    for pred in predictions.values():
        if 'predicted_chord' in pred:
            all_tokens_set.update(pred['predicted_chord'][start_step:end_step])
        if 'predicted_key' in pred:
            all_tokens_set.update(pred['predicted_key'][start_step:end_step])

    # Separate note names, chord names, and key names
    note_names_order = ['C', 'C#', 'D', 'D#', 'E', 'F',
                        'F#', 'G', 'G#', 'A', 'A#', 'B', '<pad>']
    note_names_in_tokens = [name for name in note_names_order if name in all_tokens_set]
    
    # Chord names (assuming chords contain ':', e.g., "C:maj")
    chord_names_in_tokens = [token for token in all_tokens_set if ':' in token]
    chord_names_in_tokens.sort()
    
    # Key names
    key_names_in_tokens = [token for token in all_tokens_set if token in vocabs['vocab_key'].idx_to_token]
    key_names_in_tokens.sort()
    
    # Other tokens
    other_tokens = [token for token in all_tokens_set if token not in note_names_in_tokens and 
                    token not in chord_names_in_tokens and token not in key_names_in_tokens]
    other_tokens.sort()
    
    # Set the order of all tokens
    sorted_all_tokens = note_names_in_tokens + chord_names_in_tokens + key_names_in_tokens + other_tokens
    
    # Map tokens to numeric indices
    token_to_index = {token: idx for idx, token in enumerate(sorted_all_tokens)}
    
    # Convert tokens to numeric indices
    note_indices_plot = [token_to_index.get(token, -1) for token in note_tokens_sample[start_step:end_step]]
    targets_chord_indices_plot = [token_to_index.get(token, -1) for token in targets_chord_tokens[start_step:end_step]]
    targets_key_indices_plot = [token_to_index.get(token, -1) for token in targets_key_tokens[start_step:end_step]]
    
    # Prepare predicted indices
    predicted_chord_indices_plot = {}
    predicted_key_indices_plot = {}
    for name, pred in predictions.items():
        if 'predicted_chord' in pred:
            predicted_chord_indices_plot[name] = [token_to_index.get(token, -1) for token in pred['predicted_chord'][start_step:end_step]]
        if 'predicted_key' in pred:
            predicted_key_indices_plot[name] = [token_to_index.get(token, -1) for token in pred['predicted_key'][start_step:end_step]]
    
    # Plotting
    plt.figure(figsize=(20, 12))
    
    # Plot Input Note
    plt.plot(x, note_indices_plot, label='Input Note', alpha=0.5, color='blue')
    
    # Plot Target Chord
    plt.plot(x, targets_chord_indices_plot, label='Target Chord', alpha=0.8, color='green')
    
    # Plot Predicted Chord
    for name, indices in predicted_chord_indices_plot.items():
        plt.plot(x, indices, label=f'Predicted Chord ({name})', linestyle='--')
    
    # Plot Target Key
    plt.plot(x, targets_key_indices_plot, label='Target Key', linestyle='-', color='black', linewidth=2)
    
    # Plot Predicted Key
    for name, indices in predicted_key_indices_plot.items():
        plt.plot(x, indices, label=f'Predicted Key ({name})', linestyle='--', linewidth=2)
    
    plt.title(f'Chord and Key Prediction (Sample {sample_index})')
    plt.xlabel('Time Step')
    plt.ylabel('Token')
    
    # Set y-axis ticks to token names
    plt.yticks(ticks=range(len(sorted_all_tokens)), labels=sorted_all_tokens)
    
    plt.legend(loc='upper right')
    plt.tight_layout()
    
    # Set directory to save the graph
    save_dir_visual = 'results'
    if not os.path.exists(save_dir_visual):
        os.makedirs(save_dir_visual)
    
    # Save the graph as PNG (including time step range in the filename)
    save_path = os.path.join(save_dir_visual, f'pred_result_{sample_index}_{start_step}_{end_step}.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Prediction result saved to {save_path}")
    
    # Display the graph
    plt.show()


In [None]:
import os
import torch
import pickle
import collections
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import warnings

# ==========================
# 1. Data Loading and Preprocessing
# ==========================
# Define data paths
x_path = r"aug_x_pop.npy"
y_path = r"aug_y_pop.npy"

# Load data
x = np.load(x_path, allow_pickle=True)  # shape: (num_samples, num_steps, 5)
y = np.load(y_path, allow_pickle=True)  # shape: (num_samples, num_steps, 2)

print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")

# Separate features from x
note_tokens = x[:, :, 0]      # (num_samples, num_steps)
bar_tokens = x[:, :, 1]       # (num_samples, num_steps)
key_tokens = x[:, :, 2]       # (num_samples, num_steps) # Actual key (ground truth)
tempo_tokens = x[:, :, 3]     # (num_samples, num_steps)
velocity_tokens = x[:, :, 4]  # (num_samples, num_steps)

# Separate feature from y (using only chord)
chord_tokens = y[:, :, 0]     # (num_samples, num_steps)

# Create token lists for each feature
note_tokens_list = note_tokens.tolist()
bar_tokens_list = bar_tokens.tolist()
key_tokens_list = key_tokens.tolist()
tempo_tokens_list = tempo_tokens.tolist()
velocity_tokens_list = velocity_tokens.tolist()
chord_tokens_list = chord_tokens.tolist()

# Load or create vocabularies (vocabs.pkl)
vocab_path = r"vocabs.pkl"
if os.path.exists(vocab_path):
    with open(vocab_path, 'rb') as f:
        vocabs = pickle.load(f)
    print("Vocabularies loaded.")
else:
    vocabs = {
        'vocab_note': Vocab(tokens=note_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_bar': Vocab(tokens=bar_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_key': Vocab(tokens=key_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_tempo': Vocab(tokens=tempo_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_velocity': Vocab(tokens=velocity_tokens_list, min_freq=1, reserved_tokens=['<unk>']),
        'vocab_chord': Vocab(tokens=chord_tokens_list, min_freq=1, reserved_tokens=['<unk>'])
    }
    # Save vocabularies
    with open(vocab_path, 'wb') as f:
        pickle.dump(vocabs, f)
    print("Vocabularies created and saved.")

print(f"Vocabulary sizes:")
for vocab_name, vocab in vocabs.items():
    print(f"{vocab_name}: {len(vocab)}")

# Convert tokens to indices
note_indices = [vocabs['vocab_note'][line] for line in note_tokens_list]
bar_indices = [vocabs['vocab_bar'][line] for line in bar_tokens_list]
key_indices = [vocabs['vocab_key'][line] for line in key_tokens_list]
tempo_indices = [vocabs['vocab_tempo'][line] for line in tempo_tokens_list]
velocity_indices = [vocabs['vocab_velocity'][line] for line in velocity_tokens_list]
chord_indices = [vocabs['vocab_chord'][line] for line in chord_tokens_list]

# Convert to numpy arrays
note_indices = np.array(note_indices)
bar_indices = np.array(bar_indices)
key_indices = np.array(key_indices)
tempo_indices = np.array(tempo_indices)
velocity_indices = np.array(velocity_indices)
chord_indices = np.array(chord_indices)

# Convert to Tensors
note_tensor = torch.from_numpy(note_indices).long()         # (num_samples, num_steps)
bar_tensor = torch.from_numpy(bar_indices).long()           # (num_samples, num_steps)
key_tensor = torch.from_numpy(key_indices).long()           # (num_samples, num_steps)
tempo_tensor = torch.from_numpy(tempo_indices).long()       # (num_samples, num_steps)
velocity_tensor = torch.from_numpy(velocity_indices).long() # (num_samples, num_steps)
chord_tensor = torch.from_numpy(chord_indices).long()       # (num_samples, num_steps)

# ==========================
# 2. Dataset and DataLoader Preparation
# ==========================
# Create the dataset (including key_tensor)
dataset = MultiInputDataset(
    note_tensor, bar_tensor, key_tensor, tempo_tensor, velocity_tensor, chord_tensor
)

# Split the dataset (80% training, 20% validation)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Valid dataset size: {len(valid_dataset)}")

# Define collate_fn function
def collate_fn(batch):
    notes = [item['note'] for item in batch]
    bars = [item['bar'] for item in batch]
    keys = [item['key'] for item in batch]
    tempos = [item['tempo'] for item in batch]
    velocities = [item['velocity'] for item in batch]
    chords = [item['chord'] for item in batch]

    seq_lengths = [len(seq) for seq in notes]
    max_length = max(seq_lengths)

    padded_notes = torch.zeros(len(batch), max_length, dtype=torch.long)
    padded_bars = torch.zeros(len(batch), max_length, dtype=torch.long)
    padded_keys = torch.zeros(len(batch), max_length, dtype=torch.long)
    padded_tempos = torch.zeros(len(batch), max_length, dtype=torch.long)
    padded_velocities = torch.zeros(len(batch), max_length, dtype=torch.long)
    padded_chords = torch.zeros(len(batch), max_length, dtype=torch.long)

    for i, length in enumerate(seq_lengths):
        padded_notes[i, :length] = batch[i]['note'][:length]
        padded_bars[i, :length] = batch[i]['bar'][:length]
        padded_keys[i, :length] = batch[i]['key'][:length]
        padded_tempos[i, :length] = batch[i]['tempo'][:length]
        padded_velocities[i, :length] = batch[i]['velocity'][:length]
        padded_chords[i, :length] = batch[i]['chord'][:length]

    return {
        'note': padded_notes,
        'bar': padded_bars,
        'key': padded_keys,
        'tempo': padded_tempos,
        'velocity': padded_velocities,
        'chord': padded_chords
    }

# Define DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)

# ==========================
# 3. Model Initialization and Loading
# ==========================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize DeepBiLSTM model
deep_bilstm_model_loaded = DeepBiLSTM(
    vocab_sizes={'note': len(vocabs['vocab_note']), 'bar': len(vocabs['vocab_bar'])},
    embed_dims={'note': 64, 'bar': 16},
    hidden_size=256,
    num_layers=3,
    num_classes=len(vocabs['vocab_key']),
    dropout=0.5
).to(device)

deep_bilstm_model_path = r"DeepBiLSTM_filtered.pth"

# Load state_dict (remove "model." prefix if present)
state_dict = torch.load(deep_bilstm_model_path, map_location=device)

# Remove "model." prefix from keys
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('model.'):
        new_key = k[6:]  # Remove "model." (first 6 characters)
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

# Load the modified state_dict into the model
deep_bilstm_model_loaded.load_state_dict(new_state_dict)

deep_bilstm_model_loaded.eval()

# Initialize Chord prediction model
chord_model_loaded = DenoiseTransformer(
    vocab_sizes={
        'note': len(vocabs['vocab_note']),
        'bar': len(vocabs['vocab_bar']),
        'key': len(vocabs['vocab_key']),
        'tempo': len(vocabs['vocab_tempo']),
        'velocity': len(vocabs['vocab_velocity'])
    },
    embed_dims={
        'note': 64,
        'bar': 16,
        'key': 32,
        'tempo': 16,
        'velocity': 16
    },
    num_classes=len(vocabs['vocab_chord']),
    nhead=8,
    num_encoder_layers=4,
    dim_feedforward=512,
    dropout=0.1
).to(device)

# Initialize IntegratedModel
integrated_model = IntegratedModel(
    deep_bilstm_model=deep_bilstm_model_loaded,
    chord_model=chord_model_loaded,
    filter_func=apply_window_filter,
    window_size=128
).to(device)

# ==========================
# 4. Loss Function and Optimizer Definition
# ==========================
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, integrated_model.parameters()), lr=0.001, weight_decay=1e-5)

# Initialize history dictionary to record loss and accuracy
history = {'train_loss': [], 'valid_loss': [], 'train_acc': [], 'valid_acc': []}

# Set directory to save models
save_dir = r"integrated_model"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f"Directory created: {save_dir}")

# Save Vocabulary
vocab_save_path = os.path.join(save_dir, 'vocabs.pkl')
with open(vocab_save_path, 'wb') as f:
    pickle.dump(vocabs, f)
    print(f"Vocabularies saved at {vocab_save_path}")

# ==========================
# 5. Suppress Warnings (Optional)
# ==========================
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.")

# ==========================
# 6. Training and Validation Loop
# ==========================
num_epochs = 100
best_valid_acc = 0.0  # Initialize best validation accuracy
start_time = time.time()  # Record training start time

for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()  # Record epoch start time
    print(f'\nEpoch {epoch}/{num_epochs}')

    # ======================
    # Training Phase
    # ======================
    integrated_model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch in train_loader:
        # Extract inputs
        inputs = {
            'note': batch['note'].to(device),
            'bar': batch['bar'].to(device),
            'tempo': batch['tempo'].to(device),
            'velocity': batch['velocity'].to(device)
        }
        targets = batch['chord'].to(device)

        optimizer.zero_grad()

        outputs_chord, predicted_key = integrated_model(inputs)  # (batch_size, seq_len, num_classes), (batch_size, seq_len)

        # Reshape for loss: (batch_size * seq_len, num_classes)
        outputs_reshaped = outputs_chord.view(-1, len(vocabs['vocab_chord']))
        targets_reshaped = targets.view(-1)  # (batch_size * seq_len)

        loss = criterion(outputs_reshaped, targets_reshaped)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * targets.size(0)

        # Compute predictions
        _, predicted = torch.max(outputs_chord, dim=2)  # (batch_size, seq_len)
        correct += ((predicted == targets) & (targets != 0)).sum().item()  # Exclude <pad> tokens
        total += (targets != 0).sum().item()

    epoch_loss = running_loss / train_size
    epoch_acc = correct / total
    history['train_loss'].append(epoch_loss)
    history['train_acc'].append(epoch_acc)

    # ======================
    # Validation Phase
    # ======================
    integrated_model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch in valid_loader:
            inputs = {
                'note': batch['note'].to(device),
                'bar': batch['bar'].to(device),
                'tempo': batch['tempo'].to(device),
                'velocity': batch['velocity'].to(device)
            }
            targets = batch['chord'].to(device)

            outputs_chord, predicted_key = integrated_model(inputs)

            outputs_reshaped = outputs_chord.view(-1, len(vocabs['vocab_chord']))
            targets_reshaped = targets.view(-1)

            loss = criterion(outputs_reshaped, targets_reshaped)
            val_loss += loss.item() * targets.size(0)

            _, predicted = torch.max(outputs_chord, dim=2)
            val_correct += ((predicted == targets) & (targets != 0)).sum().item()
            val_total += (targets != 0).sum().item()

    epoch_val_loss = val_loss / valid_size
    epoch_val_acc = val_correct / val_total
    history['valid_loss'].append(epoch_val_loss)
    history['valid_acc'].append(epoch_val_acc)

    # ======================
    # Periodic Output
    # ======================
    if epoch % 10 == 0 or epoch == 1 or epoch == num_epochs:
        print(f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, '
              f'Valid Loss: {epoch_val_loss:.4f}, Valid Acc: {epoch_val_acc:.4f}')

    # ======================
    # Save Model if Validation Accuracy Improved
    # ======================
    if epoch % 2 == 0 and epoch_val_acc > best_valid_acc:
        best_valid_acc = epoch_val_acc
        model_save_path = os.path.join(save_dir, f'IntegratedModel_epoch{epoch}_valacc{epoch_val_acc:.4f}.pt')
        torch.save(integrated_model.state_dict(), model_save_path)
        print(f"Integrated Model saved at epoch {epoch} with validation accuracy {epoch_val_acc:.4f}")

    # ======================
    # Calculate Elapsed and Remaining Time
    # ======================
    epoch_end_time = time.time()
    elapsed_time = epoch_end_time - start_time
    remaining_time = elapsed_time * (num_epochs / epoch) - elapsed_time

    # Print time information
    print(f'Epoch {epoch}/{num_epochs} completed. Time elapsed: {elapsed_time:.2f} seconds. '
          f'Estimated remaining time: {remaining_time:.2f} seconds.')

# ==========================
# 7. Loss and Accuracy Visualization
# ==========================
plt.figure(figsize=(12, 8))

# Plot Loss
plt.subplot(2, 1, 1)
plt.plot(range(1, num_epochs + 1), history['train_loss'], label='Train Loss', linestyle='-')
plt.plot(range(1, num_epochs + 1), history['valid_loss'], label='Valid Loss', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot Accuracy
plt.subplot(2, 1, 2)
plt.plot(range(1, num_epochs + 1), history['train_acc'], label='Train Acc', linestyle='-')
plt.plot(range(1, num_epochs + 1), history['valid_acc'], label='Valid Acc', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()

# Save the loss and accuracy graphs
results_dir = 'results'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
save_path = os.path.join(results_dir, 'tuned-model_train_graph.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()

# ==========================
# 8. Load Best Model and Visualize Predictions
# ==========================
# Function to find best model files based on validation accuracy
def find_best_model_files(save_dir):
    model_files = glob.glob(os.path.join(save_dir, '*.pt'))
    best_models = {}
    pattern = r'(.+)_epoch(\d+)_valacc([\d.]+)\.pt'
    for file in model_files:
        filename = os.path.basename(file)
        match = re.match(pattern, filename)
        if match:
            model_name, epoch, valacc = match.groups()
            valacc = float(valacc)
            if model_name not in best_models or valacc > best_models[model_name]['valacc']:
                best_models[model_name] = {'file': file, 'epoch': int(epoch), 'valacc': valacc}
    return best_models

# Function to evaluate model
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'note': batch['note'].to(device),
                'bar': batch['bar'].to(device),
                'tempo': batch['tempo'].to(device),
                'velocity': batch['velocity'].to(device)
            }
            targets = batch['chord'].to(device)

            outputs_chord, predicted_key = model(inputs)  # (batch_size, seq_len, num_classes), (batch_size, seq_len)
            outputs_reshaped = outputs_chord.view(-1, len(vocabs['vocab_chord']))
            targets_reshaped = targets.view(-1)

            loss = criterion(outputs_reshaped, targets_reshaped)
            val_loss += loss.item() * targets.size(0)

            _, predicted = torch.max(outputs_chord, dim=2)
            val_correct += ((predicted == targets) & (targets != 0)).sum().item()
            val_total += (targets != 0).sum().item()

    avg_loss = val_loss / len(dataloader.dataset)
    avg_acc = val_correct / val_total if val_total > 0 else 0
    return avg_loss, avg_acc

# Find best model files
best_models = find_best_model_files(save_dir)

# Initialize a dictionary to store evaluation results
evaluation_results = {}

# Evaluate each best model
for model_name, info in best_models.items():
    print(f"Evaluating best model for {model_name}: {info['file']} (Epoch {info['epoch']}, Val Acc {info['valacc']:.4f})")
    
    # Check if the model name exists in the mapping
    if model_name == 'IntegratedModel':
        # Instantiate the model
        integrated_model_loaded = IntegratedModel(
            deep_bilstm_model=deep_bilstm_model_loaded,
            chord_model=chord_model_loaded,
            filter_func=apply_window_filter,
            window_size=128
        ).to(device)
        
        # Load the model's state_dict
        integrated_model_loaded.load_state_dict(torch.load(info['file'], map_location=device))
        integrated_model_loaded.eval()
        
        # Evaluate the model
        val_loss, val_acc = evaluate_model(integrated_model_loaded, valid_loader, nn.CrossEntropyLoss(ignore_index=0), device)
        
        # Store the evaluation results
        evaluation_results[model_name] = {'val_loss': val_loss, 'val_acc': val_acc}
        print(f"{model_name} - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
    else:
        print(f"Model '{model_name}' is not recognized for evaluation.")

# ==========================
# 9. Validation Results Visualization
# ==========================
# Generate bar graphs to visualize validation loss and accuracy
plt.figure(figsize=(14, 6))

# Validation Loss Visualization
plt.subplot(1, 2, 1)
model_names = list(evaluation_results.keys())
val_losses = [evaluation_results[name]['val_loss'] for name in model_names]
plt.bar(model_names, val_losses, color='skyblue')
plt.xlabel('Model')
plt.ylabel('Validation Loss')
plt.title('Validation Loss of Best Models')
plt.xticks(rotation=45)

# Validation Accuracy Visualization
plt.subplot(1, 2, 2)
val_accs = [evaluation_results[name]['val_acc'] for name in model_names]
plt.bar(model_names, val_accs, color='salmon')
plt.xlabel('Model')
plt.ylabel('Validation Accuracy')
plt.title('Validation Accuracy of Best Models')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

# ==========================
# 10. Visualization Function Call
# ==========================
def visualize_predictions(models, vocabs, valid_dataset, valid_loader, device, filter_func, window_size=128, sample_indices=[10, 11, 13, 14], max_steps=256):
    """
    Visualizes predictions for multiple samples in the validation dataset.

    Args:
        models (dict): Dictionary with model names as keys and loaded model instances as values.
        vocabs (dict): Dictionary with vocab names as keys and Vocab objects as values.
        valid_dataset (Dataset): Validation dataset.
        valid_loader (DataLoader): Validation DataLoader.
        device (torch.device): Device to perform computations on.
        filter_func (function): Function to apply filtering on key predictions.
        window_size (int): Window size for filtering.
        sample_indices (list): List of sample indices to visualize.
        max_steps (int): Maximum number of time steps to display in the visualization.
    """
    for sample_idx in sample_indices:
        print(f"\nVisualizing Sample {sample_idx}")
        sample = valid_dataset[sample_idx]
        visualize_sample_prediction(models, sample, sample_idx, vocabs, 64, 256)

# Load models (assuming 'IntegratedModel' is the only one saved)
loaded_models, loaded_vocabs = {'IntegratedModel': integrated_model}, vocabs

# Visualize predictions for specific samples
sample_indices = [10, 11, 13, 14]  # Change as needed
visualize_predictions(
    models=loaded_models,
    vocabs=loaded_vocabs,
    valid_dataset=valid_dataset,
    valid_loader=valid_loader,
    device=device,
    filter_func=apply_window_filter,
    window_size=128,
    sample_indices=sample_indices,
    max_steps=256
)


In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from collections import Counter

def plot_confusion_matrix(N=24):
    integrated_model_loaded.eval()
    val_all_preds = []
    val_all_targets = []

    with torch.no_grad():
        for batch in valid_loader:
            inputs = {
                'note': batch['note'].to(device),
                'bar': batch['bar'].to(device),
                'tempo': batch['tempo'].to(device),
                'velocity': batch['velocity'].to(device)
            }
            targets = batch['chord'].to(device)

            outputs_chord, _ = integrated_model_loaded(inputs)

            _, predicted = torch.max(outputs_chord, dim=2)
            mask = targets != 0  # exclude <pad> token

            val_all_preds.extend(predicted[mask].cpu().numpy())
            val_all_targets.extend(targets[mask].cpu().numpy())

    # list to np
    val_all_preds = np.array(val_all_preds)
    val_all_targets = np.array(val_all_targets)

    # exclude <unk> token
    unk_token_idx = vocabs['vocab_chord'].token_to_idx.get('<unk>', None)
    if unk_token_idx is not None:
        valid_indices = (val_all_targets != unk_token_idx) & (val_all_preds != unk_token_idx)
        val_all_preds = val_all_preds[valid_indices]
        val_all_targets = val_all_targets[valid_indices]

    chord_counts = Counter(val_all_targets)
    # select top N chord
    top_n = N
    top_chords = chord_counts.most_common(top_n)
    top_chord_indices = [chord for chord, _ in top_chords]

    # idx to token name matching
    idx_to_token = vocabs['vocab_chord'].idx_to_token
    class_names = [idx_to_token[idx] for idx in top_chord_indices]

    # confusion matrix masking for top N chord
    mask = np.isin(val_all_targets, top_chord_indices)
    val_all_preds_top = val_all_preds[mask]
    val_all_targets_top = val_all_targets[mask]

    # modify range properly
    label_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(top_chord_indices)}
    val_all_preds_mapped = np.array([label_mapping.get(idx, -1) for idx in val_all_preds_top])
    val_all_targets_mapped = np.array([label_mapping.get(idx, -1) for idx in val_all_targets_top])

    # excepction processing
    valid_entries = (val_all_preds_mapped >= 0) & (val_all_targets_mapped >= 0)
    val_all_preds_mapped = val_all_preds_mapped[valid_entries]
    val_all_targets_mapped = val_all_targets_mapped[valid_entries]

    # calculating confusion matrix
    cm = confusion_matrix(val_all_targets_mapped, val_all_preds_mapped, labels=range(top_n))

    # normalization
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # save
    save_dir = 'results'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # plot
    plt.figure(figsize=(12, 10))
    plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Normalized Confusion Matrix for Major & minor Chords')
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=90)
    plt.yticks(tick_marks, class_names)

    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()

    # save
    save_path = os.path.join(save_dir, 'confusion_matrix.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Confusion matrix saved to {save_path}")

    # 그래프를 화면에 표시
    plt.show()

plot_confusion_matrix(N=24) # 24 for maj & min, 60 for all chord classes
