In [None]:
# %% [markdown]
# # Integrated Pipeline Example with LSTM Model
# This Notebook integrates the training and evaluation of Presence and Type models across two stages and adds a note position prediction model based on LSTM.
# 
# Please ensure to create a `model` folder before running to save the model files.

# %% [markdown]
# ## Import Necessary Libraries

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # Move to the top

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
import json
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from typing import List, Tuple, Dict, Any
import enum
import tqdm
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score, mean_squared_error, mean_absolute_error, r2_score

# %% [markdown]
# ## Constants Definition

# STFT constants
SAMPLE_RATE = 22050  
HOP_LENGTH = 512     
NMELS = 128        
WINDOW_SIZE = 40  # Number of frames before and after
NUM_EPOCHS = 50    # Number of training epochs
BATCH_SIZE = 64    # Batch size
LEARNING_RATE = 1e-3
DROPOUT = 0.5      # Dropout rate

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# %% [markdown]
# ## Data Preprocessing Functions

def contains_non_ascii(s: str) -> bool:
    """Check if the string contains non-ASCII characters."""
    return any(ord(c) > 127 for c in s)

def extract_level_json_multiple(directories: List[Path], min_difficulty: int = 15) -> Dict[str, Any]:
    """
    Extract level.json files from multiple directories, organize relevant information, and filter songs by specified difficulty.
    If an unparseable name or other issue is encountered, skip the song.

    Args:
        directories (List[Path]): List of main directories containing multiple subfolders.
        min_difficulty (int): Minimum difficulty level.

    Returns:
        Dict[str, Any]: Contains relevant information for each level, limited to songs with difficulty >= min_difficulty.
    """
    result = {}
    skipped_songs = 0
    skipped_reasons = defaultdict(int)

    for directory in directories:
        if not directory.exists():
            print(f"Directory does not exist: {directory}")
            skipped_reasons['missing_directory'] += 1
            continue
        for folder_path in directory.iterdir():
            if not folder_path.is_dir():
                continue
            json_file_path = folder_path / 'level.json'
            if not json_file_path.is_file():
                print(f"Missing level.json file in {folder_path}")
                skipped_songs += 1
                skipped_reasons['missing_level_json'] += 1
                continue
            try:
                with json_file_path.open('r', encoding='utf-8') as json_file:
                    level_data = json.load(json_file)
            except json.JSONDecodeError:
                print(f"JSON parse error: {json_file_path}")
                skipped_songs += 1
                skipped_reasons['json_decode_error'] += 1
                continue
            except Exception as e:
                print(f"Unable to read {json_file_path}: {e}")
                skipped_songs += 1
                skipped_reasons['read_error'] += 1
                continue

            # Ensure all necessary fields exist
            try:
                level_id = level_data['id']
                charts = level_data['charts']
                music = level_data['music']
            except KeyError as e:
                print(f"Missing key {e} in file: {json_file_path}")
                skipped_songs += 1
                skipped_reasons['missing_keys'] += 1
                continue

            if not charts:
                print(f"No charts found in file {json_file_path}")
                skipped_songs += 1
                skipped_reasons['empty_charts'] += 1
                continue

            chart_difficulty = charts[0].get('difficulty', 0)
            if chart_difficulty < min_difficulty:
                continue

            audio_file_name = music.get('path', '')
            if not audio_file_name:
                print(f"No music path specified in file {json_file_path}")
                skipped_songs += 1
                skipped_reasons['missing_music_path'] += 1
                continue

            audio_file_extensions = ['.mp3', '.ogg', '.wav']
            audio_file_path = None
            for ext in audio_file_extensions:
                aud_path = folder_path / audio_file_name
                if aud_path.suffix.lower() == ext and aud_path.is_file():
                    audio_file_path = aud_path
                    break
            if audio_file_path is None:
                print(f"Audio file {audio_file_name} for song ID {level_id} not found in {folder_path}")
                skipped_songs += 1
                skipped_reasons['missing_audio_file'] += 1
                continue

            charts_path = folder_path / charts[0].get('path', '')
            if not charts_path.is_file():
                print(f"Charts file {charts[0].get('path', '')} for song ID {level_id} not found in {folder_path}")
                skipped_songs += 1
                skipped_reasons['missing_charts_file'] += 1
                continue

            # Create unique ID, ensure name is parseable
            unique_id = f"{directory.name}_{level_id}"
            try:
                unique_id.encode('ascii')  # Check if ASCII
            except UnicodeEncodeError:
                print(f"Unparseable unique_id: {unique_id}, skipping song")
                skipped_songs += 1
                skipped_reasons['unparseable_unique_id'] += 1
                continue

            # Add to result
            result[unique_id] = {
                'level': level_data,
                'mp3_path': str(audio_file_path),
                'charts_path': str(charts_path),
                'charter': level_data.get('charter', ''),
                'type': charts[0].get('type', ''),
                'difficulty': chart_difficulty
            }

    print(f"Total number of skipped songs: {skipped_songs}")
    for reason, count in skipped_reasons.items():
        print(f"Skipped reason '{reason}': {count} songs")
    return result

def extract_charts(path: str) -> Dict[str, Any]:
    """
    Extract chart data from a JSON file.

    Args:
        path (str): Path to the chart JSON file.

    Returns:
        Dict[str, Any]: Chart data.
    """
    file_path = Path(path)
    if file_path.exists() and file_path.is_file():
        with open(file_path, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                return data
            except json.JSONDecodeError:
                print(f"JSON decode error for file: {path}")
    return {}

def find_single_tempo_songs(data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Filter songs with constant BPM.

    Args:
        data (Dict[str, Any]): Dictionary containing all song information.

    Returns:
        List[Dict[str, Any]]: List of songs with constant BPM.
    """
    single_tempo_songs = []
    for song_id, song in data.items():
        charts_data = extract_charts(song['charts_path'])
        if charts_data and 'tempo_list' in charts_data:
            if len(charts_data['tempo_list']) == 1:
                single_tempo_songs.append(song)
    return single_tempo_songs

def map_note_to_time(data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Map notes to time.

    Args:
        data (Dict[str, Any]): Chart data.

    Returns:
        List[Dict[str, Any]]: Time mapping information for each note.
    """
    time_base = data.get('time_base', 1000) 
    offset_universal = 0.033 
    offset = data.get('music_offset', 0) - offset_universal
    tempo_list = sorted(data.get('tempo_list', []), key=lambda x: x['tick'])  
    note_list = data.get('note_list', [])
    
    note_time_map = []
    accumulated_time = 0 
    last_tick = 0  
    if not tempo_list:
        return note_time_map
    current_tempo = tempo_list[0]['value']  
    tempo_index = 0  

    for note in note_list:
        note_tick = note['tick']
        while tempo_index < len(tempo_list) - 1 and tempo_list[tempo_index + 1]['tick'] <= note_tick:
            next_tempo_tick = tempo_list[tempo_index + 1]['tick']
            ticks_in_interval = next_tempo_tick - last_tick
            tick_duration = (current_tempo / time_base) 
            accumulated_time += ticks_in_interval * tick_duration
            last_tick = next_tempo_tick
            tempo_index += 1
            current_tempo = tempo_list[tempo_index]['value']

        ticks_in_interval = note_tick - last_tick
        tick_duration = (current_tempo / time_base) 
        note_time = accumulated_time + ticks_in_interval * tick_duration
        note_time_map.append({
            'note_id': note.get('id', 0),
            'note_tick': note_tick,
            'note_time_microseconds': note_time - offset * 1_000_000,
            'note_type': note.get('type', 0),
            'note_x': note.get('x', 0.0)
        })

    return note_time_map

def generate_mel_spectrogram(
    audio_path: Path,
    log_enable: bool = True,
    bpm_info: List[Dict[str, float]] = None,
    note_info: List[Dict[str, Any]] = None,
    max_frames: int = 5000  # New parameter to limit maximum frames
) -> dict:
    """
    Generate Mel spectrogram and corresponding labels, limiting its length to a maximum of max_frames.
    
    Args:
        audio_path (Path): Path to the audio file.
        log_enable (bool): Whether to apply logarithmic transformation.
        bpm_info (List[Dict[str, float]]): BPM information.
        note_info (List[Dict[str, Any]]): Note information.
        max_frames (int): Maximum number of frames.
    
    Returns:
        dict: Dictionary containing Mel spectrogram, presence labels, and position_labels.
    """
    data, sr = librosa.load(str(audio_path), sr=SAMPLE_RATE)
    assert sr == SAMPLE_RATE, f"Expected sample rate {SAMPLE_RATE}, but got {sr}"

    mel = librosa.feature.melspectrogram(
        y=data,
        sr=sr, 
        hop_length=HOP_LENGTH, 
        fmin=30.0, 
        n_mels=NMELS, 
        htk=True
    )
    if log_enable:
        mel = np.log(np.clip(mel, 1e-5, None))
    mel = mel.T  # (Time steps, Features)

    # Limit the length of the Mel spectrogram
    if mel.shape[0] > max_frames:
        mel = mel[:max_frames]

    data_dic = {"mel": mel}

    # Initialize presence labels and position_labels
    presence_labels = np.zeros(mel.shape[0], dtype=int)  # Presence labels
    position_labels = -1 * np.ones(mel.shape[0], dtype=int)  # -1 indicates no note

    if bpm_info and note_info:
        mel_length = mel.shape[0]
        for note in note_info:
            time_sec = note['note_time_microseconds'] / 1_000_000
            frame_idx = int(time_sec * SAMPLE_RATE / HOP_LENGTH)
            if 0 <= frame_idx < mel_length:  # Ensure frame_idx is non-negative and within range
                presence_labels[frame_idx] = 1  # Presence
                # Calculate relative position to window center (assuming window size of 40)
                position = frame_idx  # Adjust based on specific requirements
                position_labels[frame_idx] = position

    data_dic["labels"] = presence_labels  # shape: (mel_length,)
    data_dic["position_labels"] = position_labels  # shape: (mel_length,)

    return data_dic

# %% [markdown]
# ## Dataset and DataLoader

class TimeUnit(enum.Enum):
    milliseconds = "milliseconds"
    frames = "frames"
    seconds = "seconds"

class OnsetDataset(Dataset):
    """
    PyTorch Dataset class for loading and providing data.
    Each sample includes the current frame and 40 frames before and after (total 81 frames).
    """
    def __init__(self, data: Dict[str, Any], bpm_info: Dict[str, List[Dict[str, float]]], score_positions: Dict[str, List[Dict[str, Any]]], window_size: int = 40, transform=None):
        self.data = data
        self.bpm_info = bpm_info
        self.score_positions = score_positions
        self.transform = transform
        self.window_size = window_size
        self.samples = self.prepare_samples()
        
    def prepare_samples(self) -> List[Tuple[np.ndarray, np.ndarray, int]]:
        """
        Prepare data samples, each containing 81 frames of Mel spectrogram and corresponding labels.
        """
        samples = []
        for song_id, song in self.data.items():
            mp3_path = song["mp3_path"]
            charts_path = song["charts_path"]
            difficulty = song['difficulty']
            
            mel_dict = generate_mel_spectrogram(
                audio_path=Path(mp3_path),
                log_enable=True,
                bpm_info=self.bpm_info.get(song_id, None),
                note_info=self.score_positions.get(song_id, None)
            )
            if "labels" in mel_dict:
                mel = mel_dict["mel"]  # shape: (num_frames, n_mels)
                labels = mel_dict["labels"]  # shape: (num_frames,)

                num_frames = mel.shape[0]
                for i in range(num_frames):
                    start = max(i - self.window_size, 0)
                    end = min(i + self.window_size + 1, num_frames)
                    
                    # Pad insufficient frames
                    pad_before = self.window_size - i if i < self.window_size else 0
                    pad_after = (i + self.window_size + 1) - num_frames if (i + self.window_size + 1) > num_frames else 0
                    
                    mel_window = mel[start:end]
                    if pad_before > 0:
                        mel_window = np.pad(mel_window, ((pad_before, 0), (0, 0)), mode='constant')
                    if pad_after > 0:
                        mel_window = np.pad(mel_window, ((0, pad_after), (0, 0)), mode='constant')
                    
                    label = labels[i]
                    
                    samples.append((mel_window, label, difficulty))
        return samples

    def __len__(self) -> int:
        return len(self.samples)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        mel_window, label, difficulty = self.samples[idx]
        mel_window = torch.from_numpy(mel_window).float()  # shape: (81, n_mels)
        label = torch.tensor(label).float()  # shape: ()
        
        if self.transform:
            mel_window, label = self.transform(mel_window, label)

        return mel_window, label, difficulty

def collate_fn_padded(batch: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
    """
    Custom collate_fn for handling batch data.
    """
    mel, labels, difficulties = zip(*batch)
    
    mel = torch.stack(mel, dim=0)  # (batch_size, 81, n_mels)
    labels = torch.stack(labels, dim=0)  # (batch_size,)

    return mel, labels, difficulties

# %% [markdown]
# ## Presence Model Definition (Keeping Code Nearly Unchanged)

class CNNOnsetDetector(nn.Module):
    """
    Convolutional Neural Network (CNN) based Onset Detection Model.
    """
    def __init__(self, input_channels: int, num_classes: int = 1, dropout: float = 0.5):
        super(CNNOnsetDetector, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        
        # Calculate pooled feature length
        self.pool_layers = 3
        self.feature_length = 81
        for _ in range(self.pool_layers):
            self.feature_length = self.feature_length // 2
        self.feature_length = max(self.feature_length, 1)  # Prevent feature length from being 0
        
        self.fc1 = nn.Linear(256 * self.feature_length, 512)  # Assuming three pooling layers
        self.fc2 = nn.Linear(512, num_classes)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, 81, n_mels]
        """
        x = x.permute(0, 2, 1)  # Convert to [batch_size, n_mels, 81]
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 64, 40]
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 128, 20]
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 256, 10]
        
        x = x.view(x.size(0), -1)  # [batch_size, 2560]
        x = self.dropout(x)
        x = self.fc1(x)  # [batch_size, 512]
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)  # [batch_size, num_classes]
        
        return x

class CNNOnsetFeatureExtractor(nn.Module):
    def __init__(self, input_channels: int, dropout: float = 0.5):
        super(CNNOnsetFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        
        self.pool_layers = 3
        self.feature_length = 81
        for _ in range(self.pool_layers):
            self.feature_length = self.feature_length // 2
        self.feature_length = max(self.feature_length, 1)
        
        self.fc1 = nn.Linear(256 * self.feature_length, 512)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: (batch_size, 81, n_mels)
        x = x.permute(0, 2, 1)  # (batch_size, n_mels, 81)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)  # (batch_size, 64, 40)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)  # (batch_size, 128, 20)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)  # (batch_size, 256, 10)
        
        x = x.view(x.size(0), -1)  # (batch_size, 2560)
        x = self.dropout(x)
        x = self.fc1(x)  # (batch_size, 512)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Do not perform fc2, directly return 512-dimensional features
        return x

# %% [markdown]
# ## Type Model Definition (Modified to CNN)

class CNNTypePredictor(nn.Module):
    """
    Convolutional Neural Network (CNN) based Type Prediction Model.
    """
    def __init__(self, input_channels: int, num_types: int = 5, dropout: float = 0.5):
        super(CNNTypePredictor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        
        # Calculate pooled feature length
        self.pool_layers = 3
        self.feature_length = 81
        for _ in range(self.pool_layers):
            self.feature_length = self.feature_length // 2
        self.feature_length = max(self.feature_length, 1)  # Prevent feature length from being 0
        
        self.fc1 = nn.Linear(256 * self.feature_length, 512)  # Assuming three pooling layers
        self.fc_type = nn.Linear(512, num_types)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, 81, n_mels]
        """
        x = x.permute(0, 2, 1)  # Convert to [batch_size, n_mels, 81]
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 64, 40]
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 128, 20]
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 256, 10]
        
        x = x.view(x.size(0), -1)  # [batch_size, 2560]
        x = self.dropout(x)
        x = self.fc1(x)  # [batch_size, 512]
        x = self.relu(x)
        x = self.dropout(x)
        type_out = self.fc_type(x)  # [batch_size, num_types]
        
        return type_out

