In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import collections
import matplotlib.pyplot as plt
import time
import glob
import re

# ==========================
# Vocab Class Definition
# ==========================
class Vocab:
    """Vocabulary for text tokens."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        # Create list of unique tokens
        self.idx_to_token = ['<pad>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq and token != '<pad>'
        ]
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
    
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]
    
    @property
    def unk(self):
        return self.token_to_idx.get('<unk>', 1)  # Return index 1 if '<unk>' is not in the vocabulary

# ==========================
# MIDI to Note Name Conversion Function
# ==========================
def midi_to_note_name(midi_numbers):
    """
    Converts a list or array of MIDI pitch numbers to a list of note names.

    Args:
        midi_numbers (list or array): List or array of MIDI pitch numbers

    Returns:
        note_names (list): List of note names
    """
    note_names_list = ['C', 'C#', 'D', 'D#', 'E', 'F',
                       'F#', 'G', 'G#', 'A', 'A#', 'B']
    note_names = []
    for num in midi_numbers:
        if num == 0 or num == '<pad>':
            note_names.append('<pad>')
        else:
            note_name = note_names_list[int(num) % 12]
            note_names.append(note_name)
    return note_names

# ==========================
# Dataset Class Definition
# ==========================
class MultiInputDataset(Dataset):
    def __init__(self, note, bar, key):
        """
        Initializes the dataset with note, bar, and key tensors.

        Args:
            note (Tensor): Tensor of shape (num_samples, num_steps)
            bar (Tensor): Tensor of shape (num_samples, num_steps)
            key (Tensor): Tensor of shape (num_samples, num_steps)
        """
        self.note = note
        self.bar = bar
        self.key = key
    
    def __len__(self):
        return self.note.size(0)
    
    def __getitem__(self, idx):
        return {
            'note': self.note[idx],
            'bar': self.bar[idx],
            'key': self.key[idx]
        }

# ==========================
# Base Model Class Definition
# ==========================
class MultiInputModelBase(nn.Module):
    def __init__(
        self, 
        vocab_sizes,       # Dictionary of vocab sizes for each feature
        embed_dims,        # Dictionary of embedding dimensions for each feature
        hidden_size=256,   # Hidden state size
        num_layers=3, 
        num_classes=62,    # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(MultiInputModelBase, self).__init__()
        # Embedding layers (using note and bar features)
        self.embed_note = nn.Embedding(vocab_sizes['note'], embed_dims['note'], padding_idx=0)
        self.embed_bar = nn.Embedding(vocab_sizes['bar'], embed_dims['bar'], padding_idx=0)
        
        # Sum of embedding dimensions
        self.total_embed_dim = sum(embed_dims.values())

# ==========================
# Deep LSTM Model Definition
# ==========================
class DenoiseLSTM(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        hidden_size=256,    # Hidden state size
        num_layers=3, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(DenoiseLSTM, self).__init__(vocab_sizes, embed_dims, hidden_size, num_layers, num_classes, dropout)
        
        self.lstm = nn.LSTM(
            input_size=self.total_embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            dropout=dropout
        )
        self.dropout_layer = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        out, _ = self.lstm(x)
        out = self.dropout_layer(out)
        out = self.fc(out)     # (batch_size, seq_len, num_classes)
        return out

# ==========================
# Deep BiLSTM Model Definition
# ==========================
class DenoiseBiLSTM(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        hidden_size=256,    # Hidden state size
        num_layers=3, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        dropout=0.5
    ):
        super(DenoiseBiLSTM, self).__init__(vocab_sizes, embed_dims, hidden_size, num_layers, num_classes, dropout)
        
        self.lstm = nn.LSTM(
            input_size=self.total_embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True, 
            dropout=dropout
        )
        self.layer_norm = nn.LayerNorm(hidden_size * 2)
        self.dropout_layer = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, num_classes)
    
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        out, _ = self.lstm(x)
        out = self.layer_norm(out)
        out = self.dropout_layer(out)
        out = self.fc(out)     # (batch_size, seq_len, num_classes)
        return out

# ==========================
# Positional Encoding for Transformer
# ==========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return x

# ==========================
# Transformer Model Definition
# ==========================
class DenoiseTransformer(MultiInputModelBase):
    def __init__(
        self, 
        vocab_sizes, 
        embed_dims, 
        num_classes=62,     # Number of target classes (key_vocab_size)
        nhead=8, 
        num_encoder_layers=4, 
        dim_feedforward=512, 
        dropout=0.1
    ):
        super(DenoiseTransformer, self).__init__(vocab_sizes, embed_dims, num_classes=num_classes, dropout=dropout)
        
        self.pos_encoder = PositionalEncoding(self.total_embed_dim)
        encoder_layers = nn.TransformerEncoderLayer(self.total_embed_dim, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
        self.fc = nn.Linear(self.total_embed_dim, num_classes)
        
    def forward(self, inputs):
        # inputs: dictionary of tensors
        note = self.embed_note(inputs['note'])       # (batch_size, seq_len, embed_dim)
        bar = self.embed_bar(inputs['bar'])
        
        # Concatenate embeddings
        x = torch.cat([note, bar], dim=-1)  # (batch_size, seq_len, total_embed_dim)
        
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = self.fc(x)     # (batch_size, seq_len, num_classes)
        return x

# ==========================
# Function to Find Best Model Files
# ==========================
def find_best_model_files(save_dir):
    """
    Finds the best model files with the highest validation accuracy for each model type.

    Args:
        save_dir (str): Directory where the model files are saved

    Returns:
        best_models (dict): Dictionary containing the best model file info for each model type
    """
    model_files = glob.glob(os.path.join(save_dir, '*.pt'))
    best_models = {}
    pattern = r'(.+)_epoch(\d+)_valacc([\d.]+)\.pt'
    for file in model_files:
        filename = os.path.basename(file)
        match = re.match(pattern, filename)
        if match:
            model_name, epoch, valacc = match.groups()
            valacc = float(valacc)
            if model_name not in best_models or valacc > best_models[model_name]['valacc']:
                best_models[model_name] = {'file': file, 'epoch': int(epoch), 'valacc': valacc}
    return best_models

# ==========================
# Function to Evaluate Model on Validation Set
# ==========================
def evaluate_model(model, dataloader, criterion, device):
    """
    Evaluates the model on the validation dataset.

    Args:
        model (nn.Module): The trained model
        dataloader (DataLoader): Validation DataLoader
        criterion (nn.Module): Loss function
        device (torch.device): Device to run the evaluation on

    Returns:
        avg_loss (float): Average validation loss
        avg_acc (float): Average validation accuracy
    """
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'note': batch['note'].to(device),
                'bar': batch['bar'].to(device)
            }
            targets = batch['key'].to(device)
            
            outputs = model(inputs)  # (batch_size, seq_len, num_classes)
            outputs_reshaped = outputs.view(-1, num_classes)
            targets_reshaped = targets.view(-1)
            
            loss = criterion(outputs_reshaped, targets_reshaped)
            val_loss += loss.item() * targets.size(0)
            
            _, predicted = torch.max(outputs, dim=2)
            val_correct += ((predicted == targets) & (targets != 0)).sum().item()
            val_total += (targets != 0).sum().item()
    
    avg_loss = val_loss / len(dataloader.dataset)
    avg_acc = val_correct / val_total
    return avg_loss, avg_acc

# ==========================
# Visualization Function for Sample Prediction (Modified)
# ==========================
def visualize_sample_prediction(models, sample, sample_index, vocab_note, vocab_key, start_step=0, end_step=64):
    """
    Visualizes the model's prediction results for a specific sample in the validation dataset and saves it as a PNG file.

    Args:
        models (dict): Dictionary of trained models
        sample (dict): Sample from the validation dataset
        sample_index (int): Index of the sample in the validation dataset
        vocab_note (Vocab): Vocabulary for note tokens
        vocab_key (Vocab): Vocabulary for key tokens
        start_step (int): Starting time step to visualize (inclusive)
        end_step (int): Ending time step to visualize (exclusive)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    inputs = {
        'note': sample['note'].unsqueeze(0).to(device),
        'bar': sample['bar'].unsqueeze(0).to(device)
    }
    targets = sample['key'].unsqueeze(0).to(device)
    
    # Store prediction results
    predictions = {}
    
    for name, model in models.items():
        model.eval()
        with torch.no_grad():
            outputs = model(inputs)  # (1, seq_len, num_classes)
            _, predicted = torch.max(outputs, dim=2)  # (1, seq_len)
            predictions[name] = predicted.cpu().numpy().squeeze()  # (seq_len,)
    
    # Convert tokens to strings
    targets_tokens = vocab_key.to_tokens(targets.cpu().numpy().squeeze())
    note_tokens_indices = inputs['note'].cpu().numpy().squeeze()
    note_tokens_sample = midi_to_note_name(note_tokens_indices)
    
    predicted_tokens = {name: vocab_key.to_tokens(pred) for name, pred in predictions.items()}
    
    # Determine sequence length
    seq_length = len(targets_tokens)
    start_step = max(0, start_step)
    end_step = min(end_step, seq_length)
    
    if start_step >= end_step:
        print(f"Start time step {start_step} is greater than or equal to the sequence length {seq_length}. Skipping visualization.")
        return
    
    x = range(start_step, end_step)
    
    # Generate set of all tokens within the specified range
    all_tokens_set = set(note_tokens_sample[start_step:end_step] + 
                         targets_tokens[start_step:end_step])
    for pred in predicted_tokens.values():
        all_tokens_set.update(pred[start_step:end_step])
    
    # Separate note names and key names
    note_names_order = ['C', 'C#', 'D', 'D#', 'E', 'F',
                        'F#', 'G', 'G#', 'A', 'A#', 'B', '<pad>']
    note_names_in_tokens = [name for name in note_names_order if name in all_tokens_set]
    
    # Other tokens (e.g., '<unk>' etc.)
    other_tokens = [token for token in all_tokens_set if token not in note_names_in_tokens]
    other_tokens.sort()
    
    # Set the order of all tokens
    sorted_all_tokens = note_names_in_tokens + other_tokens
    
    # Map tokens to numeric indices
    token_to_index = {token: idx for idx, token in enumerate(sorted_all_tokens)}
    
    # Convert note, target, and predicted tokens to numeric indices
    note_indices_plot = [token_to_index.get(token, -1) for token in note_tokens_sample[start_step:end_step]]
    targets_indices_plot = [token_to_index.get(token, -1) for token in targets_tokens[start_step:end_step]]
    
    predicted_indices_plot = {}
    for name, pred_tokens in predicted_tokens.items():
        predicted_indices_plot[name] = [token_to_index.get(token, -1) for token in pred_tokens[start_step:end_step]]
    
    # Plotting
    plt.figure(figsize=(20, 10))
    
    plt.plot(x, note_indices_plot, label='Input Note', alpha=0.5)
    plt.plot(x, targets_indices_plot, label='Target Key', linestyle='-', color='black', linewidth=1)
    
    for name, indices in predicted_indices_plot.items():
        plt.plot(x, indices, label=f'Predicted Key ({name})', linestyle='--')
    
    plt.title(f'Key Prediction (Sample {sample_index})')
    plt.xlabel('Time Step')
    plt.ylabel('Token')
    
    # Set y-axis ticks to token names
    plt.yticks(ticks=range(len(sorted_all_tokens)), labels=sorted_all_tokens)
    
    plt.legend(loc='upper right')
    plt.tight_layout()
    
    # Set directory to save the graph
    save_dir = 'results'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Save the graph as PNG (including time step range in the filename)
    save_path = os.path.join(save_dir, f'key_pred_result_{sample_index}_{start_step}_{end_step}.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Prediction result saved to {save_path}")
    
    # Display the graph
    plt.show()


In [None]:
# ==========================
# Data Loading and Preprocessing
# ==========================
# Load data (use your provided paths)
x_path = r"example ... aug_x_pop.npy"
y_path = r"example ... aug_y_pop.npy"

# Load data
x = np.load(x_path, allow_pickle=True)  # shape: (num_samples, num_steps, 5)
y = np.load(y_path, allow_pickle=True)  # shape: (num_samples, num_steps, 2)

print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")

# Separate features from x (using note, bar, and key)
note_tokens = x[:, :, 0]      # (num_samples, num_steps)
bar_tokens = x[:, :, 1]       # (num_samples, num_steps)
key_tokens = x[:, :, 2]       # (num_samples, num_steps)

# Create token lists for each feature
note_tokens_list = note_tokens.tolist()
bar_tokens_list = bar_tokens.tolist()
key_tokens_list = key_tokens.tolist()

# Create vocabularies with '<unk>' as a reserved token
vocab_note = Vocab(tokens=note_tokens_list, min_freq=1, reserved_tokens=['<unk>'])
vocab_bar = Vocab(tokens=bar_tokens_list, min_freq=1, reserved_tokens=['<unk>'])
vocab_key = Vocab(tokens=key_tokens_list, min_freq=1, reserved_tokens=['<unk>'])

print(f"Vocabulary sizes:")
print(f"Note: {len(vocab_note)}")
print(f"Bar: {len(vocab_bar)}")
print(f"Key: {len(vocab_key)}")

# Convert tokens to indices
note_indices = [vocab_note[line] for line in note_tokens_list]
bar_indices = [vocab_bar[line] for line in bar_tokens_list]
key_indices = [vocab_key[line] for line in key_tokens_list]

# Convert to numpy arrays
note_indices = np.array(note_indices)
bar_indices = np.array(bar_indices)
key_indices = np.array(key_indices)

# Convert to Tensors
note_tensor = torch.from_numpy(note_indices).long()         # (num_samples, num_steps)
bar_tensor = torch.from_numpy(bar_indices).long()
key_tensor = torch.from_numpy(key_indices).long()

# ==========================
# Dataset and DataLoader Preparation
# ==========================
# Create the dataset (using note and bar as inputs, key as target)
dataset = MultiInputDataset(
    note_tensor, bar_tensor, key_tensor
)

# Split the dataset (80% training, 20% validation)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Valid dataset size: {len(valid_dataset)}")

# Define DataLoader (batch size 32)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# ==========================
# Model Initialization
# ==========================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Target vocab size (key_vocab_size)
num_classes = len(vocab_key)
print(f"Number of classes: {num_classes}")

# Set vocab sizes and embedding dimensions for each feature
vocab_sizes = {
    'note': len(vocab_note),
    'bar': len(vocab_bar)
}

embed_dims = {
    'note': 64,      # Increased from 32 to 64
    'bar': 16        # Increased from 8 to 16
}

# Update the model list (total of 3 models)
models = {
    'DeepLSTM': DenoiseLSTM(vocab_sizes=vocab_sizes, embed_dims=embed_dims, hidden_size=256, num_layers=3, num_classes=num_classes, dropout=0.5).to(device),
    'DeepBiLSTM': DenoiseBiLSTM(vocab_sizes=vocab_sizes, embed_dims=embed_dims, hidden_size=256, num_layers=3, num_classes=num_classes, dropout=0.5).to(device),
    'Transformer': DenoiseTransformer(vocab_sizes=vocab_sizes, embed_dims=embed_dims, num_classes=num_classes, nhead=8, num_encoder_layers=4, dim_feedforward=512, dropout=0.1).to(device),
}

# ==========================
# Loss Function and Optimizers
# ==========================
# Define loss function (CrossEntropyLoss expects (N, C) and (N,))
criterion = nn.CrossEntropyLoss(ignore_index=0)

# Define optimizers (learning rate adjusted to 0.001)
optimizers = {
    name: optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    for name, model in models.items()
}

# Dictionaries to record loss and accuracy
history = {
    name: {'train_loss': [], 'valid_loss': [], 'train_acc': [], 'valid_acc': []}
    for name in models.keys()
}

# Set directory to save models
save_dir = r"note_bar_to_key_models"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# ==========================
# Training and Validation Loop
# ==========================
num_epochs = 50
best_valid_loss = {name: float('inf') for name in models.keys()}
start_time = time.time()  # Record training start time

for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()  # Record epoch start time
    print(f'\nEpoch {epoch}/{num_epochs}')
    
    for name, model in models.items():
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for batch in train_loader:
            # Extract inputs (note and bar)
            inputs = {
                'note': batch['note'].to(device),
                'bar': batch['bar'].to(device)
            }
            targets = batch['key'].to(device)
            
            optimizer = optimizers[name]
            optimizer.zero_grad()
            
            outputs = model(inputs)  # (batch_size, seq_len, num_classes)
            
            # Reshape for loss: (batch_size * seq_len, num_classes)
            outputs_reshaped = outputs.view(-1, num_classes)
            targets_reshaped = targets.view(-1)  # (batch_size * seq_len)
            
            loss = criterion(outputs_reshaped, targets_reshaped)
            loss.backward()
            
            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += loss.item() * targets.size(0)
            
            # Compute predictions
            _, predicted = torch.max(outputs, dim=2)  # (batch_size, seq_len)
            correct += ((predicted == targets) & (targets != 0)).sum().item()  # Exclude <pad> tokens
            total += (targets != 0).sum().item()
        
        epoch_loss = running_loss / train_size
        epoch_acc = correct / total
        history[name]['train_loss'].append(epoch_loss)
        history[name]['train_acc'].append(epoch_acc)
        
        # Validation step
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for batch in valid_loader:
                inputs = {
                    'note': batch['note'].to(device),
                    'bar': batch['bar'].to(device)
                }
                targets = batch['key'].to(device)
                
                outputs = model(inputs)
                
                outputs_reshaped = outputs.view(-1, num_classes)
                targets_reshaped = targets.view(-1)
                
                loss = criterion(outputs_reshaped, targets_reshaped)
                val_loss += loss.item() * targets.size(0)
                
                _, predicted = torch.max(outputs, dim=2)
                val_correct += ((predicted == targets) & (targets != 0)).sum().item()
                val_total += (targets != 0).sum().item()
        
        epoch_val_loss = val_loss / valid_size
        epoch_val_acc = val_correct / val_total
        history[name]['valid_loss'].append(epoch_val_loss)
        history[name]['valid_acc'].append(epoch_val_acc)
        
        # Periodic output
        if epoch % 10 == 0 or epoch == 1 or epoch == num_epochs:
            print(f'[{name}] Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, '
                  f'Valid Loss: {epoch_val_loss:.4f}, Valid Acc: {epoch_val_acc:.4f}')
        
        # Save model
        if epoch % 2 == 0 and epoch_val_loss < best_valid_loss[name]:
            best_valid_loss[name] = epoch_val_loss
            model_save_path = os.path.join(save_dir, f'{name}_epoch{epoch}_valacc{epoch_val_acc:.4f}.pt')
            torch.save(model.state_dict(), model_save_path)
            print(f"Model '{name}' saved at epoch {epoch} with validation accuracy {epoch_val_acc:.4f}")
    
    # Calculate elapsed time and estimate remaining time
    epoch_end_time = time.time()
    elapsed_time = epoch_end_time - start_time
    remaining_time = elapsed_time * (num_epochs / epoch) - elapsed_time
    
    # Print time information
    print(f'Epoch {epoch}/{num_epochs} completed. Time elapsed: {elapsed_time:.2f} seconds. '
          f'Estimated remaining time: {remaining_time:.2f} seconds.')

# ==========================
# Visualize Loss and Accuracy
# ==========================
plt.figure(figsize=(20, 10))

# Visualize Loss
plt.subplot(2, 1, 1)
for name in models.keys():
    plt.plot(range(1, num_epochs + 1), history[name]['train_loss'], label=f'{name} Train', linestyle='-')
    plt.plot(range(1, num_epochs + 1), history[name]['valid_loss'], label=f'{name} Valid', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Visualize Accuracy
plt.subplot(2, 1, 2)
for name in models.keys():
    plt.plot(range(1, num_epochs + 1), history[name]['train_acc'], label=f'{name} Train', linestyle='-')
    plt.plot(range(1, num_epochs + 1), history[name]['valid_acc'], label=f'{name} Valid', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# ==========================
# Load Best Models and Evaluate
# ==========================
# Set directory where models are saved
save_dir = r"note_bar_to_key_models"

# Find best model files
best_models_info = find_best_model_files(save_dir)

# Load and evaluate each best model
for model_name, info in best_models_info.items():
    print(f"Evaluating best model for {model_name}: {info['file']} (Epoch {info['epoch']}, Val Acc {info['valacc']:.4f})")
    
    # Initialize the model
    if model_name in models:
        model = models[model_name]
    else:
        print(f"Model {model_name} is not defined.")
        continue
    
    # Load saved state_dict
    model.load_state_dict(torch.load(info['file'], map_location=device))
    model.to(device)
    
    # Evaluate the model
    val_loss, val_acc = evaluate_model(model, valid_loader, criterion, device)
    
    # Update info
    best_models_info[model_name]['valloss'] = val_loss
    best_models_info[model_name]['valacc'] = val_acc
    print(f"{model_name} - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

# ==========================
# Visualize Validation Results
# ==========================
# Generate bar graphs to visualize validation loss and accuracy
plt.figure(figsize=(14, 6))

# Validation Loss Visualization
plt.subplot(1, 2, 1)
model_names = list(best_models_info.keys())
val_losses = [best_models_info[name]['valloss'] for name in model_names]
plt.bar(model_names, val_losses, color='skyblue')
plt.xlabel('Model')
plt.ylabel('Validation Loss')
plt.title('Validation Loss of Best Models')
plt.xticks(rotation=45)

# Validation Accuracy Visualization
plt.subplot(1, 2, 2)
val_accs = [best_models_info[name]['valacc'] for name in model_names]
plt.bar(model_names, val_accs, color='salmon')
plt.xlabel('Model')
plt.ylabel('Validation Accuracy')
plt.title('Validation Accuracy of Best Models')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

# ==========================
# Visualize Sample Predictions
# ==========================
# Select a sample from the validation dataset
sample_index = 500  # You can change this index to visualize different samples
sample = valid_dataset[sample_index]

# Visualize sample prediction
visualize_sample_prediction(models, sample, sample_index, vocab_note, vocab_key, start_step=0, end_step=100)