# %% [markdown]
# ## Visualization Functions

def visualize_presence_predictions(mel: np.ndarray, labels: np.ndarray, preds: np.ndarray, start_time: float = 0, end_time: float = 5):
    """
    Visualize the model's presence prediction results and true labels.

    Args:
        mel (np.ndarray): Mel spectrogram, shape (seq_len, feature_dim)
        labels (np.ndarray): True labels, shape (seq_len,)
        preds (np.ndarray): Model's presence scores, shape (seq_len,)
        start_time (float): Start time for visualization (seconds)
        end_time (float): End time for visualization (seconds)
    """
    # Apply Sigmoid activation
    presence_pred = 1 / (1 + np.exp(-preds))
    
    # Apply threshold of 0.5
    presence_final = (presence_pred >= 0.5).astype(int)

    # Calculate time axis
    total_time = mel.shape[0] * HOP_LENGTH / SAMPLE_RATE
    times = np.linspace(0, total_time, num=mel.shape[0])

    # Determine frame range for visualization
    start_frame = int(start_time * SAMPLE_RATE / HOP_LENGTH)
    end_frame = int(end_time * SAMPLE_RATE / HOP_LENGTH)

    # Ensure end_frame does not exceed sequence length
    end_frame = min(end_frame, mel.shape[0])

    # Crop data
    mel_cropped = mel[start_frame:end_frame]
    labels_cropped = labels[start_frame:end_frame]
    presence_pred_cropped = presence_pred[start_frame:end_frame]
    presence_final_cropped = presence_final[start_frame:end_frame]

    # Create subplots
    fig, axs = plt.subplots(2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})

    # Plot Mel spectrogram
    img = librosa.display.specshow(
        mel_cropped.T,
        sr=SAMPLE_RATE,
        hop_length=HOP_LENGTH,
        x_coords=times[start_frame:end_frame],
        ax=axs[0],
        x_axis='time',
        y_axis='mel',
        fmax=8000
    )
    axs[0].set_title('Mel Spectrogram')
    fig.colorbar(img, ax=axs[0], format='%+2.0f dB')

    # Plot Presence predictions and true labels
    axs[1].plot(
        times[start_frame:end_frame],
        presence_pred_cropped.flatten(),
        label='Presence Prediction (Raw)',
        color='red',
        alpha=0.6
    )
    axs[1].plot(
        times[start_frame:end_frame],
        presence_final_cropped,
        label='Presence Prediction (Threshold=0.50)',
        color='orange',
        alpha=0.6
    )
    axs[1].plot(
        times[start_frame:end_frame],
        labels_cropped.flatten(),
        label='Presence Ground Truth',
        color='blue',
        linestyle='dashed'
    )

    axs[1].set_title('Presence Predictions vs Ground Truth (Threshold: 0.50)')
    axs[1].legend(loc='upper right')
    axs[1].set_xlabel('Time (s)')
    axs[1].set_ylabel('Presence')

    plt.tight_layout()
    plt.show()

def visualize_presence_predictions_single(mel: np.ndarray, label: np.ndarray, pred: np.ndarray, start_time: float = 0, end_time: float = 5):
    """
    Visualize the presence prediction results and true label for a single sample.
    
    Args:
        mel (np.ndarray): Mel spectrogram, shape (81, n_mels)
        label (np.ndarray): True label, shape (1,)
        pred (np.ndarray): Model's presence score, shape (1,)
        start_time (float): Start time for visualization (seconds)
        end_time (float): End time for visualization (seconds)
    """
    # Apply Sigmoid activation
    presence_pred = 1 / (1 + np.exp(-pred))
    
    # Apply threshold of 0.5
    presence_final = (presence_pred >= 0.5).astype(int)
    
    # Calculate time axis (assuming window center frame corresponds to current time)
    total_time = WINDOW_SIZE * 2 * HOP_LENGTH / SAMPLE_RATE  # Total time for before and after frames
    times = np.linspace(-WINDOW_SIZE * HOP_LENGTH / SAMPLE_RATE, WINDOW_SIZE * HOP_LENGTH / SAMPLE_RATE, num=mel.shape[0])
    
    # Create subplots
    fig, axs = plt.subplots(2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})
    
    # Plot Mel spectrogram
    img = librosa.display.specshow(
        mel.T,
        sr=SAMPLE_RATE,
        hop_length=HOP_LENGTH,
        x_coords=times,
        ax=axs[0],
        x_axis='time',
        y_axis='mel',
        fmax=8000
    )
    axs[0].set_title('Mel Spectrogram')
    fig.colorbar(img, ax=axs[0], format='%+2.0f dB')
    
    # Plot Presence predictions and true label
    axs[1].bar(0, presence_pred, label='Presence Prediction (Raw)', color='red', alpha=0.6)
    axs[1].bar(0, presence_final, label='Presence Prediction (Threshold=0.50)', color='orange', alpha=0.6)
    axs[1].bar(0, label, label='Presence Ground Truth', color='blue', alpha=0.6)
    
    axs[1].set_title('Presence Predictions vs Ground Truth')
    axs[1].legend(loc='upper right')
    axs[1].set_xlabel('Current Frame')
    axs[1].set_ylabel('Presence')
    
    plt.tight_layout()
    plt.show()

def visualize_type_predictions(mel: np.ndarray, labels: np.ndarray, preds: np.ndarray, start_time: float = 0, end_time: float = 5, hop_length: int = HOP_LENGTH, sample_rate: int = SAMPLE_RATE):
    """
    Visualize the model's type prediction results and true labels.

    Args:
        mel (np.ndarray): Mel spectrogram, shape (seq_len, feature_dim)
        labels (np.ndarray): True labels, shape (seq_len,)
        preds (np.ndarray): Model's type scores, shape (seq_len, num_types)
        start_time (float): Start time for visualization (seconds)
        end_time (float): End time for visualization (seconds)
        hop_length (int): hop_length parameter
        sample_rate (int): Sample rate
    """
    # Apply Softmax activation
    preds_prob = F.softmax(torch.tensor(preds), dim=-1).numpy()

    # Calculate time axis
    total_time = mel.shape[0] * hop_length / sample_rate
    times = np.linspace(0, total_time, num=mel.shape[0])

    # Determine frame range for visualization
    start_frame = int(start_time * sample_rate / hop_length)
    end_frame = int(end_time * sample_rate / hop_length)

    # Crop data
    labels_cropped = labels[start_frame:end_frame]
    preds_cropped = preds_prob[start_frame:end_frame]
    times_cropped = times[start_frame:end_frame]

    # Define color mapping (using matplotlib's tab10 color set)
    cmap = plt.get_cmap('tab10')
    num_types = preds_cropped.shape[1]
    colors = [cmap(i) for i in range(num_types)]

    # Create plot
    plt.figure(figsize=(15, 8))

    # Plot type probabilities
    for type_idx in range(num_types):
        plt.plot(
            times_cropped,
            preds_cropped[:, type_idx],
            label=f'Type {type_idx}',
            color=colors[type_idx],
            alpha=0.6
        )

    # Plot true labels
    for idx, label in enumerate(labels_cropped):
        if label == 0:
            continue  # Skip type 0 (assumed to be no event)
        plt.scatter(
            times_cropped[idx],
            preds_cropped[idx, label],
            color=colors[label],
            marker='x',
            s=50,
            label=f'Ground Truth Type {label}' if idx == 0 else "",  # Add to legend only once
            zorder=5
        )
        # Plot vertical lines
        plt.axvline(
            x=times_cropped[idx],
            color=colors[label],
            linestyle='--',
            alpha=0.5,
            linewidth=1
        )

    # Add title and labels
    plt.title('Type Probabilities and Ground Truth', fontsize=14)
    plt.xlabel('Time (s)', fontsize=12)
    plt.ylabel('Probability', fontsize=12)

    # Set legend, avoid duplicates
    handles, labels_legend = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels_legend, handles))
    plt.legend(by_label.values(), by_label.keys(), loc='upper right', fontsize='small')

    # Set x-axis range
    plt.xlim(start_time, end_time)

    # Only show y-axis label on the right
    ax = plt.gca()
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.yaxis.set_label_coords(1.05, 0.5)

    # Display plot
    plt.tight_layout()
    plt.show()

def visualize_position_predictions(true_positions: List[int], pred_positions: List[int], num_samples: int = 100):
    """
    Visualize the comparison between true positions and predicted positions.
    
    Args:
        true_positions (List[int]): List of true positions.
        pred_positions (List[int]): List of predicted positions.
        num_samples (int): Number of samples to visualize.
    """
    plt.figure(figsize=(15, 6))
    if len(true_positions) < num_samples:
        num_samples = len(true_positions)
    indices = np.random.choice(len(true_positions), size=num_samples, replace=False)
    true = np.array(true_positions)[indices]
    pred = np.array(pred_positions)[indices]
    
    plt.scatter(range(num_samples), true, label='True Position', alpha=0.6, color='blue')
    plt.scatter(range(num_samples), pred, label='Predicted Position', alpha=0.6, color='red')
    plt.title('True vs Predicted Note Positions')
    plt.xlabel('Sample Index')
    plt.ylabel('Position Index')
    plt.legend()
    plt.show()

def visualize_position_distribution(true_positions: List[int], pred_positions: List[int]):
    """
    Visualize the distribution of true positions and predicted positions.
    
    Args:
        true_positions (List[int]): List of true positions.
        pred_positions (List[int]): List of predicted positions.
    """
    plt.figure(figsize=(10, 6))
    plt.hist(true_positions, bins=81, alpha=0.5, label='True Positions', color='blue', density=True)
    plt.hist(pred_positions, bins=81, alpha=0.5, label='Predicted Positions', color='red', density=True)
    plt.title('Distribution of True and Predicted Positions')
    plt.xlabel('Position Index')
    plt.ylabel('Density')
    plt.legend()
    plt.show()

# %% [markdown]
# ## Training and Validation Functions

# Presence Model Training Function
def train_epoch_cnn(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, loss_fn: nn.Module) -> float:
    """
    Train for one epoch (CNN version).
    """
    model.train()
    running_loss = 0.0
    progress_bar = tqdm.tqdm(dataloader, desc="Training Presence", leave=False)

    for mel, labels, difficulties in progress_bar:
        mel = mel.to(device)  # (batch_size, 81, n_mels)
        labels = labels.to(device)  # (batch_size,)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(mel)  # (batch_size, 1)
        outputs = outputs.squeeze(1)  # (batch_size)

        # Compute loss
        loss = loss_fn(outputs, labels)

        # Backward pass
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss

# Presence Model Validation Function
def validate_epoch_cnn(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module) -> Tuple[float, Dict[int, Dict[str, List[Any]]]]:
    model.eval()
    running_loss = 0.0
    difficulty_preds = defaultdict(lambda: {'y_true': [], 'y_scores': []})
    progress_bar = tqdm.tqdm(dataloader, desc="Validation Presence", leave=False)

    with torch.no_grad():
        for mel, labels, difficulties in progress_bar:
            mel = mel.to(device)  # (batch_size, 81, n_mels)
            labels = labels.to(device)  # (batch_size,)

            # Forward pass
            outputs = model(mel)  # (batch_size, 1)
            outputs = outputs.squeeze(1)  # (batch_size)

            # Compute loss
            loss = loss_fn(outputs, labels)

            running_loss += loss.item()
            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

            # Collect prediction scores and true labels, grouped by difficulty level
            presence_pred_np = outputs.cpu().numpy()
            presence_target_np = labels.cpu().numpy()

            for i in range(mel.size(0)):
                difficulty = difficulties[i]
                y_true = presence_target_np[i]
                y_score = presence_pred_np[i]
                difficulty_preds[difficulty]['y_true'].append(y_true)
                difficulty_preds[difficulty]['y_scores'].append(y_score)

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss, difficulty_preds

# Type Model Training Function (CNN version)
def train_epoch_cnn_type(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, loss_fn: nn.Module, num_types: int) -> float:
    """
    Train for one epoch (Type CNN version).
    """
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    progress_bar = tqdm.tqdm(dataloader, desc="Training Type CNN", leave=False)

    for mel, labels, lengths in progress_bar:
        mel = mel.to(device)  # (batch_size, 81, n_mels)
        labels = labels.to(device)  # (batch_size,)

        optimizer.zero_grad()

        # Forward pass
        type_pred = model(mel)  # (batch_size, num_types)

        # Compute loss
        loss = loss_fn(type_pred, labels)

        # Backward pass
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()

        # Calculate accuracy
        preds = torch.argmax(F.softmax(type_pred, dim=1), dim=1)
        acc = (preds == labels).float().mean().item()
        running_acc += acc

        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{acc*100:.2f}%'})

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = running_acc / len(dataloader)
    return epoch_loss, epoch_acc

# Type Model Validation Function (CNN version)
def validate_epoch_cnn_type(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module, num_types: int) -> Tuple[float, float]:
    """
    Evaluate the model on the validation set (CNN version).
    """
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    progress_bar = tqdm.tqdm(dataloader, desc="Validation Type CNN", leave=False)

    with torch.no_grad():
        for mel, labels, lengths in progress_bar:
            mel = mel.to(device)
            labels = labels.to(device)

            # Forward pass
            type_pred = model(mel)  # (batch_size, num_types)

            # Compute loss
            loss = loss_fn(type_pred, labels)

            running_loss += loss.item()

            # Calculate accuracy
            preds = torch.argmax(F.softmax(type_pred, dim=1), dim=1)
            acc = (preds == labels).float().mean().item()
            running_acc += acc

            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{acc*100:.2f}%'})

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = running_acc / len(dataloader)
    return epoch_loss, epoch_acc

# %% [markdown]
# ## Model Evaluation Functions

def evaluate_test_set_cnn(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module):
    """
    Evaluate the CNN model on the test set.
    """
    model.eval()
    difficulty_metrics = defaultdict(lambda: {'y_true': [], 'y_pred': []})
    all_preds, all_labels = [], []
    progress_bar = tqdm.tqdm(dataloader, desc="Testing Presence CNN", leave=False)

    with torch.no_grad():
        for mel, labels, difficulties in progress_bar:
            mel = mel.to(device)  # (batch_size, 81, n_mels)
            labels = labels.to(device)  # (batch_size,)

            # Forward pass
            outputs = model(mel)  # (batch_size, 1)
            outputs = outputs.squeeze(1)  # (batch_size)
            # Remove the following line as labels are already shape (batch_size,)
            # labels = labels.squeeze(1)  # (batch_size)

            # Compute loss (optional)
            loss = loss_fn(outputs, labels)

            # Collect predictions and labels
            presence_pred_np = outputs.cpu().numpy()
            presence_target_np = labels.cpu().numpy()

            # Apply Sigmoid activation
            presence_pred_sigmoid = 1 / (1 + np.exp(-presence_pred_np))

            # Use threshold 0.5 for prediction
            y_pred = (presence_pred_sigmoid >= 0.5).astype(int)
            y_true = presence_target_np.astype(int)

            # Collect all predictions and labels for distribution
            all_preds.extend(presence_pred_sigmoid.tolist())
            all_labels.extend(presence_target_np.tolist())

            # Group by difficulty level
            for i in range(mel.size(0)):
                difficulty = difficulties[i]
                difficulty_metrics[difficulty]['y_true'].append(y_true[i])
                difficulty_metrics[difficulty]['y_pred'].append(y_pred[i])

    # Convert to NumPy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Plot distribution
    plt.figure(figsize=(10, 6))
    plt.hist(all_preds, bins=50, alpha=0.7, label="Predictions", color="blue", density=True)
    plt.hist(all_labels, bins=50, alpha=0.7, label="Ground Truth", color="orange", density=True)
    plt.title("Frame-wise Prediction and Ground Truth Distribution")
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

    # Calculate metrics for each difficulty level
    final_metrics = {}
    for diff, metrics in difficulty_metrics.items():
        y_true = np.array(metrics['y_true'])
        y_pred = np.array(metrics['y_pred'])
        acc = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        final_metrics[diff] = {
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        }
        print(f"Difficulty {diff}: Accuracy={acc:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1 Score={f1:.4f}")

    return final_metrics

def evaluate_test_set_type_cnn(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module, num_types: int):
    """
    Evaluate the Type CNN model on the test set.
    """
    model.eval()
    all_preds = []
    all_labels = []
    progress_bar = tqdm.tqdm(dataloader, desc="Testing Type CNN", leave=False)

    with torch.no_grad():
        for mel, labels, lengths in progress_bar:
            mel = mel.to(device)  # (batch_size, 81, n_mels)
            labels = labels.to(device)  # (batch_size,)

            # Forward pass
            type_pred = model(mel)  # (batch_size, num_types)

            # Compute loss
            loss = loss_fn(type_pred, labels)

            # Collect predictions and labels
            preds = torch.argmax(F.softmax(type_pred, dim=1), dim=1).cpu().numpy()
            true = labels.cpu().numpy()

            all_preds.extend(preds.tolist())
            all_labels.extend(true.tolist())

    # Calculate metrics
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    print(f"Type Prediction - Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1_score': f1}

# %% [markdown]
# ## Data Preparation

current_directory = Path().cwd()
dataset_dirs = [
    current_directory / "../dataset/A",
    current_directory / "../dataset/B",
    current_directory / "../dataset/C",
    current_directory / "../dataset/Z"
]
data = extract_level_json_multiple(dataset_dirs, min_difficulty=15)
print(f"Filtered Data Count (Difficulty>=15): {len(data)}")
# data = dict(list(data.items())[:30])
bpm_info_dict = {}
score_positions_dict = {}

for unique_id, song in data.items():
    level_data = song['level']
    song_id = unique_id  # Use unique ID
    charts_data = extract_charts(song['charts_path'])
    if charts_data:
        bpm_info = charts_data.get('tempo_list', [])
        bpm_info_dict[song_id] = bpm_info
        note_time_map = map_note_to_time(charts_data) 
        # Detailed information for each note, including time, type, and position
        score_positions = [] 
        for note in note_time_map:
            score_positions.append({
                'note_time_microseconds': note['note_time_microseconds'],
                'note_type': note.get('note_type', 0),  # Ensure this field exists
                'note_x': note.get('note_x', 0.0)      # Ensure this field exists
            })
        score_positions_dict[song_id] = score_positions

# %% [markdown]
# ## Define Dataset and DataLoader

# Define Presence dataset and DataLoader
presence_dataset = OnsetDataset(
    data=data, 
    bpm_info=bpm_info_dict, 
    score_positions=score_positions_dict,
    window_size=WINDOW_SIZE  # 40 frames before and after
)

train_size = int(0.8 * len(presence_dataset))
val_size = len(presence_dataset) - train_size
train_dataset, val_dataset = random_split(presence_dataset, [train_size, val_size])

print(f"Presence Train Size: {len(train_dataset)}")
print(f"Presence Validation Size: {len(val_dataset)}")
# Below is the training code ---------------------------------------------------------------------------------------------------
# Create test set (using part of the validation set as test set)
test_size = int(0.5 * len(val_dataset))
val_size = len(val_dataset) - test_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

print(f"Presence Validation Size after split: {len(val_dataset)}")
print(f"Presence Test Size: {len(test_dataset)}")

presence_train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_fn_padded)
presence_val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded)
presence_test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded)

# %% [markdown]
# ## Presence Model Training and Evaluation

# Initialize Presence model
presence_model = CNNOnsetDetector(input_channels=NMELS, num_classes=1, dropout=DROPOUT)
presence_model.to(device)

# Define optimizer and loss function
presence_optimizer = torch.optim.Adam(presence_model.parameters(), lr=LEARNING_RATE)
presence_loss_fn = nn.BCEWithLogitsLoss()

# Train the best Presence model
best_presence_val_loss = float('inf')
best_presence_model_path = "model/best_cnn_onset_model.pth"

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Presence Model")
    
    # Training
    train_loss = train_epoch_cnn(presence_model, presence_train_loader, presence_optimizer, device, presence_loss_fn)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validation
    val_loss, val_difficulty_preds = validate_epoch_cnn(presence_model, presence_val_loader, device, presence_loss_fn)
    print(f"Validation Loss: {val_loss:.4f}")
    
    # Save the model if validation loss is lower
    if val_loss < best_presence_val_loss:
        best_presence_val_loss = val_loss
        torch.save(presence_model.state_dict(), best_presence_model_path)
        print("Saved Best Presence Model")

# Load the best Presence model
presence_model.load_state_dict(torch.load(best_presence_model_path))
presence_model.to(device)
presence_model.eval()
# Below is the training code ---------------------------------------------------------------------------------------------------

# %% [markdown]
# ## Define Type Dataset and DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=2.0, gamma=2, reduction='mean'):
        """
        alpha: Class balancing factor, can be used to give more weight to minority classes in imbalanced datasets
        gamma: Modulation parameter, the larger gamma is, the more focus on hard-to-classify samples
        reduction: Loss aggregation method, 'mean' or 'sum'
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: [batch_size, num_classes]
        # targets: [batch_size]

        # Get predicted probability distribution
        probs = F.softmax(inputs, dim=1)
        # Extract the predicted probability for the true class
        pt = probs[range(len(targets)), targets]

        # Focal loss formula
        loss = -self.alpha * (1 - pt)**self.gamma * torch.log(pt)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# %% 
class TypeDataset(Dataset):
    """
    PyTorch Dataset class for loading and providing data.
    Each sample includes the current frame and 40 frames before and after (total 81 frames).
    """
    def __init__(self, data: Dict[str, Any], bpm_info: Dict[str, List[Dict[str, float]]], score_positions: Dict[str, List[Dict[str, Any]]], window_size: int = 40, transform=None):
        self.data = data
        self.bpm_info = bpm_info
        self.score_positions = score_positions
        self.transform = transform
        self.window_size = window_size
        self.samples = self.prepare_samples()
        
    def prepare_samples(self) -> List[Tuple[np.ndarray, int, int]]:
        """
        Prepare data samples, each containing 81 frames of Mel spectrogram and corresponding type labels.
        """
        samples = []
        for song_id, song in self.data.items():
            mp3_path = song["mp3_path"]
            charts_path = song["charts_path"]
            difficulty = song['difficulty']
            
            mel_dict = generate_mel_spectrogram(
                audio_path=Path(mp3_path),
                log_enable=True,
                bpm_info=self.bpm_info.get(song_id, None),
                note_info=self.score_positions.get(song_id, None)
            )
            if "labels" in mel_dict:
                mel = mel_dict["mel"]  # shape: (num_frames, n_mels)
                labels = mel_dict["labels"]  # shape: (num_frames,)

                num_frames = mel.shape[0]
                for i in range(num_frames):
                    start = max(i - self.window_size, 0)
                    end = min(i + self.window_size + 1, num_frames)
                    
                    # Pad insufficient frames
                    pad_before = self.window_size - i if i < self.window_size else 0
                    pad_after = (i + self.window_size + 1) - num_frames if (i + self.window_size + 1) > num_frames else 0
                    
                    mel_window = mel[start:end]
                    if pad_before > 0:
                        mel_window = np.pad(mel_window, ((pad_before, 0), (0, 0)), mode='constant')
                    if pad_after > 0:
                        mel_window = np.pad(mel_window, ((0, pad_after), (0, 0)), mode='constant')
                    
                    label = labels[i]
                    
                    samples.append((mel_window, label, difficulty))
        return samples

    def __len__(self) -> int:
        return len(self.samples)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        mel_window, label, difficulty = self.samples[idx]
        mel_window = torch.from_numpy(mel_window).float()  # shape: (81, n_mels)
        label = torch.tensor(label).long()  # shape: ()
        
        if self.transform:
            mel_window, label = self.transform(mel_window, label)

        return mel_window, label, difficulty

def collate_fn_padded_type(batch: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
    """
    Custom collate_fn for handling batch data.
    """
    mel, labels, difficulties = zip(*batch)
    
    mel = torch.stack(mel, dim=0)  # (batch_size, 81, n_mels)
    labels = torch.stack(labels, dim=0)  # (batch_size,)

    return mel, labels, difficulties

type_dataset = TypeDataset(
    data=data, 
    bpm_info=bpm_info_dict, 
    score_positions=score_positions_dict,
    window_size=WINDOW_SIZE  # 40 frames before and after
)

# Below is the training code ---------------------------------------------------------------------------------------------------
# # Split the dataset
type_train_size = int(0.8 * len(type_dataset))
type_val_size = len(type_dataset) - type_train_size
type_train_dataset, type_val_dataset = random_split(type_dataset, [type_train_size, type_val_size])

print(f"Type Train Size: {len(type_train_dataset)}")
print(f"Type Validation Size: {len(type_val_dataset)}")

# Create test set (using part of the validation set as test set)
type_test_size = int(0.5 * len(type_val_dataset))
type_val_size = len(type_val_dataset) - type_test_size
type_val_dataset, type_test_dataset = random_split(type_val_dataset, [type_val_size, type_test_size])

print(f"Type Validation Size after split: {len(type_val_dataset)}")
print(f"Type Test Size: {len(type_test_dataset)}")

type_train_loader = DataLoader(type_train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_fn_padded_type)
type_val_loader = DataLoader(type_val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded_type)
type_test_loader = DataLoader(type_test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded_type)

# %% [markdown]
# ## Type Model Definition and Training

# Initialize Type model
num_types = 5  # Adjust based on requirements
type_model = CNNTypePredictor(input_channels=NMELS, num_types=num_types, dropout=DROPOUT)
type_model.to(device)

# Define optimizer and loss function
type_optimizer = torch.optim.Adam(type_model.parameters(), lr=LEARNING_RATE)
type_loss_fn = FocalLoss(alpha=1.0, gamma=2, reduction='mean')


# Train the best Type model
best_type_val_loss = float('inf')
best_type_model_path = "model/best_cnn_type_model.pth"

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Type Model")
    
    # Training
    train_loss, train_acc = train_epoch_cnn_type(type_model, type_train_loader, type_optimizer, device, type_loss_fn, num_types)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%")
    
    # Validation
    val_loss, val_acc = validate_epoch_cnn_type(type_model, type_val_loader, device, type_loss_fn, num_types)
    print(f"Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc*100:.2f}%")
    
    # Save the model if validation loss is lower
    if val_loss < best_type_val_loss:
        best_type_val_loss = val_loss
        torch.save(type_model.state_dict(), best_type_model_path)
        print("Saved Best Type Model")

# Load the best Type model
type_model.load_state_dict(torch.load(best_type_model_path))
type_model.to(device)
type_model.eval()
# Below is the training code ---------------------------------------------------------------------------------------------------

num_types = 5  # Adjust based on requirements
best_presence_model_path = "model/best_cnn_onset_model.pth"
best_type_model_path = "model/best_cnn_type_model.pth"

presence_model = CNNOnsetDetector(input_channels=NMELS, num_classes=1, dropout=DROPOUT)
presence_model.load_state_dict(torch.load(best_presence_model_path, map_location=device))
presence_model.to(device)
presence_model.eval()  # Switch to evaluation mode

type_model = CNNTypePredictor(input_channels=NMELS, num_types=num_types, dropout=DROPOUT)
type_model.load_state_dict(torch.load(best_type_model_path, map_location=device))
type_model.to(device)
type_model.eval()

presence_model.load_state_dict(torch.load(best_presence_model_path))
presence_model.to(device)
presence_model.eval()


Using device: cuda
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.clouddiver.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/arwtdydhqhfa.helamind.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/andogaru.fumiko.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.summernight.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.cereris.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.alone.cytoidlevel/level.json
未找到音频文件 �TAKUMI³�OЯDIN -Apocalyptic War-(Re Mastering).mp3 对应于 song ID ant.ordin-tc 在 /data1/yuchen/cytoid/final_code/../dataset/A/ant.ordin-tc.cytoidlevel
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anthony.lolk_muricaaaaa.cytoidlevel/level.json
未找到音频文件 Langley_D - deli.+駄々子 - 最果ての勇者にラブソングを.ogg 对应于 song ID anoppo.furiy 在 /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.furiy.cytoidleve

[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?


Presence Train Size: 858862
Presence Validation Size: 214716
Presence Validation Size after split: 107358
Presence Test Size: 107358
Epoch 1/50 - Training Presence Model


                                                                                      

Train Loss: 0.4299


                                                                                      

Validation Loss: 0.3996
Saved Best Presence Model
Epoch 2/50 - Training Presence Model


                                                                                      

Train Loss: 0.3967


                                                                                      

Validation Loss: 0.3695
Saved Best Presence Model
Epoch 3/50 - Training Presence Model


                                                                                      

Train Loss: 0.3722


                                                                                      

Validation Loss: 0.3430
Saved Best Presence Model
Epoch 4/50 - Training Presence Model


                                                                                      

Train Loss: 0.3535


                                                                                      

Validation Loss: 0.3294
Saved Best Presence Model
Epoch 5/50 - Training Presence Model


                                                                                      

Train Loss: 0.3393


                                                                                      

Validation Loss: 0.3150
Saved Best Presence Model
Epoch 6/50 - Training Presence Model


                                                                                      

Train Loss: 0.3280


                                                                                      

Validation Loss: 0.3057
Saved Best Presence Model
Epoch 7/50 - Training Presence Model


                                                                                      

Train Loss: 0.3201


                                                                                      

Validation Loss: 0.2988
Saved Best Presence Model
Epoch 8/50 - Training Presence Model


                                                                                      

Train Loss: 0.3127


                                                                                      

Validation Loss: 0.2934
Saved Best Presence Model
Epoch 9/50 - Training Presence Model


                                                                                      

Train Loss: 0.3074


                                                                                      

Validation Loss: 0.2872
Saved Best Presence Model
Epoch 10/50 - Training Presence Model


                                                                                      

Train Loss: 0.3027


                                                                                      

Validation Loss: 0.2834
Saved Best Presence Model
Epoch 11/50 - Training Presence Model


                                                                                      

Train Loss: 0.2976


                                                                                      

Validation Loss: 0.2771
Saved Best Presence Model
Epoch 12/50 - Training Presence Model


                                                                                      

Train Loss: 0.2942


                                                                                      

Validation Loss: 0.2765
Saved Best Presence Model
Epoch 13/50 - Training Presence Model


                                                                                      

Train Loss: 0.2910


                                                                                      

Validation Loss: 0.2746
Saved Best Presence Model
Epoch 14/50 - Training Presence Model


                                                                                      

Train Loss: 0.2878


                                                                                      

Validation Loss: 0.2748
Epoch 15/50 - Training Presence Model


                                                                                      

Train Loss: 0.2853


                                                                                      

Validation Loss: 0.2700
Saved Best Presence Model
Epoch 16/50 - Training Presence Model


                                                                                      

Train Loss: 0.2830


                                                                                      

Validation Loss: 0.2707
Epoch 17/50 - Training Presence Model


                                                                                      

Train Loss: 0.2807


                                                                                      

Validation Loss: 0.2649
Saved Best Presence Model
Epoch 18/50 - Training Presence Model


                                                                                      

Train Loss: 0.2785


                                                                                      

Validation Loss: 0.2646
Saved Best Presence Model
Epoch 19/50 - Training Presence Model


                                                                                      

Train Loss: 0.2768


                                                                                      

Validation Loss: 0.2626
Saved Best Presence Model
Epoch 20/50 - Training Presence Model


                                                                                      

Train Loss: 0.2749


                                                                                      

Validation Loss: 0.2617
Saved Best Presence Model
Epoch 21/50 - Training Presence Model


                                                                                      

Train Loss: 0.2735


                                                                                      

Validation Loss: 0.2600
Saved Best Presence Model
Epoch 22/50 - Training Presence Model


                                                                                      

Train Loss: 0.2713


                                                                                      

Validation Loss: 0.2579
Saved Best Presence Model
Epoch 23/50 - Training Presence Model


                                                                                      

Train Loss: 0.2691


                                                                                      

Validation Loss: 0.2619
Epoch 24/50 - Training Presence Model


                                                                                      

Train Loss: 0.2684


                                                                                      

Validation Loss: 0.2576
Saved Best Presence Model
Epoch 25/50 - Training Presence Model


                                                                                      

Train Loss: 0.2672


                                                                                      

Validation Loss: 0.2602
Epoch 26/50 - Training Presence Model


                                                                                      

Train Loss: 0.2658


                                                                                      

Validation Loss: 0.2557
Saved Best Presence Model
Epoch 27/50 - Training Presence Model


                                                                                      

Train Loss: 0.2646


                                                                                      

Validation Loss: 0.2583
Epoch 28/50 - Training Presence Model


                                                                                      

Train Loss: 0.2639


                                                                                      

Validation Loss: 0.2545
Saved Best Presence Model
Epoch 29/50 - Training Presence Model


                                                                                      

Train Loss: 0.2626


                                                                                      

Validation Loss: 0.2545
Saved Best Presence Model
Epoch 30/50 - Training Presence Model


                                                                                      

Train Loss: 0.2611


                                                                                      

Validation Loss: 0.2543
Saved Best Presence Model
Epoch 31/50 - Training Presence Model


                                                                                      

Train Loss: 0.2607


                                                                                      

Validation Loss: 0.2520
Saved Best Presence Model
Epoch 32/50 - Training Presence Model


                                                                                      

Train Loss: 0.2594


                                                                                      

Validation Loss: 0.2531
Epoch 33/50 - Training Presence Model


                                                                                      

Train Loss: 0.2585


                                                                                      

Validation Loss: 0.2526
Epoch 34/50 - Training Presence Model


                                                                                      

Train Loss: 0.2574


                                                                                      

Validation Loss: 0.2506
Saved Best Presence Model
Epoch 35/50 - Training Presence Model


                                                                                      

Train Loss: 0.2569


                                                                                      

Validation Loss: 0.2528
Epoch 36/50 - Training Presence Model


                                                                                      

Train Loss: 0.2566


                                                                                      

Validation Loss: 0.2504
Saved Best Presence Model
Epoch 37/50 - Training Presence Model


                                                                                      

Train Loss: 0.2552


                                                                                      

Validation Loss: 0.2506
Epoch 38/50 - Training Presence Model


                                                                                      

Train Loss: 0.2550


                                                                                      

Validation Loss: 0.2498
Saved Best Presence Model
Epoch 39/50 - Training Presence Model


                                                                                      

Train Loss: 0.2544


                                                                                      

Validation Loss: 0.2491
Saved Best Presence Model
Epoch 40/50 - Training Presence Model


                                                                                      

Train Loss: 0.2531


                                                                                      

Validation Loss: 0.2503
Epoch 41/50 - Training Presence Model


                                                                                      

Train Loss: 0.2525


                                                                                      

Validation Loss: 0.2486
Saved Best Presence Model
Epoch 42/50 - Training Presence Model


                                                                                      

Train Loss: 0.2517


                                                                                      

Validation Loss: 0.2474
Saved Best Presence Model
Epoch 43/50 - Training Presence Model


                                                                                      

Train Loss: 0.2516


                                                                                      

Validation Loss: 0.2475
Epoch 44/50 - Training Presence Model


                                                                                      

Train Loss: 0.2509


                                                                                      

Validation Loss: 0.2479
Epoch 45/50 - Training Presence Model


                                                                                      

Train Loss: 0.2505


                                                                                      

Validation Loss: 0.2472
Saved Best Presence Model
Epoch 46/50 - Training Presence Model


                                                                                      

Train Loss: 0.2496


                                                                                      

Validation Loss: 0.2492
Epoch 47/50 - Training Presence Model


                                                                                      

Train Loss: 0.2492


                                                                                      

Validation Loss: 0.2457
Saved Best Presence Model
Epoch 48/50 - Training Presence Model


                                                                                      

Train Loss: 0.2488


                                                                                      

Validation Loss: 0.2474
Epoch 49/50 - Training Presence Model


                                                                                      

Train Loss: 0.2480


                                                                                      

Validation Loss: 0.2457
Epoch 50/50 - Training Presence Model


                                                                                      

Train Loss: 0.2476


  presence_model.load_state_dict(torch.load(best_presence_model_path))


Validation Loss: 0.2468


[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?


Type Train Size: 858862
Type Validation Size: 214716
Type Validation Size after split: 107358
Type Test Size: 107358
Epoch 1/50 - Training Type Model


                                                                                                  

Train Loss: 0.1162, Train Acc: 83.59%


                                                                                                  

Validation Loss: 0.1110, Validation Acc: 83.77%
Saved Best Type Model
Epoch 2/50 - Training Type Model


                                                                                                  

Train Loss: 0.1107, Train Acc: 83.65%


                                                                                                  

Validation Loss: 0.1060, Validation Acc: 83.95%
Saved Best Type Model
Epoch 3/50 - Training Type Model


                                                                                                  

Train Loss: 0.1071, Train Acc: 83.87%


                                                                                                  

Validation Loss: 0.1023, Validation Acc: 84.24%
Saved Best Type Model
Epoch 4/50 - Training Type Model


                                                                                                  

Train Loss: 0.1041, Train Acc: 84.22%


                                                                                                  

Validation Loss: 0.0988, Validation Acc: 84.75%
Saved Best Type Model
Epoch 5/50 - Training Type Model


                                                                                                   

Train Loss: 0.1021, Train Acc: 84.52%


                                                                                                  

Validation Loss: 0.0965, Validation Acc: 85.04%
Saved Best Type Model
Epoch 6/50 - Training Type Model


                                                                                                  

Train Loss: 0.1002, Train Acc: 84.76%


                                                                                                  

Validation Loss: 0.0944, Validation Acc: 85.43%
Saved Best Type Model
Epoch 7/50 - Training Type Model


                                                                                                  

Train Loss: 0.0987, Train Acc: 84.97%


                                                                                                  

Validation Loss: 0.0936, Validation Acc: 85.50%
Saved Best Type Model
Epoch 8/50 - Training Type Model


                                                                                                   

Train Loss: 0.0976, Train Acc: 85.16%


                                                                                                  

Validation Loss: 0.0915, Validation Acc: 85.89%
Saved Best Type Model
Epoch 9/50 - Training Type Model


                                                                                                  

Train Loss: 0.0966, Train Acc: 85.27%


                                                                                                  

Validation Loss: 0.0907, Validation Acc: 86.11%
Saved Best Type Model
Epoch 10/50 - Training Type Model


                                                                                                  

Train Loss: 0.0957, Train Acc: 85.41%


                                                                                                  

Validation Loss: 0.0899, Validation Acc: 86.11%
Saved Best Type Model
Epoch 11/50 - Training Type Model


                                                                                                  

Train Loss: 0.0947, Train Acc: 85.56%


                                                                                                  

Validation Loss: 0.0896, Validation Acc: 86.29%
Saved Best Type Model
Epoch 12/50 - Training Type Model


                                                                                                  

Train Loss: 0.0940, Train Acc: 85.64%


                                                                                                  

Validation Loss: 0.0880, Validation Acc: 86.39%
Saved Best Type Model
Epoch 13/50 - Training Type Model


                                                                                                  

Train Loss: 0.0934, Train Acc: 85.71%


                                                                                                   

Validation Loss: 0.0887, Validation Acc: 86.56%
Epoch 14/50 - Training Type Model


                                                                                                  

Train Loss: 0.0930, Train Acc: 85.79%


                                                                                                  

Validation Loss: 0.0894, Validation Acc: 86.34%
Epoch 15/50 - Training Type Model


                                                                                                  

Train Loss: 0.0925, Train Acc: 85.85%


                                                                                                  

Validation Loss: 0.0875, Validation Acc: 86.56%
Saved Best Type Model
Epoch 16/50 - Training Type Model


                                                                                                  

Train Loss: 0.0919, Train Acc: 85.91%


                                                                                                  

Validation Loss: 0.0867, Validation Acc: 86.64%
Saved Best Type Model
Epoch 17/50 - Training Type Model


                                                                                                  

Train Loss: 0.0916, Train Acc: 85.98%


                                                                                                  

Validation Loss: 0.0861, Validation Acc: 86.79%
Saved Best Type Model
Epoch 18/50 - Training Type Model


                                                                                                  

Train Loss: 0.0911, Train Acc: 86.01%


                                                                                                  

Validation Loss: 0.0856, Validation Acc: 86.66%
Saved Best Type Model
Epoch 19/50 - Training Type Model


                                                                                                  

Train Loss: 0.0907, Train Acc: 86.10%


                                                                                                  

Validation Loss: 0.0855, Validation Acc: 86.73%
Saved Best Type Model
Epoch 20/50 - Training Type Model


                                                                                                   

Train Loss: 0.0905, Train Acc: 86.15%


                                                                                                   

Validation Loss: 0.0858, Validation Acc: 86.91%
Epoch 21/50 - Training Type Model


                                                                                                  

Train Loss: 0.0901, Train Acc: 86.20%


                                                                                                  

Validation Loss: 0.0850, Validation Acc: 86.90%
Saved Best Type Model
Epoch 22/50 - Training Type Model


                                                                                                  

Train Loss: 0.0900, Train Acc: 86.18%


                                                                                                  

Validation Loss: 0.0841, Validation Acc: 86.98%
Saved Best Type Model
Epoch 23/50 - Training Type Model


                                                                                                  

Train Loss: 0.0897, Train Acc: 86.24%


                                                                                                  

Validation Loss: 0.0848, Validation Acc: 86.78%
Epoch 24/50 - Training Type Model


                                                                                                  

Train Loss: 0.0895, Train Acc: 86.27%


                                                                                                  

Validation Loss: 0.0846, Validation Acc: 86.98%
Epoch 25/50 - Training Type Model


                                                                                                  

Train Loss: 0.0891, Train Acc: 86.33%


                                                                                                   

Validation Loss: 0.0838, Validation Acc: 86.91%
Saved Best Type Model
Epoch 26/50 - Training Type Model


                                                                                                  

Train Loss: 0.0892, Train Acc: 86.35%


                                                                                                  

Validation Loss: 0.0845, Validation Acc: 87.06%
Epoch 27/50 - Training Type Model


                                                                                                   

Train Loss: 0.0889, Train Acc: 86.36%


                                                                                                  

Validation Loss: 0.0826, Validation Acc: 87.28%
Saved Best Type Model
Epoch 28/50 - Training Type Model


                                                                                                   

Train Loss: 0.0886, Train Acc: 86.37%


                                                                                                  

Validation Loss: 0.0834, Validation Acc: 87.16%
Epoch 29/50 - Training Type Model


                                                                                                 

Train Loss: nan, Train Acc: 85.02%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 30/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 31/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 32/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 33/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 34/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 35/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 36/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 37/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 38/50 - Training Type Model


                                                                                                

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 39/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 40/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 41/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 42/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 43/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 44/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 45/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 46/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 47/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 48/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 49/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 50/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


  type_model.load_state_dict(torch.load(best_type_model_path))
  presence_model.load_state_dict(torch.load(best_presence_model_path, map_location=device))
  type_model.load_state_dict(torch.load(best_type_model_path, map_location=device))
  presence_model.load_state_dict(torch.load(best_presence_model_path))


Validation Loss: nan, Validation Acc: 83.77%


CNNOnsetDetector(
  (conv1): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=2560, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [None]:
# Visualize the result of both model


import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # 移动到最顶部

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
import json
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from typing import List, Tuple, Dict, Any
import enum
import tqdm
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score, mean_squared_error, mean_absolute_error, r2_score

# %% [markdown]
# ## 常量定义

# %%
# STFT 常量
SAMPLE_RATE = 22050  
HOP_LENGTH = 512     
NMELS = 128        
WINDOW_SIZE = 40  # 前后帧数
NUM_EPOCHS = 50    # 训练轮数
BATCH_SIZE = 64    # 批大小
LEARNING_RATE = 1e-3
DROPOUT = 0.5      # Dropout率

# 检查 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# %% [markdown]
# ## 数据预处理函数

# %%
def contains_non_ascii(s: str) -> bool:
    """检查字符串中是否包含非ASCII字符。"""
    return any(ord(c) > 127 for c in s)

def extract_level_json_multiple(directories: List[Path], min_difficulty: int = 15) -> Dict[str, Any]:
    """
    从多个目录中提取level.json文件，整理相关信息，并筛选出指定难度的曲目。
    如果遇到无法解析的名字或其他问题，直接跳过该曲子。

    Args:
        directories (List[Path]): 包含多个子文件夹的主目录列表。
        min_difficulty (int): 最低难度级别。

    Returns:
        Dict[str, Any]: 包含每个级别的相关信息，限定为 difficulty>=min_difficulty 的曲目。
    """
    result = {}
    skipped_songs = 0
    skipped_reasons = defaultdict(int)

    for directory in directories:
        if not directory.exists():
            print(f"目录不存在: {directory}")
            skipped_reasons['missing_directory'] += 1
            continue
        for folder_path in directory.iterdir():
            if not folder_path.is_dir():
                continue
            json_file_path = folder_path / 'level.json'
            if not json_file_path.is_file():
                print(f"缺少 level.json 文件在 {folder_path}")
                skipped_songs += 1
                skipped_reasons['missing_level_json'] += 1
                continue
            try:
                with json_file_path.open('r', encoding='utf-8') as json_file:
                    level_data = json.load(json_file)
            except json.JSONDecodeError:
                print(f"JSON 解析错误: {json_file_path}")
                skipped_songs += 1
                skipped_reasons['json_decode_error'] += 1
                continue
            except Exception as e:
                print(f"无法读取 {json_file_path}: {e}")
                skipped_songs += 1
                skipped_reasons['read_error'] += 1
                continue

            # 确保所有必要的字段存在
            try:
                level_id = level_data['id']
                charts = level_data['charts']
                music = level_data['music']
            except KeyError as e:
                print(f"缺少键 {e} 在文件: {json_file_path}")
                skipped_songs += 1
                skipped_reasons['missing_keys'] += 1
                continue

            if not charts:
                print(f"在文件 {json_file_path} 中未找到任何 charts")
                skipped_songs += 1
                skipped_reasons['empty_charts'] += 1
                continue

            chart_difficulty = charts[0].get('difficulty', 0)
            if chart_difficulty < min_difficulty:
                continue

            audio_file_name = music.get('path', '')
            if not audio_file_name:
                print(f"在文件 {json_file_path} 中未指定 music path")
                skipped_songs += 1
                skipped_reasons['missing_music_path'] += 1
                continue

            audio_file_extensions = ['.mp3', '.ogg', '.wav']
            audio_file_path = None
            for ext in audio_file_extensions:
                aud_path = folder_path / audio_file_name
                if aud_path.suffix.lower() == ext and aud_path.is_file():
                    audio_file_path = aud_path
                    break
            if audio_file_path is None:
                print(f"未找到音频文件 {audio_file_name} 对应于 song ID {level_id} 在 {folder_path}")
                skipped_songs += 1
                skipped_reasons['missing_audio_file'] += 1
                continue

            charts_path = folder_path / charts[0].get('path', '')
            if not charts_path.is_file():
                print(f"未找到 charts 文件 {charts[0].get('path', '')} 对应于 song ID {level_id} 在 {folder_path}")
                skipped_songs += 1
                skipped_reasons['missing_charts_file'] += 1
                continue

            # 创建唯一的ID，确保名称可解析
            unique_id = f"{directory.name}_{level_id}"
            try:
                unique_id.encode('ascii')  # 检查是否为ASCII
            except UnicodeEncodeError:
                print(f"无法解析的 unique_id: {unique_id}，跳过该曲子")
                skipped_songs += 1
                skipped_reasons['unparseable_unique_id'] += 1
                continue

            # 添加到结果
            result[unique_id] = {
                'level': level_data,
                'mp3_path': str(audio_file_path),
                'charts_path': str(charts_path),
                'charter': level_data.get('charter', ''),
                'type': charts[0].get('type', ''),
                'difficulty': chart_difficulty
            }

    print(f"总共跳过的曲子数量: {skipped_songs}")
    for reason, count in skipped_reasons.items():
        print(f"跳过原因 '{reason}': {count} 个曲子")
    return result

def extract_charts(path: str) -> Dict[str, Any]:
    """
    从JSON文件中提取图表数据。

    Args:
        path (str): 图表JSON文件的路径。

    Returns:
        Dict[str, Any]: 图表数据。
    """
    file_path = Path(path)
    if file_path.exists() and file_path.is_file():
        with open(file_path, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                return data
            except json.JSONDecodeError:
                print(f"JSON decode error for file: {path}")
    return {}

def find_single_tempo_songs(data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    筛选出BPM不变的歌曲。

    Args:
        data (Dict[str, Any]): 包含所有歌曲信息的字典。

    Returns:
        List[Dict[str, Any]]: BPM不变的歌曲列表。
    """
    single_tempo_songs = []
    for song_id, song in data.items():
        charts_data = extract_charts(song['charts_path'])
        if charts_data and 'tempo_list' in charts_data:
            if len(charts_data['tempo_list']) == 1:
                single_tempo_songs.append(song)
    return single_tempo_songs

def map_note_to_time(data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    将音符映射到时间。

    Args:
        data (Dict[str, Any]): 图表数据。

    Returns:
        List[Dict[str, Any]]: 每个音符的时间映射信息。
    """
    time_base = data.get('time_base', 1000) 
    offset_universal = 0.033 
    offset = data.get('music_offset', 0) - offset_universal
    tempo_list = sorted(data.get('tempo_list', []), key=lambda x: x['tick'])  
    note_list = data.get('note_list', [])
    
    note_time_map = []
    accumulated_time = 0 
    last_tick = 0  
    if not tempo_list:
        return note_time_map
    current_tempo = tempo_list[0]['value']  
    tempo_index = 0  

    for note in note_list:
        note_tick = note['tick']
        while tempo_index < len(tempo_list) - 1 and tempo_list[tempo_index + 1]['tick'] <= note_tick:
            next_tempo_tick = tempo_list[tempo_index + 1]['tick']
            ticks_in_interval = next_tempo_tick - last_tick
            tick_duration = (current_tempo / time_base) 
            accumulated_time += ticks_in_interval * tick_duration
            last_tick = next_tempo_tick
            tempo_index += 1
            current_tempo = tempo_list[tempo_index]['value']

        ticks_in_interval = note_tick - last_tick
        tick_duration = (current_tempo / time_base) 
        note_time = accumulated_time + ticks_in_interval * tick_duration
        note_time_map.append({
            'note_id': note.get('id', 0),
            'note_tick': note_tick,
            'note_time_microseconds': note_time - offset * 1_000_000,
            'note_type': note.get('type', 0),
            'note_x': note.get('x', 0.0)
        })

    return note_time_map

def generate_mel_spectrogram(
    audio_path: Path,
    log_enable: bool = True,
    bpm_info: List[Dict[str, float]] = None,
    note_info: List[Dict[str, Any]] = None,
    max_frames: int = 5000  # 新增参数，限制最大帧数
) -> dict:
    """
    生成Mel频谱图及相应的标签，并限制其长度不超过max_frames。
    
    Args:
        audio_path (Path): 音频文件路径。
        log_enable (bool): 是否进行对数变换。
        bpm_info (List[Dict[str, float]]): BPM信息。
        note_info (List[Dict[str, Any]]): 音符信息。
        max_frames (int): 最大帧数。
    
    Returns:
        dict: 包含Mel频谱图、presence标签和position_labels的字典。
    """
    data, sr = librosa.load(str(audio_path), sr=SAMPLE_RATE)
    assert sr == SAMPLE_RATE, f"Expected sample rate {SAMPLE_RATE}, but got {sr}"

    mel = librosa.feature.melspectrogram(
        y=data,
        sr=sr, 
        hop_length=HOP_LENGTH, 
        fmin=30.0, 
        n_mels=NMELS, 
        htk=True
    )
    if log_enable:
        mel = np.log(np.clip(mel, 1e-5, None))
    mel = mel.T  # (时间步, 特征)

    # 限制Mel频谱图的长度
    if mel.shape[0] > max_frames:
        mel = mel[:max_frames]

    data_dic = {"mel": mel}

    # 初始化presence标签和position_labels
    presence_labels = np.zeros(mel.shape[0], dtype=int)  # presence标签
    position_labels = -1 * np.ones(mel.shape[0], dtype=int)  # -1表示无音符

    if bpm_info and note_info:
        mel_length = mel.shape[0]
        for note in note_info:
            time_sec = note['note_time_microseconds'] / 1_000_000
            frame_idx = int(time_sec * SAMPLE_RATE / HOP_LENGTH)
            if 0 <= frame_idx < mel_length:  # 确保 frame_idx 非负且不超出
                presence_labels[frame_idx] = 1  # Presence
                # 计算相对于窗口中心的相对位置（假设窗口大小为40）
                position = frame_idx  # 根据具体需求调整
                position_labels[frame_idx] = position

    data_dic["labels"] = presence_labels  # shape: (mel_length,)
    data_dic["position_labels"] = position_labels  # shape: (mel_length,)

    return data_dic

# %% [markdown]
# ## 数据集与数据加载

# %%
class TimeUnit(enum.Enum):
    milliseconds = "milliseconds"
    frames = "frames"
    seconds = "seconds"

class OnsetDataset(Dataset):
    """
    PyTorch Dataset 类，用于加载和提供数据。
    每个样本包含当前帧及其前后40个帧（共81帧）。
    """
    def __init__(self, data: Dict[str, Any], bpm_info: Dict[str, List[Dict[str, float]]], score_positions: Dict[str, List[Dict[str, Any]]], window_size: int = 40, transform=None):
        self.data = data
        self.bpm_info = bpm_info
        self.score_positions = score_positions
        self.transform = transform
        self.window_size = window_size
        self.samples = self.prepare_samples()
        
    def prepare_samples(self) -> List[Tuple[np.ndarray, np.ndarray, int]]:
        """
        准备数据样本，每个样本包含81帧的Mel频谱图和对应的标签。
        """
        samples = []
        for song_id, song in self.data.items():
            mp3_path = song["mp3_path"]
            charts_path = song["charts_path"]
            difficulty = song['difficulty']
            
            mel_dict = generate_mel_spectrogram(
                audio_path=Path(mp3_path),
                log_enable=True,
                bpm_info=self.bpm_info.get(song_id, None),
                note_info=self.score_positions.get(song_id, None)
            )
            if "labels" in mel_dict:
                mel = mel_dict["mel"]  # shape: (num_frames, n_mels)
                labels = mel_dict["labels"]  # shape: (num_frames,)

                num_frames = mel.shape[0]
                for i in range(num_frames):
                    start = max(i - self.window_size, 0)
                    end = min(i + self.window_size + 1, num_frames)
                    
                    # 填充不足的帧
                    pad_before = self.window_size - i if i < self.window_size else 0
                    pad_after = (i + self.window_size + 1) - num_frames if (i + self.window_size + 1) > num_frames else 0
                    
                    mel_window = mel[start:end]
                    if pad_before > 0:
                        mel_window = np.pad(mel_window, ((pad_before, 0), (0, 0)), mode='constant')
                    if pad_after > 0:
                        mel_window = np.pad(mel_window, ((0, pad_after), (0, 0)), mode='constant')
                    
                    label = labels[i]
                    
                    samples.append((mel_window, label, difficulty))
        return samples

    def __len__(self) -> int:
        return len(self.samples)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        mel_window, label, difficulty = self.samples[idx]
        mel_window = torch.from_numpy(mel_window).float()  # shape: (81, n_mels)
        label = torch.tensor(label).float()  # shape: ()
        
        if self.transform:
            mel_window, label = self.transform(mel_window, label)

        return mel_window, label, difficulty

def collate_fn_padded(batch: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
    """
    自定义的collate_fn，用于处理批次数据。
    """
    mel, labels, difficulties = zip(*batch)
    
    mel = torch.stack(mel, dim=0)  # (batch_size, 81, n_mels)
    labels = torch.stack(labels, dim=0)  # (batch_size,)

    return mel, labels, difficulties

# %% [markdown]
# ## Presence模型定义（保持代码2几乎不变）

# %%
class CNNOnsetDetector(nn.Module):
    """
    基于卷积神经网络（CNN）的Onset检测模型。
    """
    def __init__(self, input_channels: int, num_classes: int = 1, dropout: float = 0.5):
        super(CNNOnsetDetector, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        
        # 计算池化后的特征长度
        self.pool_layers = 3
        self.feature_length = 81
        for _ in range(self.pool_layers):
            self.feature_length = self.feature_length // 2
        self.feature_length = max(self.feature_length, 1)  # 防止特征长度为0
        
        self.fc1 = nn.Linear(256 * self.feature_length, 512)  # 假设经过三次池化
        self.fc2 = nn.Linear(512, num_classes)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, 81, n_mels]
        """
        x = x.permute(0, 2, 1)  # 转换为 [batch_size, n_mels, 81]
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 64, 40]
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 128, 20]
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 256, 10]
        
        x = x.view(x.size(0), -1)  # [batch_size, 2560]
        x = self.dropout(x)
        x = self.fc1(x)  # [batch_size, 512]
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)  # [batch_size, num_classes]
        
        return x

class CNNOnsetFeatureExtractor(nn.Module):
    def __init__(self, input_channels: int, dropout: float = 0.5):
        super(CNNOnsetFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        
        self.pool_layers = 3
        self.feature_length = 81
        for _ in range(self.pool_layers):
            self.feature_length = self.feature_length // 2
        self.feature_length = max(self.feature_length, 1)
        
        self.fc1 = nn.Linear(256 * self.feature_length, 512)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: (batch_size, 81, n_mels)
        x = x.permute(0, 2, 1)  # (batch_size, n_mels, 81)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)  # (batch_size, 64, 40)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)  # (batch_size, 128, 20)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)  # (batch_size, 256, 10)
        
        x = x.view(x.size(0), -1)  # (batch_size, 2560)
        x = self.dropout(x)
        x = self.fc1(x)  # (batch_size, 512)
        x = self.relu(x)
        x = self.dropout(x)
        
        # 不再执行fc2，直接返回512维特征
        return x


# %% [markdown]
# ## Type模型定义（修改为CNN）

# %%
class CNNTypePredictor(nn.Module):
    """
    基于卷积神经网络（CNN）的Type预测模型。
    """
    def __init__(self, input_channels: int, num_types: int = 5, dropout: float = 0.5):
        super(CNNTypePredictor, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.conv3 = nn.Conv1d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        
        # 计算池化后的特征长度
        self.pool_layers = 3
        self.feature_length = 81
        for _ in range(self.pool_layers):
            self.feature_length = self.feature_length // 2
        self.feature_length = max(self.feature_length, 1)  # 防止特征长度为0
        
        self.fc1 = nn.Linear(256 * self.feature_length, 512)  # 假设经过三次池化
        self.fc_type = nn.Linear(512, num_types)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, 81, n_mels]
        """
        x = x.permute(0, 2, 1)  # 转换为 [batch_size, n_mels, 81]
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 64, 40]
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 128, 20]
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)  # [batch_size, 256, 10]
        
        x = x.view(x.size(0), -1)  # [batch_size, 2560]
        x = self.dropout(x)
        x = self.fc1(x)  # [batch_size, 512]
        x = self.relu(x)
        x = self.dropout(x)
        type_out = self.fc_type(x)  # [batch_size, num_types]
        
        return type_out

# %% [markdown]
# ## 可视化函数

# %%
def visualize_presence_predictions(mel: np.ndarray, labels: np.ndarray, preds: np.ndarray, start_time: float = 0, end_time: float = 5):
    """
    可视化模型的presence预测结果与真实标签。

    Args:
        mel (np.ndarray): Mel频谱图，形状为 (seq_len, feature_dim)
        labels (np.ndarray): 真实labels，形状为 (seq_len,)
        preds (np.ndarray): 模型预测的presence分数，形状为 (seq_len,)
        start_time (float): 可视化的开始时间（秒）
        end_time (float): 可视化的结束时间（秒）
    """
    # 应用 Sigmoid 激活
    presence_pred = 1 / (1 + np.exp(-preds))
    
    # 应用阈值为0.5
    presence_final = (presence_pred >= 0.5).astype(int)

    # 计算时间轴
    total_time = mel.shape[0] * HOP_LENGTH / SAMPLE_RATE
    times = np.linspace(0, total_time, num=mel.shape[0])

    # 确定可视化的帧范围
    start_frame = int(start_time * SAMPLE_RATE / HOP_LENGTH)
    end_frame = int(end_time * SAMPLE_RATE / HOP_LENGTH)

    # 确保end_frame不超过序列长度
    end_frame = min(end_frame, mel.shape[0])

    # 裁剪数据
    mel_cropped = mel[start_frame:end_frame]
    labels_cropped = labels[start_frame:end_frame]
    presence_pred_cropped = presence_pred[start_frame:end_frame]
    presence_final_cropped = presence_final[start_frame:end_frame]

    # 创建子图
    fig, axs = plt.subplots(2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})

    # 绘制Mel频谱图
    img = librosa.display.specshow(
        mel_cropped.T,
        sr=SAMPLE_RATE,
        hop_length=HOP_LENGTH,
        x_coords=times[start_frame:end_frame],
        ax=axs[0],
        x_axis='time',
        y_axis='mel',
        fmax=8000
    )
    axs[0].set_title('Mel Spectrogram')
    fig.colorbar(img, ax=axs[0], format='%+2.0f dB')

    # 绘制Presence预测与真实标签
    axs[1].plot(
        times[start_frame:end_frame],
        presence_pred_cropped.flatten(),
        label='Presence Prediction (Raw)',
        color='red',
        alpha=0.6
    )
    axs[1].plot(
        times[start_frame:end_frame],
        presence_final_cropped,
        label='Presence Prediction (Threshold=0.50)',
        color='orange',
        alpha=0.6
    )
    axs[1].plot(
        times[start_frame:end_frame],
        labels_cropped.flatten(),
        label='Presence Ground Truth',
        color='blue',
        linestyle='dashed'
    )

    axs[1].set_title('Presence Predictions vs Ground Truth (Threshold: 0.50)')
    axs[1].legend(loc='upper right')
    axs[1].set_xlabel('Time (s)')
    axs[1].set_ylabel('Presence')

    plt.tight_layout()
    plt.show()

def visualize_presence_predictions_single(mel: np.ndarray, label: np.ndarray, pred: np.ndarray, start_time: float = 0, end_time: float = 5):
    """
    可视化单个样本的presence预测结果与真实标签。
    
    Args:
        mel (np.ndarray): Mel频谱图，形状为 (81, n_mels)
        label (np.ndarray): 真实labels，形状为 (1,)
        pred (np.ndarray): 模型预测的presence分数，形状为 (1,)
        start_time (float): 可视化的开始时间（秒）
        end_time (float): 可视化的结束时间（秒）
    """
    # 应用 Sigmoid 激活
    presence_pred = 1 / (1 + np.exp(-pred))
    
    # 应用阈值为0.5
    presence_final = (presence_pred >= 0.5).astype(int)
    
    # 计算时间轴（假设窗口中心帧对应当前时间）
    total_time = WINDOW_SIZE * 2 * HOP_LENGTH / SAMPLE_RATE  # 前后帧总时间
    times = np.linspace(-WINDOW_SIZE * HOP_LENGTH / SAMPLE_RATE, WINDOW_SIZE * HOP_LENGTH / SAMPLE_RATE, num=mel.shape[0])
    
    # 创建子图
    fig, axs = plt.subplots(2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})
    
    # 绘制Mel频谱图
    img = librosa.display.specshow(
        mel.T,
        sr=SAMPLE_RATE,
        hop_length=HOP_LENGTH,
        x_coords=times,
        ax=axs[0],
        x_axis='time',
        y_axis='mel',
        fmax=8000
    )
    axs[0].set_title('Mel Spectrogram')
    fig.colorbar(img, ax=axs[0], format='%+2.0f dB')
    
    # 绘制Presence预测与真实标签
    axs[1].bar(0, presence_pred, label='Presence Prediction (Raw)', color='red', alpha=0.6)
    axs[1].bar(0, presence_final, label='Presence Prediction (Threshold=0.50)', color='orange', alpha=0.6)
    axs[1].bar(0, label, label='Presence Ground Truth', color='blue', alpha=0.6)
    
    axs[1].set_title('Presence Predictions vs Ground Truth')
    axs[1].legend(loc='upper right')
    axs[1].set_xlabel('Current Frame')
    axs[1].set_ylabel('Presence')
    
    plt.tight_layout()
    plt.show()

def visualize_type_predictions(mel: np.ndarray, labels: np.ndarray, preds: np.ndarray, start_time: float = 0, end_time: float = 5, hop_length: int = HOP_LENGTH, sample_rate: int = SAMPLE_RATE):
    """
    可视化模型的Type预测结果与真实标签。

    Args:
        mel (np.ndarray): Mel频谱图，形状为 (seq_len, feature_dim)
        labels (np.ndarray): 真实labels，形状为 (seq_len,)
        preds (np.ndarray): 模型预测的type分数，形状为 (seq_len, num_types)
        start_time (float): 可视化的开始时间（秒）
        end_time (float): 可视化的结束时间（秒）
        hop_length (int): hop_length参数
        sample_rate (int): 采样率
    """
    # 应用 Softmax 激活
    preds_prob = F.softmax(torch.tensor(preds), dim=-1).numpy()

    # 计算时间轴
    total_time = mel.shape[0] * hop_length / sample_rate
    times = np.linspace(0, total_time, num=mel.shape[0])

    # 确定可视化的帧范围
    start_frame = int(start_time * sample_rate / hop_length)
    end_frame = int(end_time * sample_rate / hop_length)

    # 裁剪数据
    labels_cropped = labels[start_frame:end_frame]
    preds_cropped = preds_prob[start_frame:end_frame]
    times_cropped = times[start_frame:end_frame]

    # 定义颜色映射（使用matplotlib的tab10颜色集）
    cmap = plt.get_cmap('tab10')
    num_types = preds_cropped.shape[1]
    colors = [cmap(i) for i in range(num_types)]

    # 创建图表
    plt.figure(figsize=(15, 8))

    # 绘制类型概率
    for type_idx in range(num_types):
        plt.plot(
            times_cropped,
            preds_cropped[:, type_idx],
            label=f'Type {type_idx}',
            color=colors[type_idx],
            alpha=0.6
        )

    # 绘制真实标签
    for idx, label in enumerate(labels_cropped):
        if label == 0:
            continue  # 跳过类型0（假设为无事件）
        plt.scatter(
            times_cropped[idx],
            preds_cropped[idx, label],
            color=colors[label],
            marker='x',
            s=50,
            label=f'Ground Truth Type {label}' if idx == 0 else "",  # 只为图例添加一次
            zorder=5
        )
        # 绘制竖线
        plt.axvline(
            x=times_cropped[idx],
            color=colors[label],
            linestyle='--',
            alpha=0.5,
            linewidth=1
        )

    # 添加标题和标签
    plt.title('Type Probabilities and Ground Truth', fontsize=14)
    plt.xlabel('Time (s)', fontsize=12)
    plt.ylabel('Probability', fontsize=12)

    # 设置图例，避免重复
    handles, labels_legend = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels_legend, handles))
    plt.legend(by_label.values(), by_label.keys(), loc='upper right', fontsize='small')

    # 设置 x 轴范围
    plt.xlim(start_time, end_time)

    # 只在底部显示 y 轴标签
    ax = plt.gca()
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.yaxis.set_label_coords(1.05, 0.5)

    # 显示图表
    plt.tight_layout()
    plt.show()

def visualize_position_predictions(true_positions: List[int], pred_positions: List[int], num_samples: int = 100):
    """
    可视化真实位置与预测位置的对比。
    
    Args:
        true_positions (List[int]): 真实位置列表。
        pred_positions (List[int]): 预测位置列表。
        num_samples (int): 可视化的样本数量。
    """
    plt.figure(figsize=(15, 6))
    if len(true_positions) < num_samples:
        num_samples = len(true_positions)
    indices = np.random.choice(len(true_positions), size=num_samples, replace=False)
    true = np.array(true_positions)[indices]
    pred = np.array(pred_positions)[indices]
    
    plt.scatter(range(num_samples), true, label='True Position', alpha=0.6, color='blue')
    plt.scatter(range(num_samples), pred, label='Predicted Position', alpha=0.6, color='red')
    plt.title('True vs Predicted Note Positions')
    plt.xlabel('Sample Index')
    plt.ylabel('Position Index')
    plt.legend()
    plt.show()

def visualize_position_distribution(true_positions: List[int], pred_positions: List[int]):
    """
    可视化真实位置与预测位置的分布。
    
    Args:
        true_positions (List[int]): 真实位置列表。
        pred_positions (List[int]): 预测位置列表。
    """
    plt.figure(figsize=(10, 6))
    plt.hist(true_positions, bins=81, alpha=0.5, label='True Positions', color='blue', density=True)
    plt.hist(pred_positions, bins=81, alpha=0.5, label='Predicted Positions', color='red', density=True)
    plt.title('Distribution of True and Predicted Positions')
    plt.xlabel('Position Index')
    plt.ylabel('Density')
    plt.legend()
    plt.show()

# %% [markdown]
# ## 训练与验证函数

# %%
# Presence模型训练函数
def train_epoch_cnn(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, loss_fn: nn.Module) -> float:
    """
    训练一个epoch（CNN版）。
    """
    model.train()
    running_loss = 0.0
    progress_bar = tqdm.tqdm(dataloader, desc="Training Presence", leave=False)

    for mel, labels, difficulties in progress_bar:
        mel = mel.to(device)  # (batch_size, 81, n_mels)
        labels = labels.to(device)  # (batch_size,)

        optimizer.zero_grad()

        # 前向传播
        outputs = model(mel)  # (batch_size, 1)
        outputs = outputs.squeeze(1)  # (batch_size)

        # 计算损失
        loss = loss_fn(outputs, labels)

        # 反向传播
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss

# Presence模型验证函数
def validate_epoch_cnn(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module) -> Tuple[float, Dict[int, Dict[str, List[Any]]]]:
    model.eval()
    running_loss = 0.0
    difficulty_preds = defaultdict(lambda: {'y_true': [], 'y_scores': []})
    progress_bar = tqdm.tqdm(dataloader, desc="Validation Presence", leave=False)

    with torch.no_grad():
        for mel, labels, difficulties in progress_bar:
            mel = mel.to(device)  # (batch_size, 81, n_mels)
            labels = labels.to(device)  # (batch_size,)

            # 前向传播
            outputs = model(mel)  # (batch_size, 1)
            outputs = outputs.squeeze(1)  # (batch_size)
            # 移除以下行，因为labels已经是 (batch_size,) 形状
            # labels = labels.squeeze(1)  # (batch_size)

            # 计算损失
            loss = loss_fn(outputs, labels)

            running_loss += loss.item()
            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

            # 收集预测分数和真实标签，按难度级别分组
            presence_pred_np = outputs.cpu().numpy()
            presence_target_np = labels.cpu().numpy()

            for i in range(mel.size(0)):
                difficulty = difficulties[i]
                y_true = presence_target_np[i]
                y_score = presence_pred_np[i]
                difficulty_preds[difficulty]['y_true'].append(y_true)
                difficulty_preds[difficulty]['y_scores'].append(y_score)

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss, difficulty_preds

# Type模型训练函数（CNN版）
def train_epoch_cnn_type(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, loss_fn: nn.Module, num_types: int) -> float:
    """
    训练一个epoch（Type CNN版）。
    """
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    progress_bar = tqdm.tqdm(dataloader, desc="Training Type CNN", leave=False)

    for mel, labels, lengths in progress_bar:
        mel = mel.to(device)  # (batch_size, 81, n_mels)
        labels = labels.to(device)  # (batch_size,)

        optimizer.zero_grad()

        # 前向传播
        type_pred = model(mel)  # (batch_size, num_types)

        # 计算损失
        loss = loss_fn(type_pred, labels)

        # 反向传播
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()

        # 计算准确率
        preds = torch.argmax(F.softmax(type_pred, dim=1), dim=1)
        acc = (preds == labels).float().mean().item()
        running_acc += acc

        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{acc*100:.2f}%'})

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = running_acc / len(dataloader)
    return epoch_loss, epoch_acc

# Type模型验证函数（CNN版）
def validate_epoch_cnn_type(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module, num_types: int) -> Tuple[float, float]:
    """
    在验证集上评估模型（CNN版）。
    """
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    progress_bar = tqdm.tqdm(dataloader, desc="Validation Type CNN", leave=False)

    with torch.no_grad():
        for mel, labels, lengths in progress_bar:
            mel = mel.to(device)
            labels = labels.to(device)

            # 前向传播
            type_pred = model(mel)  # (batch_size, num_types)

            # 计算损失
            loss = loss_fn(type_pred, labels)

            running_loss += loss.item()

            # 计算准确率
            preds = torch.argmax(F.softmax(type_pred, dim=1), dim=1)
            acc = (preds == labels).float().mean().item()
            running_acc += acc

            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{acc*100:.2f}%'})

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = running_acc / len(dataloader)
    return epoch_loss, epoch_acc

# %% [markdown]
# ## 模型评估函数

# %%
def evaluate_test_set_cnn(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module):
    """
    在测试集上评估CNN模型。
    """
    model.eval()
    difficulty_metrics = defaultdict(lambda: {'y_true': [], 'y_pred': []})
    all_preds, all_labels = [], []
    progress_bar = tqdm.tqdm(dataloader, desc="Testing Presence CNN", leave=False)

    with torch.no_grad():
        for mel, labels, difficulties in progress_bar:
            mel = mel.to(device)  # (batch_size, 81, n_mels)
            labels = labels.to(device)  # (batch_size,)

            # 前向传播
            outputs = model(mel)  # (batch_size, 1)
            outputs = outputs.squeeze(1)  # (batch_size)
            labels = labels.squeeze(1)  # (batch_size)

            # 计算损失（可选）
            loss = loss_fn(outputs, labels)

            # 收集预测和标签
            presence_pred_np = outputs.cpu().numpy()
            presence_target_np = labels.cpu().numpy()

            # 应用 Sigmoid 激活
            presence_pred_sigmoid = 1 / (1 + np.exp(-presence_pred_np))

            # 使用阈值0.5进行预测
            y_pred = (presence_pred_sigmoid >= 0.5).astype(int)
            y_true = presence_target_np.astype(int)

            # 收集所有预测和标签用于分布
            all_preds.extend(presence_pred_sigmoid.tolist())
            all_labels.extend(presence_target_np.tolist())

            # 按难度级别分组
            for i in range(mel.size(0)):
                difficulty = difficulties[i]
                difficulty_metrics[difficulty]['y_true'].append(y_true[i])
                difficulty_metrics[difficulty]['y_pred'].append(y_pred[i])

    # 转换为 NumPy 数组
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # 绘制分布
    plt.figure(figsize=(10, 6))
    plt.hist(all_preds, bins=50, alpha=0.7, label="Predictions", color="blue", density=True)
    plt.hist(all_labels, bins=50, alpha=0.7, label="Ground Truth", color="orange", density=True)
    plt.title("Frame-wise Prediction and Ground Truth Distribution")
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

    # 计算每个难度级别的指标
    final_metrics = {}
    for diff, metrics in difficulty_metrics.items():
        y_true = np.array(metrics['y_true'])
        y_pred = np.array(metrics['y_pred'])
        acc = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        final_metrics[diff] = {
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        }
        print(f"Difficulty {diff}: Accuracy={acc:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1 Score={f1:.4f}")

    return final_metrics

def evaluate_test_set_type_cnn(model: nn.Module, dataloader: DataLoader, device: torch.device, loss_fn: nn.Module, num_types: int):
    """
    在测试集上评估Type CNN模型。
    """
    model.eval()
    all_preds = []
    all_labels = []
    progress_bar = tqdm.tqdm(dataloader, desc="Testing Type CNN", leave=False)

    with torch.no_grad():
        for mel, labels, lengths in progress_bar:
            mel = mel.to(device)  # (batch_size, 81, n_mels)
            labels = labels.to(device)  # (batch_size,)

            # 前向传播
            type_pred = model(mel)  # (batch_size, num_types)

            # 计算损失
            loss = loss_fn(type_pred, labels)

            running_loss += loss.item()

            # 收集预测和标签
            preds = torch.argmax(F.softmax(type_pred, dim=1), dim=1).cpu().numpy()
            true = labels.cpu().numpy()

            all_preds.extend(preds.tolist())
            all_labels.extend(true.tolist())

    # 计算指标
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    print(f"Type Prediction - Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1_score': f1}

# %% [markdown]
# ## 数据准备

# %%
current_directory = Path().cwd()
dataset_dirs = [
    current_directory / "../dataset/A",
    current_directory / "../dataset/B",
    current_directory / "../dataset/C",
    current_directory / "../dataset/Z"
]
data = extract_level_json_multiple(dataset_dirs, min_difficulty=15)
print(f"Filtered Data Count (Difficulty>=15): {len(data)}")
# data = dict(list(data.items())[:30])
bpm_info_dict = {}
score_positions_dict = {}

for unique_id, song in data.items():
    level_data = song['level']
    song_id = unique_id  # 使用唯一ID
    charts_data = extract_charts(song['charts_path'])
    if charts_data:
        bpm_info = charts_data.get('tempo_list', [])
        bpm_info_dict[song_id] = bpm_info
        note_time_map = map_note_to_time(charts_data) 
        # 每个音符的详细信息，包括时间、类型和位置
        score_positions = [] 
        for note in note_time_map:
            score_positions.append({
                'note_time_microseconds': note['note_time_microseconds'],
                'note_type': note.get('note_type', 0),  # 确保此字段存在
                'note_x': note.get('note_x', 0.0)      # 确保此字段存在
            })
        score_positions_dict[song_id] = score_positions

# %% [markdown]
# ## 定义数据集和 DataLoader

# %%
# 定义Presence数据集和 DataLoader
presence_dataset = OnsetDataset(
    data=data, 
    bpm_info=bpm_info_dict, 
    score_positions=score_positions_dict,
    window_size=WINDOW_SIZE  # 前后40帧
)

train_size = int(0.8 * len(presence_dataset))
val_size = len(presence_dataset) - train_size
train_dataset, val_dataset = random_split(presence_dataset, [train_size, val_size])

print(f"Presence Train Size: {len(train_dataset)}")
print(f"Presence Validation Size: {len(val_dataset)}")
# 一下是train 的代码---------------------------------------------------------------------------------------------------
# 创建测试集（使用验证集的一部分作为测试集）
test_size = int(0.5 * len(val_dataset))
val_size = len(val_dataset) - test_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

print(f"Presence Validation Size after split: {len(val_dataset)}")
print(f"Presence Test Size: {len(test_dataset)}")

presence_train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_fn_padded)
presence_val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded)
presence_test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded)

# %% [markdown]
# ## Presence模型训练与评估

# %%
# 初始化Presence模型
presence_model = CNNOnsetDetector(input_channels=NMELS, num_classes=1, dropout=DROPOUT)
presence_model.to(device)

# 定义优化器和损失函数
presence_optimizer = torch.optim.Adam(presence_model.parameters(), lr=LEARNING_RATE)
presence_loss_fn = nn.BCEWithLogitsLoss()

# 训练最佳Presence模型
best_presence_val_loss = float('inf')
best_presence_model_path = "model/best_cnn_onset_model.pth"

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Presence Model")
    
    # 训练
    train_loss = train_epoch_cnn(presence_model, presence_train_loader, presence_optimizer, device, presence_loss_fn)
    print(f"Train Loss: {train_loss:.4f}")
    
    # 验证
    val_loss, val_difficulty_preds = validate_epoch_cnn(presence_model, presence_val_loader, device, presence_loss_fn)
    print(f"Validation Loss: {val_loss:.4f}")
    
    # 如果验证损失更低，则保存模型
    if val_loss < best_presence_val_loss:
        best_presence_val_loss = val_loss
        torch.save(presence_model.state_dict(), best_presence_model_path)
        print("Saved Best Presence Model")

# 加载最佳Presence模型
presence_model.load_state_dict(torch.load(best_presence_model_path))
presence_model.to(device)
presence_model.eval()
# 一下是train 的代码---------------------------------------------------------------------------------------------------



# %% [markdown]
# ## 定义Type数据集和 DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=2.0, gamma=2, reduction='mean'):
        """
        alpha: 类别平衡因子，可用于在类不平衡时对少数类进行重权重
        gamma: 难易度调控参数，gamma越大，越专注在难分类的样本上
        reduction: 损失聚合方式，'mean'或'sum'
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: [batch_size, num_classes]
        # targets: [batch_size]

        # 获取预测的概率分布
        probs = F.softmax(inputs, dim=1)
        # 取出对应真实类别的预测概率
        pt = probs[range(len(targets)), targets]

        # focal loss公式
        loss = -self.alpha * (1 - pt)**self.gamma * torch.log(pt)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# %%
class TypeDataset(Dataset):
    """
    PyTorch Dataset 类，用于加载和提供数据。
    每个样本包含当前帧及其前后40个帧（共81帧）。
    """
    def __init__(self, data: Dict[str, Any], bpm_info: Dict[str, List[Dict[str, float]]], score_positions: Dict[str, List[Dict[str, Any]]], window_size: int = 40, transform=None):
        self.data = data
        self.bpm_info = bpm_info
        self.score_positions = score_positions
        self.transform = transform
        self.window_size = window_size
        self.samples = self.prepare_samples()
        
    def prepare_samples(self) -> List[Tuple[np.ndarray, int, int]]:
        """
        准备数据样本，每个样本包含81帧的Mel频谱图和对应的类型标签。
        """
        samples = []
        for song_id, song in self.data.items():
            mp3_path = song["mp3_path"]
            charts_path = song["charts_path"]
            difficulty = song['difficulty']
            
            mel_dict = generate_mel_spectrogram(
                audio_path=Path(mp3_path),
                log_enable=True,
                bpm_info=self.bpm_info.get(song_id, None),
                note_info=self.score_positions.get(song_id, None)
            )
            if "labels" in mel_dict:
                mel = mel_dict["mel"]  # shape: (num_frames, n_mels)
                labels = mel_dict["labels"]  # shape: (num_frames,)

                num_frames = mel.shape[0]
                for i in range(num_frames):
                    start = max(i - self.window_size, 0)
                    end = min(i + self.window_size + 1, num_frames)
                    
                    # 填充不足的帧
                    pad_before = self.window_size - i if i < self.window_size else 0
                    pad_after = (i + self.window_size + 1) - num_frames if (i + self.window_size + 1) > num_frames else 0
                    
                    mel_window = mel[start:end]
                    if pad_before > 0:
                        mel_window = np.pad(mel_window, ((pad_before, 0), (0, 0)), mode='constant')
                    if pad_after > 0:
                        mel_window = np.pad(mel_window, ((0, pad_after), (0, 0)), mode='constant')
                    
                    label = labels[i]
                    
                    samples.append((mel_window, label, difficulty))
        return samples

    def __len__(self) -> int:
        return len(self.samples)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        mel_window, label, difficulty = self.samples[idx]
        mel_window = torch.from_numpy(mel_window).float()  # shape: (81, n_mels)
        label = torch.tensor(label).long()  # shape: ()
        
        if self.transform:
            mel_window, label = self.transform(mel_window, label)

        return mel_window, label, difficulty

def collate_fn_padded_type(batch: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
    """
    自定义的collate_fn，用于处理批次数据。
    """
    mel, labels, difficulties = zip(*batch)
    
    mel = torch.stack(mel, dim=0)  # (batch_size, 81, n_mels)
    labels = torch.stack(labels, dim=0)  # (batch_size,)

    return mel, labels, difficulties

type_dataset = TypeDataset(
    data=data, 
    bpm_info=bpm_info_dict, 
    score_positions=score_positions_dict,
    window_size=WINDOW_SIZE  # 前后40帧
)

# 一下是train 的代码---------------------------------------------------------------------------------------------------
# # 划分数据集
type_train_size = int(0.8 * len(type_dataset))
type_val_size = len(type_dataset) - type_train_size
type_train_dataset, type_val_dataset = random_split(type_dataset, [type_train_size, type_val_size])

print(f"Type Train Size: {len(type_train_dataset)}")
print(f"Type Validation Size: {len(type_val_dataset)}")

# 创建测试集（使用验证集的一部分作为测试集）
type_test_size = int(0.5 * len(type_val_dataset))
type_val_size = len(type_val_dataset) - type_test_size
type_val_dataset, type_test_dataset = random_split(type_val_dataset, [type_val_size, type_test_size])

print(f"Type Validation Size after split: {len(type_val_dataset)}")
print(f"Type Test Size: {len(type_test_dataset)}")

type_train_loader = DataLoader(type_train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_fn_padded_type)
type_val_loader = DataLoader(type_val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded_type)
type_test_loader = DataLoader(type_test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_fn_padded_type)

# %% [markdown]
# ## Type模型定义与训练

# %%
# 初始化Type模型
num_types = 5  # 根据需求调整
type_model = CNNTypePredictor(input_channels=NMELS, num_types=num_types, dropout=DROPOUT)
type_model.to(device)

# 定义优化器和损失函数
type_optimizer = torch.optim.Adam(type_model.parameters(), lr=LEARNING_RATE)
type_loss_fn = FocalLoss(alpha=1.0, gamma=2, reduction='mean')


# 训练最佳Type模型
best_type_val_loss = float('inf')
best_type_model_path = "model/best_cnn_type_model.pth"

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Type Model")
    
    # 训练
    train_loss, train_acc = train_epoch_cnn_type(type_model, type_train_loader, type_optimizer, device, type_loss_fn, num_types)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%")
    
    # 验证
    val_loss, val_acc = validate_epoch_cnn_type(type_model, type_val_loader, device, type_loss_fn, num_types)
    print(f"Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc*100:.2f}%")
    
    # 如果验证损失更低，则保存模型
    if val_loss < best_type_val_loss:
        best_type_val_loss = val_loss
        torch.save(type_model.state_dict(), best_type_model_path)
        print("Saved Best Type Model")

# 加载最佳Type模型
type_model.load_state_dict(torch.load(best_type_model_path))
type_model.to(device)
type_model.eval()
# 是train 的代码---------------------------------------------------------------------------------------------------


num_types = 5  # 根据需求调整
best_presence_model_path = "model/best_cnn_onset_model.pth"
best_type_model_path = "model/best_cnn_type_model.pth"

presence_model = CNNOnsetDetector(input_channels=NMELS, num_classes=1, dropout=DROPOUT)
presence_model.load_state_dict(torch.load(best_presence_model_path, map_location=device))
presence_model.to(device)
presence_model.eval()  # 切换到评估模式

type_model = CNNTypePredictor(input_channels=NMELS, num_types=num_types, dropout=DROPOUT)
type_model.load_state_dict(torch.load(best_type_model_path, map_location=device))
type_model.to(device)
type_model.eval()

presence_model.load_state_dict(torch.load(best_presence_model_path))
presence_model.to(device)
presence_model.eval()


Using device: cuda
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.clouddiver.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/arwtdydhqhfa.helamind.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/andogaru.fumiko.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.summernight.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.cereris.cytoidlevel/level.json
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.alone.cytoidlevel/level.json
未找到音频文件 �TAKUMI³�OЯDIN -Apocalyptic War-(Re Mastering).mp3 对应于 song ID ant.ordin-tc 在 /data1/yuchen/cytoid/final_code/../dataset/A/ant.ordin-tc.cytoidlevel
JSON 解析错误: /data1/yuchen/cytoid/final_code/../dataset/A/anthony.lolk_muricaaaaa.cytoidlevel/level.json
未找到音频文件 Langley_D - deli.+駄々子 - 最果ての勇者にラブソングを.ogg 对应于 song ID anoppo.furiy 在 /data1/yuchen/cytoid/final_code/../dataset/A/anoppo.furiy.cytoidleve

[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?


Presence Train Size: 858862
Presence Validation Size: 214716
Presence Validation Size after split: 107358
Presence Test Size: 107358
Epoch 1/50 - Training Presence Model


                                                                                      

Train Loss: 0.4299


                                                                                      

Validation Loss: 0.3996
Saved Best Presence Model
Epoch 2/50 - Training Presence Model


                                                                                      

Train Loss: 0.3967


                                                                                      

Validation Loss: 0.3695
Saved Best Presence Model
Epoch 3/50 - Training Presence Model


                                                                                      

Train Loss: 0.3722


                                                                                      

Validation Loss: 0.3430
Saved Best Presence Model
Epoch 4/50 - Training Presence Model


                                                                                      

Train Loss: 0.3535


                                                                                      

Validation Loss: 0.3294
Saved Best Presence Model
Epoch 5/50 - Training Presence Model


                                                                                      

Train Loss: 0.3393


                                                                                      

Validation Loss: 0.3150
Saved Best Presence Model
Epoch 6/50 - Training Presence Model


                                                                                      

Train Loss: 0.3280


                                                                                      

Validation Loss: 0.3057
Saved Best Presence Model
Epoch 7/50 - Training Presence Model


                                                                                      

Train Loss: 0.3201


                                                                                      

Validation Loss: 0.2988
Saved Best Presence Model
Epoch 8/50 - Training Presence Model


                                                                                      

Train Loss: 0.3127


                                                                                      

Validation Loss: 0.2934
Saved Best Presence Model
Epoch 9/50 - Training Presence Model


                                                                                      

Train Loss: 0.3074


                                                                                      

Validation Loss: 0.2872
Saved Best Presence Model
Epoch 10/50 - Training Presence Model


                                                                                      

Train Loss: 0.3027


                                                                                      

Validation Loss: 0.2834
Saved Best Presence Model
Epoch 11/50 - Training Presence Model


                                                                                      

Train Loss: 0.2976


                                                                                      

Validation Loss: 0.2771
Saved Best Presence Model
Epoch 12/50 - Training Presence Model


                                                                                      

Train Loss: 0.2942


                                                                                      

Validation Loss: 0.2765
Saved Best Presence Model
Epoch 13/50 - Training Presence Model


                                                                                      

Train Loss: 0.2910


                                                                                      

Validation Loss: 0.2746
Saved Best Presence Model
Epoch 14/50 - Training Presence Model


                                                                                      

Train Loss: 0.2878


                                                                                      

Validation Loss: 0.2748
Epoch 15/50 - Training Presence Model


                                                                                      

Train Loss: 0.2853


                                                                                      

Validation Loss: 0.2700
Saved Best Presence Model
Epoch 16/50 - Training Presence Model


                                                                                      

Train Loss: 0.2830


                                                                                      

Validation Loss: 0.2707
Epoch 17/50 - Training Presence Model


                                                                                      

Train Loss: 0.2807


                                                                                      

Validation Loss: 0.2649
Saved Best Presence Model
Epoch 18/50 - Training Presence Model


                                                                                      

Train Loss: 0.2785


                                                                                      

Validation Loss: 0.2646
Saved Best Presence Model
Epoch 19/50 - Training Presence Model


                                                                                      

Train Loss: 0.2768


                                                                                      

Validation Loss: 0.2626
Saved Best Presence Model
Epoch 20/50 - Training Presence Model


                                                                                      

Train Loss: 0.2749


                                                                                      

Validation Loss: 0.2617
Saved Best Presence Model
Epoch 21/50 - Training Presence Model


                                                                                      

Train Loss: 0.2735


                                                                                      

Validation Loss: 0.2600
Saved Best Presence Model
Epoch 22/50 - Training Presence Model


                                                                                      

Train Loss: 0.2713


                                                                                      

Validation Loss: 0.2579
Saved Best Presence Model
Epoch 23/50 - Training Presence Model


                                                                                      

Train Loss: 0.2691


                                                                                      

Validation Loss: 0.2619
Epoch 24/50 - Training Presence Model


                                                                                      

Train Loss: 0.2684


                                                                                      

Validation Loss: 0.2576
Saved Best Presence Model
Epoch 25/50 - Training Presence Model


                                                                                      

Train Loss: 0.2672


                                                                                      

Validation Loss: 0.2602
Epoch 26/50 - Training Presence Model


                                                                                      

Train Loss: 0.2658


                                                                                      

Validation Loss: 0.2557
Saved Best Presence Model
Epoch 27/50 - Training Presence Model


                                                                                      

Train Loss: 0.2646


                                                                                      

Validation Loss: 0.2583
Epoch 28/50 - Training Presence Model


                                                                                      

Train Loss: 0.2639


                                                                                      

Validation Loss: 0.2545
Saved Best Presence Model
Epoch 29/50 - Training Presence Model


                                                                                      

Train Loss: 0.2626


                                                                                      

Validation Loss: 0.2545
Saved Best Presence Model
Epoch 30/50 - Training Presence Model


                                                                                      

Train Loss: 0.2611


                                                                                      

Validation Loss: 0.2543
Saved Best Presence Model
Epoch 31/50 - Training Presence Model


                                                                                      

Train Loss: 0.2607


                                                                                      

Validation Loss: 0.2520
Saved Best Presence Model
Epoch 32/50 - Training Presence Model


                                                                                      

Train Loss: 0.2594


                                                                                      

Validation Loss: 0.2531
Epoch 33/50 - Training Presence Model


                                                                                      

Train Loss: 0.2585


                                                                                      

Validation Loss: 0.2526
Epoch 34/50 - Training Presence Model


                                                                                      

Train Loss: 0.2574


                                                                                      

Validation Loss: 0.2506
Saved Best Presence Model
Epoch 35/50 - Training Presence Model


                                                                                      

Train Loss: 0.2569


                                                                                      

Validation Loss: 0.2528
Epoch 36/50 - Training Presence Model


                                                                                      

Train Loss: 0.2566


                                                                                      

Validation Loss: 0.2504
Saved Best Presence Model
Epoch 37/50 - Training Presence Model


                                                                                      

Train Loss: 0.2552


                                                                                      

Validation Loss: 0.2506
Epoch 38/50 - Training Presence Model


                                                                                      

Train Loss: 0.2550


                                                                                      

Validation Loss: 0.2498
Saved Best Presence Model
Epoch 39/50 - Training Presence Model


                                                                                      

Train Loss: 0.2544


                                                                                      

Validation Loss: 0.2491
Saved Best Presence Model
Epoch 40/50 - Training Presence Model


                                                                                      

Train Loss: 0.2531


                                                                                      

Validation Loss: 0.2503
Epoch 41/50 - Training Presence Model


                                                                                      

Train Loss: 0.2525


                                                                                      

Validation Loss: 0.2486
Saved Best Presence Model
Epoch 42/50 - Training Presence Model


                                                                                      

Train Loss: 0.2517


                                                                                      

Validation Loss: 0.2474
Saved Best Presence Model
Epoch 43/50 - Training Presence Model


                                                                                      

Train Loss: 0.2516


                                                                                      

Validation Loss: 0.2475
Epoch 44/50 - Training Presence Model


                                                                                      

Train Loss: 0.2509


                                                                                      

Validation Loss: 0.2479
Epoch 45/50 - Training Presence Model


                                                                                      

Train Loss: 0.2505


                                                                                      

Validation Loss: 0.2472
Saved Best Presence Model
Epoch 46/50 - Training Presence Model


                                                                                      

Train Loss: 0.2496


                                                                                      

Validation Loss: 0.2492
Epoch 47/50 - Training Presence Model


                                                                                      

Train Loss: 0.2492


                                                                                      

Validation Loss: 0.2457
Saved Best Presence Model
Epoch 48/50 - Training Presence Model


                                                                                      

Train Loss: 0.2488


                                                                                      

Validation Loss: 0.2474
Epoch 49/50 - Training Presence Model


                                                                                      

Train Loss: 0.2480


                                                                                      

Validation Loss: 0.2457
Epoch 50/50 - Training Presence Model


                                                                                      

Train Loss: 0.2476


  presence_model.load_state_dict(torch.load(best_presence_model_path))


Validation Loss: 0.2468


[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?


Type Train Size: 858862
Type Validation Size: 214716
Type Validation Size after split: 107358
Type Test Size: 107358
Epoch 1/50 - Training Type Model


                                                                                                  

Train Loss: 0.1162, Train Acc: 83.59%


                                                                                                  

Validation Loss: 0.1110, Validation Acc: 83.77%
Saved Best Type Model
Epoch 2/50 - Training Type Model


                                                                                                  

Train Loss: 0.1107, Train Acc: 83.65%


                                                                                                  

Validation Loss: 0.1060, Validation Acc: 83.95%
Saved Best Type Model
Epoch 3/50 - Training Type Model


                                                                                                  

Train Loss: 0.1071, Train Acc: 83.87%


                                                                                                  

Validation Loss: 0.1023, Validation Acc: 84.24%
Saved Best Type Model
Epoch 4/50 - Training Type Model


                                                                                                  

Train Loss: 0.1041, Train Acc: 84.22%


                                                                                                  

Validation Loss: 0.0988, Validation Acc: 84.75%
Saved Best Type Model
Epoch 5/50 - Training Type Model


                                                                                                   

Train Loss: 0.1021, Train Acc: 84.52%


                                                                                                  

Validation Loss: 0.0965, Validation Acc: 85.04%
Saved Best Type Model
Epoch 6/50 - Training Type Model


                                                                                                  

Train Loss: 0.1002, Train Acc: 84.76%


                                                                                                  

Validation Loss: 0.0944, Validation Acc: 85.43%
Saved Best Type Model
Epoch 7/50 - Training Type Model


                                                                                                  

Train Loss: 0.0987, Train Acc: 84.97%


                                                                                                  

Validation Loss: 0.0936, Validation Acc: 85.50%
Saved Best Type Model
Epoch 8/50 - Training Type Model


                                                                                                   

Train Loss: 0.0976, Train Acc: 85.16%


                                                                                                  

Validation Loss: 0.0915, Validation Acc: 85.89%
Saved Best Type Model
Epoch 9/50 - Training Type Model


                                                                                                  

Train Loss: 0.0966, Train Acc: 85.27%


                                                                                                  

Validation Loss: 0.0907, Validation Acc: 86.11%
Saved Best Type Model
Epoch 10/50 - Training Type Model


                                                                                                  

Train Loss: 0.0957, Train Acc: 85.41%


                                                                                                  

Validation Loss: 0.0899, Validation Acc: 86.11%
Saved Best Type Model
Epoch 11/50 - Training Type Model


                                                                                                  

Train Loss: 0.0947, Train Acc: 85.56%


                                                                                                  

Validation Loss: 0.0896, Validation Acc: 86.29%
Saved Best Type Model
Epoch 12/50 - Training Type Model


                                                                                                  

Train Loss: 0.0940, Train Acc: 85.64%


                                                                                                  

Validation Loss: 0.0880, Validation Acc: 86.39%
Saved Best Type Model
Epoch 13/50 - Training Type Model


                                                                                                  

Train Loss: 0.0934, Train Acc: 85.71%


                                                                                                   

Validation Loss: 0.0887, Validation Acc: 86.56%
Epoch 14/50 - Training Type Model


                                                                                                  

Train Loss: 0.0930, Train Acc: 85.79%


                                                                                                  

Validation Loss: 0.0894, Validation Acc: 86.34%
Epoch 15/50 - Training Type Model


                                                                                                  

Train Loss: 0.0925, Train Acc: 85.85%


                                                                                                  

Validation Loss: 0.0875, Validation Acc: 86.56%
Saved Best Type Model
Epoch 16/50 - Training Type Model


                                                                                                  

Train Loss: 0.0919, Train Acc: 85.91%


                                                                                                  

Validation Loss: 0.0867, Validation Acc: 86.64%
Saved Best Type Model
Epoch 17/50 - Training Type Model


                                                                                                  

Train Loss: 0.0916, Train Acc: 85.98%


                                                                                                  

Validation Loss: 0.0861, Validation Acc: 86.79%
Saved Best Type Model
Epoch 18/50 - Training Type Model


                                                                                                  

Train Loss: 0.0911, Train Acc: 86.01%


                                                                                                  

Validation Loss: 0.0856, Validation Acc: 86.66%
Saved Best Type Model
Epoch 19/50 - Training Type Model


                                                                                                  

Train Loss: 0.0907, Train Acc: 86.10%


                                                                                                  

Validation Loss: 0.0855, Validation Acc: 86.73%
Saved Best Type Model
Epoch 20/50 - Training Type Model


                                                                                                   

Train Loss: 0.0905, Train Acc: 86.15%


                                                                                                   

Validation Loss: 0.0858, Validation Acc: 86.91%
Epoch 21/50 - Training Type Model


                                                                                                  

Train Loss: 0.0901, Train Acc: 86.20%


                                                                                                  

Validation Loss: 0.0850, Validation Acc: 86.90%
Saved Best Type Model
Epoch 22/50 - Training Type Model


                                                                                                  

Train Loss: 0.0900, Train Acc: 86.18%


                                                                                                  

Validation Loss: 0.0841, Validation Acc: 86.98%
Saved Best Type Model
Epoch 23/50 - Training Type Model


                                                                                                  

Train Loss: 0.0897, Train Acc: 86.24%


                                                                                                  

Validation Loss: 0.0848, Validation Acc: 86.78%
Epoch 24/50 - Training Type Model


                                                                                                  

Train Loss: 0.0895, Train Acc: 86.27%


                                                                                                  

Validation Loss: 0.0846, Validation Acc: 86.98%
Epoch 25/50 - Training Type Model


                                                                                                  

Train Loss: 0.0891, Train Acc: 86.33%


                                                                                                   

Validation Loss: 0.0838, Validation Acc: 86.91%
Saved Best Type Model
Epoch 26/50 - Training Type Model


                                                                                                  

Train Loss: 0.0892, Train Acc: 86.35%


                                                                                                  

Validation Loss: 0.0845, Validation Acc: 87.06%
Epoch 27/50 - Training Type Model


                                                                                                   

Train Loss: 0.0889, Train Acc: 86.36%


                                                                                                  

Validation Loss: 0.0826, Validation Acc: 87.28%
Saved Best Type Model
Epoch 28/50 - Training Type Model


                                                                                                   

Train Loss: 0.0886, Train Acc: 86.37%


                                                                                                  

Validation Loss: 0.0834, Validation Acc: 87.16%
Epoch 29/50 - Training Type Model


                                                                                                 

Train Loss: nan, Train Acc: 85.02%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 30/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 31/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 32/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 33/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 34/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 35/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 36/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 37/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 38/50 - Training Type Model


                                                                                                

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 39/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 40/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 41/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 42/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 43/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 44/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 45/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 46/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 47/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 48/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 49/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


                                                                                               

Validation Loss: nan, Validation Acc: 83.77%
Epoch 50/50 - Training Type Model


                                                                                               

Train Loss: nan, Train Acc: 83.64%


  type_model.load_state_dict(torch.load(best_type_model_path))
  presence_model.load_state_dict(torch.load(best_presence_model_path, map_location=device))
  type_model.load_state_dict(torch.load(best_type_model_path, map_location=device))
  presence_model.load_state_dict(torch.load(best_presence_model_path))


Validation Loss: nan, Validation Acc: 83.77%


CNNOnsetDetector(
  (conv1): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=2560, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)