# Task 2 - Symbolic Conditioned Generation


## Data Processor

In [None]:
import os
from pathlib import Path
from typing import List, Tuple, Optional, Dict
import pretty_midi
import numpy as np
import json
import pickle

class POP909DataProcessor:
    """
    Data processor for attention-based global context approach.
    This processor stores full chord sequences and creates training pairs
    with focus positions for attention mechanism.
    """
    def __init__(self, data_path: str, output_path: str, melody_segment_length: int = 32):
        self.data_path = Path(data_path)
        self.output_path = Path(output_path)
        self.output_path.mkdir(exist_ok=True)

        self.time_resolution = 0.125
        self.min_chord_duration = 0.5
        self.melody_segment_length = melody_segment_length

        self.chord_vocab = set()
        self.note_vocab = set()

        self.processed_data = []

    def parse_chord_file(self, chord_file_path: str) -> List[Tuple[float, float, str]]:
        chords = []

        try:
            with open(chord_file_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split('\t')
                    if len(parts) >= 3:
                        start_time, end_time, chord_symbol = parts[:3]
                        if (float(end_time) - float(start_time)) >= self.min_chord_duration and chord_symbol != "N":
                            simplified_chord = self.normalize_chord_symbol(chord_symbol)
                            chords.append((float(start_time), float(end_time), simplified_chord))
                            self.chord_vocab.add(simplified_chord)
        except Exception as e:
            error_msg = f"Error parsing chord file {chord_file_path}: {e}"
            print(error_msg)

        return chords

    def extract_melody_track(self, midi_file_path: str) -> Optional[pretty_midi.Instrument]:
        """
        Extract melody track from MIDI file
        """
        midi_file_path = str(midi_file_path)
        try:
            midi_data = pretty_midi.PrettyMIDI(midi_file_path)
            for instrument in midi_data.instruments:
                if instrument.name and "MELODY" in instrument.name.upper():
                    return instrument

            if len(midi_data.instruments) > 0:
                return midi_data.instruments[0]

        except Exception as e:
            print(f"Error loading MIDI file {midi_file_path}: {e}")

        return None

    def quantize_melody(self, melody_track: pretty_midi.Instrument, song_duration: float) -> List[Dict]:
        """
        Quantize melody to fixed time grid and extract features
        """
        quantized_melody = []
        time_steps = np.arange(0, song_duration, self.time_resolution)

        for note in melody_track.notes:
            start_step = int(np.round(note.start / self.time_resolution))
            end_step = int(np.round(note.end / self.time_resolution))

            if start_step >= len(time_steps) or end_step > len(time_steps):
                continue

            note_info = {
                'start_time': time_steps[start_step],
                'end_time': time_steps[end_step-1] if end_step > 0 else time_steps[start_step],
                'pitch': note.pitch,
                'velocity': note.velocity,
                'duration': note.end - note.start
            }
            quantized_melody.append(note_info)
            self.note_vocab.add(note.pitch)

        return quantized_melody

    def normalize_chord_symbol(self, chord_symbol: str) -> str:
        """
        Normalize and simplify chord symbols
        """
        if chord_symbol == 'N' or not chord_symbol or chord_symbol.strip() == '':
            return 'N'

        cleaned = chord_symbol.strip()
        simplified = self.simplify_chord_symbol(cleaned)
        return simplified

    def process_song(self, song_folder: Path) -> Optional[Dict]:
        """
        Process a single song with simplified chord vocabulary
        """
        song_id = song_folder.name
        midi_file_path = song_folder / f"{song_id}.mid"
        chord_file_path = song_folder / "chord_audio.txt"

        if not midi_file_path.exists() or not chord_file_path.exists():
            print(f"Skipping {song_id} due to missing files")
            return None

        chords = self.parse_chord_file(chord_file_path)
        if len(chords) < 4:
            print(f"Skipping {song_id} - not enough chords ({len(chords)})")
            return None

        combined_track = self.extract_melody_track(midi_file_path)

        song_duration = max(chords[-1][1], combined_track.notes[-1].end)

        quantized_melody = self.quantize_melody(combined_track, song_duration)
        if not quantized_melody:
            print(f"No quantized melody for song {song_id}")
            return None

        melody_segments = self.create_chord_aligned_segments(chords, quantized_melody)
        if not melody_segments:
            print(f"No melody segments for song {song_id}")
            return None

        return {
            'song_id': song_id,
            'melody_segments': melody_segments,
            'total_duration': song_duration,
            'num_chords': len(chords),
            'num_melody_notes': len(quantized_melody),
            'full_chord_sequence': [chord[2] for chord in chords]
        }

    def process_dataset(self) -> None:
        """
        Process all songs in the dataset and save the melody segments.
        """
        pop909_folder = self.data_path / "POP909"
        if not pop909_folder.exists():
            print(f"POP909 folder not found at {pop909_folder}")
            return

        song_folders = [f for f in pop909_folder.iterdir() if f.is_dir()]
        print(f"Found {len(song_folders)} song folders in POP909 dataset")

        successful_songs = 0
        for song_folder in sorted(song_folders):
            song_data = self.process_song(song_folder)
            if song_data:
                self.processed_data.append(song_data)
                successful_songs += 1

                if successful_songs % 50 == 0:
                    print(f"Processed {successful_songs}/{len(song_folders)} songs")

        print(f"Completed processing. {successful_songs}/{len(song_folders)} songs processed successfully")

    def create_vocabularies(self) -> Dict:
        chord_vocab_list = ['<PAD>', '<START>', '<END>', '<UNK>'] + sorted(list(self.chord_vocab))
        note_vocab_list = list(range(0, 128))

        vocab_data = {
            'chord_vocab': {chord: idx for idx, chord in enumerate(chord_vocab_list)},
            'note_vocab': {note: idx for idx, note in enumerate(note_vocab_list)},
            'idx_to_chord': {idx: chord for idx, chord in enumerate(chord_vocab_list)},
            'idx_to_note': {idx: note for idx, note in enumerate(note_vocab_list)}
        }

        with open(self.output_path / "vocabularies.json", 'w') as f:
            json.dump(vocab_data, f, indent=2)

        print(f"Chord vocabulary size: {len(chord_vocab_list)}")
        print(f"Note vocabulary size: {len(note_vocab_list)}")

        return vocab_data

    def save_processed_data(self) -> None:
        with open(self.output_path / "chord_melody_data.pkl", 'wb') as f:
            pickle.dump(self.processed_data, f)

        with open(self.output_path / "chord_melody_data.json", 'w') as f:
            json.dump(self.processed_data, f, indent=2)

        total_segments = sum(len(song['melody_segments']) for song in self.processed_data)
        avg_segments_per_song = total_segments / len(self.processed_data) if self.processed_data else 0

        stats = {
            'total_songs': len(self.processed_data),
            'total_melody_segments': total_segments,
            'avg_segments_per_song': avg_segments_per_song,
            'unique_chords': len(self.chord_vocab),
            'note_range': [min(self.note_vocab), max(self.note_vocab)] if self.note_vocab else [0, 127],
            'melody_segment_length': self.melody_segment_length
        }

        with open(self.output_path / "dataset_stats.json", 'w') as f:
            json.dump(stats, f, indent=2)

        print(f"Processed data saved to {self.output_path}")
        print(f"Dataset statistics: {stats}")

    def create_training_sequences(self) -> List[Dict]:
        """
        Create training sequences with proper chord-melody alignment
        """
        training_sequences = []

        for song in self.processed_data:
            for segment in song['melody_segments']:
                for chord_pair in segment['chord_melody_pairs']:

                    chord_sequence = segment['full_chord_sequence']
                    current_chord_pos = chord_pair['chord_position_in_segment']

                    melody_notes = []
                    for note in chord_pair['notes']:
                        melody_notes.append({
                            'pitch': note['pitch'],
                            'start_time': note['start_in_chord'],
                            'duration': note['end_in_chord'] - note['start_in_chord'],
                            'velocity': note['velocity']
                        })

                    training_example = {
                        'full_chord_sequence': chord_sequence,
                        'chord_durations': segment['chord_durations'],
                        'focus_position': current_chord_pos,
                        'target_chord': chord_pair['chord'],
                        'chord_duration': chord_pair['chord_duration'],
                        'output_melody': melody_notes,
                        'song_id': song['song_id'],
                        'segment_id': segment['segment_id'],
                        'timing_context': {
                            'bpm_estimate': 60.0 / (chord_pair['chord_duration'] / 2),
                            'time_signature': '4/4',
                            'chord_position_in_song': segment['segment_start_time'] / song['total_duration']
                        }
                    }

                    if len(melody_notes) > 0:
                        training_sequences.append(training_example)

        with open(self.output_path / "training_sequences.pkl", 'wb') as f:
            pickle.dump(training_sequences, f)

        print(f"Created {len(training_sequences)} chord-aligned training sequences")
        return training_sequences


    def simplify_chord_symbol(self, chord_symbol: str) -> str:
        """
        Simplify chord symbols to reduce vocabulary size and improve generalization.

        Maps complex chords to basic chord types:
        - All major variations -> maj
        - All minor variations -> min
        - All dominant 7th variations -> 7
        - All major 7th variations -> maj7
        - All minor 7th variations -> min7
        - All diminished variations -> dim
        - All augmented variations -> aug
        - All suspended variations -> sus
        """

        if chord_symbol == 'N' or not chord_symbol or chord_symbol.strip() == '':
            return 'N'

        if ':' not in chord_symbol:
            return chord_symbol

        root, quality = chord_symbol.split(':', 1)

        simplified_quality = None

        if any(pattern in quality for pattern in ['maj7', 'M7']):
            if any(pattern in quality for pattern in ['min', 'm7', 'minmaj']):
                simplified_quality = 'minmaj7'
            else:
                simplified_quality = 'maj7'
        elif 'maj' in quality:
            simplified_quality = 'maj'

        elif any(pattern in quality for pattern in ['min7', 'm7']) and 'maj' not in quality:
            simplified_quality = 'min7'
        elif 'min' in quality or quality.startswith('m'):
            simplified_quality = 'min'

        elif any(pattern in quality for pattern in ['7', '9', '11', '13']) and 'maj' not in quality and 'min' not in quality:
            simplified_quality = '7'

        elif any(pattern in quality for pattern in ['dim', 'hdim', 'ø']):
            if '7' in quality or 'hdim' in quality:
                simplified_quality = 'dim7'
            else:
                simplified_quality = 'dim'

        elif 'aug' in quality or '+' in quality:
            simplified_quality = 'aug'

        elif any(pattern in quality for pattern in ['sus2', 'sus4', 'sus']):
            if 'sus2' in quality:
                simplified_quality = 'sus2'
            else:
                simplified_quality = 'sus4'

        elif '6' in quality and 'min' not in quality:
            simplified_quality = 'maj'
        elif '6' in quality and 'min' in quality:
            simplified_quality = 'min'

        if simplified_quality is None:
            if quality == '' or quality.isdigit():
                simplified_quality = 'maj'
            else:
                simplified_quality = 'maj'

        return f"{root}:{simplified_quality}"

    def create_chord_aligned_segments(self, chords: List[Tuple[float, float, str]], melody: List[Dict]) -> List[Dict]:
        """
        Create training segments aligned to chord boundaries for better timing relationships.
        Each segment represents a sequence of chords with their corresponding melody notes.
        """
        segments = []

        if not chords or not melody:
            return segments

        min_chords_per_segment = 4
        max_chords_per_segment = 12
        overlap_chords = 2

        chord_idx = 0
        while chord_idx < len(chords):
            segment_end_idx = min(chord_idx + max_chords_per_segment, len(chords))
            segment_end_idx = min(chord_idx + max_chords_per_segment, len(chords))

            if segment_end_idx - chord_idx < min_chords_per_segment and segment_end_idx < len(chords):
                segment_end_idx = min(chord_idx + min_chords_per_segment, len(chords))

            segment_chords = chords[chord_idx:segment_end_idx]
            segment_start_time = segment_chords[0][0]
            segment_end_time = segment_chords[-1][1]

            chord_melody_pairs = []

            for local_idx, (chord_start, chord_end, chord_symbol) in enumerate(segment_chords):
                chord_notes = []

                for note in melody:
                    note_start = note['start_time']
                    note_end = note['end_time']

                    if not (note_end <= chord_start or note_start >= chord_end):
                        overlap_start = max(note_start, chord_start)
                        overlap_end = min(note_end, chord_end)
                        overlap_duration = overlap_end - overlap_start
                        note_duration = note_end - note_start

                        if note_duration > 0 and overlap_duration / note_duration > 0.5:
                            relative_note = {
                                'pitch': note['pitch'],
                                'start_in_chord': max(0, note_start - chord_start),
                                'end_in_chord': min(chord_end - chord_start, note_end - chord_start),
                                'duration': note.get('duration', note_end - note_start),
                                'velocity': note.get('velocity', 80),
                                'chord_duration': chord_end - chord_start
                            }
                            chord_notes.append(relative_note)

                chord_melody_pairs.append({
                    'chord': chord_symbol,
                    'chord_duration': chord_end - chord_start,
                    'chord_position_in_segment': local_idx,
                    'notes': chord_notes,
                    'absolute_start_time': chord_start,
                    'absolute_end_time': chord_end
                })

            total_notes = sum(len(pair['notes']) for pair in chord_melody_pairs)
            if total_notes > 0 and len(segment_chords) >= min_chords_per_segment:
                segment = {
                    'chord_melody_pairs': chord_melody_pairs,
                    'full_chord_sequence': [pair['chord'] for pair in chord_melody_pairs],
                    'chord_durations': [pair['chord_duration'] for pair in chord_melody_pairs],
                    'segment_start_time': segment_start_time,
                    'segment_end_time': segment_end_time,
                    'total_duration': segment_end_time - segment_start_time,
                    'num_chords': len(segment_chords),
                    'total_notes': total_notes,
                    'segment_id': len(segments)
                }
                segments.append(segment)

            step_size = max_chords_per_segment - overlap_chords
            chord_idx += step_size

            if chord_idx >= len(chords) - min_chords_per_segment:
                break

        return segments

def process_dataset(dataset_path: str, output_path: str, melody_segment_length: int = 32):
    processor = POP909DataProcessor(dataset_path, output_path, melody_segment_length)
    processor.process_dataset()
    processor.create_vocabularies()
    processor.save_processed_data()
    processor.create_training_sequences()

## Dataset Builder

In [None]:
from torch.utils.data import Dataset
import pickle
import json
import torch
import numpy as np

class ChordMelodyDataset(Dataset):
    def __init__(self, data_path: str, vocab_path: str, max_melody_length: int = 16, max_chord_length: int = 12):
        with open(data_path, 'rb') as f:
            self.training_sequences = pickle.load(f)
        with open(vocab_path, 'r') as f:
            self.vocabularies = json.load(f)

        self.chord_to_idx = self.vocabularies['chord_vocab']
        self.note_to_idx = self.vocabularies['note_vocab']
        self.max_melody_length = max_melody_length
        self.max_chord_length = max_chord_length

        self.training_sequences = [seq for seq in self.training_sequences
                                 if len(seq['output_melody']) <= max_melody_length
                                 and len(seq['full_chord_sequence']) <= max_chord_length]

        print(f"Loaded {len(self.training_sequences)} chord-aligned training sequences")
        if self.training_sequences:
            avg_chord_len = np.mean([len(seq['full_chord_sequence']) for seq in self.training_sequences])
            avg_melody_len = np.mean([len(seq['output_melody']) for seq in self.training_sequences])
            print(f"Average chord sequence length: {avg_chord_len:.1f}")
            print(f"Average melody length: {avg_melody_len:.1f}")

    def __len__(self):
        return len(self.training_sequences)

    def __getitem__(self, idx):
        sequence = self.training_sequences[idx]

        chord_indices = []
        for chord in sequence['full_chord_sequence']:
            chord_indices.append(self.chord_to_idx.get(chord, self.chord_to_idx['<UNK>']))

        chord_durations = sequence['chord_durations']
        max_duration = max(chord_durations) if chord_durations else 4.0
        normalized_durations = [d / max_duration for d in chord_durations]

        original_chord_length = len(chord_indices)
        while len(chord_indices) < self.max_chord_length:
            chord_indices.append(self.chord_to_idx['<PAD>'])
            normalized_durations.append(0.0)

        focus_position = min(sequence['focus_position'], original_chord_length - 1)

        target_chord_duration = sequence['chord_duration']

        melody_pitches, melody_starts, melody_durations = [], [], []
        for note in sequence['output_melody']:
            melody_pitches.append(note['pitch'])
            melody_starts.append(note['start_time'] / target_chord_duration)
            melody_durations.append(note['duration'] / target_chord_duration)

        actual_melody_length = len(melody_pitches)
        while len(melody_pitches) < self.max_melody_length:
            melody_pitches.append(0)
            melody_starts.append(0.0)
            melody_durations.append(0.0)

        chord_mask = [1 if i < original_chord_length else 0 for i in range(self.max_chord_length)]
        melody_mask = [1 if i < actual_melody_length else 0 for i in range(self.max_melody_length)]

        return {
            'full_chord_sequence': torch.tensor(chord_indices, dtype=torch.long),
            'chord_durations': torch.tensor(normalized_durations, dtype=torch.float32),
            'chord_mask': torch.tensor(chord_mask, dtype=torch.bool),
            'focus_position': torch.tensor(focus_position, dtype=torch.long),
            'target_chord_duration': torch.tensor(target_chord_duration, dtype=torch.float32),
            'melody_pitch': torch.tensor(melody_pitches, dtype=torch.long),
            'melody_start': torch.tensor(melody_starts, dtype=torch.float32),
            'melody_duration': torch.tensor(melody_durations, dtype=torch.float32),
            'melody_mask': torch.tensor(melody_mask, dtype=torch.bool),
            'chord_length': torch.tensor(original_chord_length, dtype=torch.long),
            'melody_length': torch.tensor(actual_melody_length, dtype=torch.long),
            'song_id': sequence['song_id']
        }

## Model

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import math

class AttentionChordToMelodyTransformer(nn.Module):
    def __init__(self,
                chord_vocab_size: int,
                note_vocab_size: int,
                d_model: int = 256,
                nhead: int = 8,
                num_layers: int = 6,
                max_melody_length: int = 32,
                max_chord_length: int = 100):
        super().__init__()

        self.d_model = d_model
        self.max_melody_length = max_melody_length
        self.max_chord_length = max_chord_length

        self.chord_embedding = nn.Embedding(chord_vocab_size, d_model)
        self.note_embedding = nn.Embedding(note_vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_melody_length, d_model)
        self.chord_position_embedding = nn.Embedding(max_chord_length, d_model)

        self.segment_position_embedding = nn.Linear(1, d_model)

        self.chord_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model,
                nhead,
                dim_feedforward=d_model * 4,
                dropout=0.1,
                activation='gelu',
                batch_first=True
            ),
            num_layers=num_layers // 2
        )

        self.melody_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model,
                nhead,
                dim_feedforward=d_model * 4,
                dropout=0.1,
                activation='gelu',
                batch_first=True
            ),
            num_layers=num_layers // 2
        )

        self.focus_attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.focus_projection = nn.Linear(d_model, d_model)

        self.pitch_head = nn.Linear(d_model, note_vocab_size)
        self.duration_head = nn.Linear(d_model, 1)
        self.start_head = nn.Linear(d_model, 1)

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(d_model)

        self._init_weights()

    def _init_weights(self):
        """Initialize weights properly for transformer"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, std=0.02)

    def create_distance_attention_mask(self, seq_len: int, focus_positions: torch.Tensor, focus_window: int = 8):
        """
        Create attention mask that focuses on nearby chords with some global context
        """
        batch_size = focus_positions.size(0)
        device = focus_positions.device

        mask = torch.zeros(batch_size, seq_len, seq_len, device=device)

        for b in range(batch_size):
            focus_pos = focus_positions[b].item()

            distances = torch.arange(seq_len, device=device).float()
            focus_distances = torch.abs(distances - focus_pos)

            local_weights = torch.exp(-focus_distances / (focus_window / 2))

            global_weights = torch.where(
                (torch.arange(seq_len, device=device) % 4) == 0,
                torch.ones(seq_len, device=device) * 0.3,
                torch.zeros(seq_len, device=device)
            )

            combined_weights = local_weights + global_weights
            combined_weights = combined_weights / combined_weights.sum()

            for i in range(seq_len):
                mask[b, i, :] = combined_weights

        return mask

    def _align_simultaneous_notes(self, melody, tolerance=0.1):
        """
        Align notes that should be played simultaneously
        """
        if len(melody) <= 1:
            return melody

        sorted_melody = sorted(melody, key=lambda x: x['start_time'])
        aligned_melody = []

        i = 0
        while i < len(sorted_melody):
            current_note = sorted_melody[i].copy()
            simultaneous_notes = [current_note]

            j = i + 1
            while j < len(sorted_melody):
                next_note = sorted_melody[j]
                time_diff = abs(next_note['start_time'] - current_note['start_time'])

                if time_diff <= tolerance:
                    simultaneous_notes.append(next_note.copy())
                    j += 1
                else:
                    break

            if len(simultaneous_notes) > 1:
                avg_start_time = sum(note['start_time'] for note in simultaneous_notes) / len(simultaneous_notes)

                print(f"  Aligning {len(simultaneous_notes)} simultaneous notes at {avg_start_time:.2f}s")
                print(f"    Pitches: {[note['pitch'] for note in simultaneous_notes]}")

                for note in simultaneous_notes:
                    note['start_time'] = avg_start_time

                avg_duration = sum(note['duration'] for note in simultaneous_notes) / len(simultaneous_notes)
                for note in simultaneous_notes:
                    note['duration'] = avg_duration

            aligned_melody.extend(simultaneous_notes)
            i = j

        return aligned_melody

    def _quantize_note_timing(self, melody, beat_duration=0.25):
        """
        Quantize note timing to musical beats for cleaner rhythm
        """
        print(f"  Quantizing notes to {beat_duration:.3f}s grid...")

        quantized_melody = []

        for note in melody:
            quantized_note = note.copy()
            quantized_start = round(note['start_time'] / beat_duration) * beat_duration
            raw_duration = note['duration']
            musical_durations = [0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.0]
            closest_duration = min(musical_durations, key=lambda x: abs(x - raw_duration))
            quantized_note['start_time'] = quantized_start
            quantized_note['duration'] = closest_duration
            quantized_melody.append(quantized_note)

        return quantized_melody

    def _remove_timing_conflicts(self, melody, min_gap=0.05):
        """
        Remove notes that create timing conflicts while preserving harmonies
        """
        if len(melody) <= 1:
            return melody

        sorted_melody = sorted(melody, key=lambda x: (x['start_time'], x['pitch']))
        cleaned_melody = []

        for note in sorted_melody:
            should_add = True

            for existing_note in cleaned_melody:
                existing_end = existing_note['start_time'] + existing_note['duration']
                note_start = note['start_time']
                note_end = note['start_time'] + note['duration']

                if abs(existing_note['start_time'] - note_start) < 0.01:
                    continue

                elif (note_start < existing_end - min_gap and
                      note_end > existing_note['start_time'] + min_gap):

                    print(f"Removing conflicting note: pitch {note['pitch']} at {note_start:.2f}s")
                    should_add = False
                    break

            if should_add:
                cleaned_melody.append(note)

        return cleaned_melody

    def _enhance_note_timing(self, melody):
        """
        Master function to enhance note timing for better musical quality
        """
        if not melody:
            return melody

        print(f"Enhancing timing for {len(melody)} notes...")

        melody = self._remove_timing_conflicts(melody, min_gap=0.05)
        print(f"  After conflict removal: {len(melody)} notes")

        melody = self._align_simultaneous_notes(melody, tolerance=0.15)
        print(f"  After alignment: {len(melody)} notes")

        melody = sorted(melody, key=lambda x: x['start_time'])

        self._report_simultaneous_groups(melody)

        return melody

    def _report_simultaneous_groups(self, melody):
        """Report groups of simultaneous notes for debugging"""
        simultaneous_groups = []
        i = 0

        while i < len(melody):
            current_time = melody[i]['start_time']
            group = [melody[i]]

            j = i + 1
            while j < len(melody) and abs(melody[j]['start_time'] - current_time) < 0.01:
                group.append(melody[j])
                j += 1

            if len(group) > 1:
                pitches = [note['pitch'] for note in group]
                simultaneous_groups.append((current_time, pitches))

            i = j

        if simultaneous_groups:
            print(f"  Found {len(simultaneous_groups)} simultaneous note groups:")
            for time, pitches in simultaneous_groups[:5]:  # Show first 5
                print(f"    {time:.2f}s: pitches {pitches}")
            if len(simultaneous_groups) > 5:
                print(f"    ... and {len(simultaneous_groups) - 5} more groups")

    def _calculate_song_bpm(self, chord_times):
        """Calculate BPM from chord timing for better quantization"""
        if len(chord_times) < 4:
            return 120

        durations = [end - start for start, end in chord_times]
        avg_chord_duration = sum(durations) / len(durations)

        estimated_bpm = 60.0 / avg_chord_duration

        common_bpms = [60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 180]
        closest_bpm = min(common_bpms, key=lambda x: abs(x - estimated_bpm))

        print(f"  Estimated BPM: {estimated_bpm:.1f} -> Using: {closest_bpm}")
        return closest_bpm

    def _enhance_note_timing_with_bpm(self, melody, beat_duration):
        """Enhanced timing with BPM awareness"""
        if not melody:
            return melody

        print(f"Enhancing timing for {len(melody)} notes (beat={beat_duration:.3f}s)...")

        melody = self._remove_timing_conflicts(melody)

        melody = self._align_simultaneous_notes(melody, tolerance=beat_duration * 0.5)

        quantized_melody = []
        for note in melody:
            quantized_note = note.copy()

            closest_beat_time = round(note['start_time'] / beat_duration) * beat_duration
            time_diff = abs(note['start_time'] - closest_beat_time)

            if time_diff < beat_duration * 0.3:
                quantized_note['start_time'] = closest_beat_time

            quantized_melody.append(quantized_note)

        self._report_simultaneous_groups(quantized_melody)
        return sorted(quantized_melody, key=lambda x: x['start_time'])

    def forward(self, full_chord_sequence, chord_mask, focus_positions, chord_durations=None,
                melody_pitch=None, training=True):
        batch_size = full_chord_sequence.size(0)
        chord_seq_len = full_chord_sequence.size(1)

        chord_embedded = self.chord_embedding(full_chord_sequence)

        chord_positions = torch.arange(chord_seq_len, device=full_chord_sequence.device)
        chord_pos_emb = self.chord_position_embedding(chord_positions).unsqueeze(0).expand(batch_size, -1, -1)
        chord_embedded = chord_embedded + chord_pos_emb

        if chord_durations is not None:
            duration_emb = chord_durations.unsqueeze(-1).expand(-1, -1, self.d_model) * 0.1
            chord_embedded = chord_embedded + duration_emb

        chord_embedded = self.dropout(chord_embedded)

        padding_mask = ~chord_mask

        chord_encoded = self.chord_encoder(
            chord_embedded,
            src_key_padding_mask=padding_mask
        )

        focus_context, focus_weights = self.focus_attention(
            chord_encoded, chord_encoded, chord_encoded,
            key_padding_mask=padding_mask
        )

        chord_encoded = chord_encoded + self.focus_projection(focus_context)
        chord_encoded = self.layer_norm(chord_encoded)

        if training:
            target_melody = melody_pitch[:, :-1]

            melody_embedded = self.note_embedding(target_melody)
            melody_seq_len = target_melody.size(1)

            melody_positions = torch.arange(melody_seq_len, device=target_melody.device)
            melody_pos_emb = self.position_embedding(melody_positions).unsqueeze(0).expand(batch_size, -1, -1)
            melody_embedded = melody_embedded + melody_pos_emb
            melody_embedded = self.dropout(melody_embedded)

            causal_mask = torch.triu(torch.ones(melody_seq_len, melody_seq_len, device=target_melody.device), diagonal=1).bool()

            decoded = self.melody_decoder(
                melody_embedded,
                chord_encoded,
                tgt_mask=causal_mask,
                memory_key_padding_mask=padding_mask
            )

            pitch_logits = self.pitch_head(decoded)
            duration_pred = self.duration_head(decoded).squeeze(-1)
            start_pred = self.start_head(decoded).squeeze(-1)

            return pitch_logits, duration_pred, start_pred

        else:
            generated_sequence = []
            current_input = torch.zeros(batch_size, 1, dtype=torch.long, device=full_chord_sequence.device)

            for i in range(self.max_melody_length):
                melody_embedded = self.note_embedding(current_input)
                melody_seq_len = current_input.size(1)

                melody_positions = torch.arange(melody_seq_len, device=current_input.device)
                melody_pos_emb = self.position_embedding(melody_positions).unsqueeze(0).expand(batch_size, -1, -1)
                melody_embedded = melody_embedded + melody_pos_emb
                melody_embedded = self.dropout(melody_embedded)

                causal_mask = torch.triu(torch.ones(melody_seq_len, melody_seq_len, device=current_input.device), diagonal=1).bool()

                decoded = self.melody_decoder(
                    melody_embedded,
                    chord_encoded,
                    tgt_mask=causal_mask,
                    memory_key_padding_mask=padding_mask
                )

                last_decoded = decoded[:, -1:]
                pitch_logits = self.pitch_head(last_decoded)
                duration_pred = self.duration_head(last_decoded).squeeze(-1)
                start_pred = self.start_head(last_decoded).squeeze(-1)

                pitch_probs = torch.softmax(pitch_logits.squeeze(1), dim=-1)
                next_pitch = torch.multinomial(pitch_probs, 1)

                generated_sequence.append({
                    'pitch': next_pitch,
                    'duration': duration_pred,
                    'start': start_pred
                })

                if next_pitch.item() == 0 or len(generated_sequence) >= self.max_melody_length:
                    break

                current_input = torch.cat([current_input, next_pitch], dim=1)

            return generated_sequence

    def generate_full_song_melody(self, full_chord_sequence, vocab_path, segment_length=16, overlap=4,
                    target_density=0.6, density_window=8, temperature=1.0,
                    original_midi_path=None):
        """
        Generate melody for an entire song by processing it in segments with proper timing
        """
        self.eval()
        device = next(self.parameters()).device

        with open(vocab_path, 'r') as f:
            vocabularies = json.load(f)

        chord_to_idx = vocabularies['chord_vocab']
        idx_to_note = vocabularies['idx_to_note']

        chord_indices = []
        for chord in full_chord_sequence:
            chord_idx = chord_to_idx.get(chord, chord_to_idx.get('<UNK>', 0))
            chord_indices.append(chord_idx)

        if len(chord_indices) > self.max_chord_length:
            print(f"Warning: Chord sequence too long ({len(chord_indices)} > {self.max_chord_length}), truncating")
            chord_indices = chord_indices[:self.max_chord_length]

        chord_times = []
        if original_midi_path:
            from utils import extract_chord_timing_from_midi
            chord_times = extract_chord_timing_from_midi(original_midi_path)

        original_length = len(chord_indices)
        padded_chords = chord_indices + [chord_to_idx.get('<PAD>', 0)] * (self.max_chord_length - len(chord_indices))

        chord_mask = [True] * original_length + [False] * (self.max_chord_length - original_length)

        full_melody = []
        step_size = max(1, segment_length - overlap)

        print(f"Generating melody for {len(chord_indices)} chords...")
        print(f"Segment length: {segment_length}, Overlap: {overlap}, Step size: {step_size}")
        print(f"Target density: {target_density:.2f}, Temperature: {temperature:.2f}")

        if chord_times and len(chord_times) >= len(chord_indices):
            print("Using real chord timing from original MIDI")
            total_song_duration = chord_times[-1][1] - chord_times[0][0]
            segment_durations = []
            for i in range(0, len(chord_indices), step_size):
                seg_start_idx = i
                seg_end_idx = min(i + segment_length, len(chord_indices))

                if seg_start_idx < len(chord_times) and seg_end_idx <= len(chord_times):
                    seg_start_time = chord_times[seg_start_idx][0]
                    seg_end_time = chord_times[seg_end_idx-1][1]
                    segment_durations.append((seg_start_time, seg_end_time - seg_start_time))
                else:
                    estimated_duration = segment_length * 2.0
                    estimated_start = i * 2.0
                    segment_durations.append((estimated_start, estimated_duration))
        else:
            print("Using estimated chord timing (2 seconds per chord)")
            total_song_duration = len(chord_indices) * 2.0
            segment_durations = []
            for i in range(0, len(chord_indices), step_size):
                estimated_start = i * 2.0
                estimated_duration = min(segment_length, len(chord_indices) - i) * 2.0
                segment_durations.append((estimated_start, estimated_duration))

        current_position = 0
        segment_idx = 0

        print(f"Total song duration: {total_song_duration:.1f}s")
        print(f"Will generate segments until position {original_length}")

        while current_position < original_length:
            segment_start = current_position
            segment_end = min(segment_start + segment_length, original_length)
            focus_position = (segment_start + segment_end) // 2
            focus_position = min(focus_position, original_length - 1)

            if segment_idx < len(segment_durations):
                segment_start_time, segment_duration = segment_durations[segment_idx]
            else:
                segment_start_time = current_position * 2.0
                remaining_chords = original_length - current_position
                segment_duration = min(segment_length, remaining_chords) * 2.0

            chord_tensor = torch.tensor([padded_chords], dtype=torch.long, device=device)
            mask_tensor = torch.tensor([chord_mask], dtype=torch.bool, device=device)
            focus_tensor = torch.tensor([focus_position], dtype=torch.long, device=device)
            segment_pos = torch.tensor([current_position / max(1, original_length - 1)], dtype=torch.float32, device=device)

            with torch.no_grad():
                generated_sequence = self._generate_with_density_control(
                    chord_tensor,
                    mask_tensor,
                    focus_tensor,
                    target_density=target_density,
                    density_window=density_window,
                    temperature=temperature
                )

            segment_melody = []
            valid_notes = [note for note in generated_sequence if note['pitch'].item() > 0]

            if valid_notes:
                for i, note_info in enumerate(valid_notes):
                    pitch_idx = note_info['pitch'].item()
                    pitch = int(idx_to_note.get(str(pitch_idx), 60))

                    note_position = i / len(valid_notes)
                    relative_start = note_info['start'].item() / 8.0

                    blended_start = 0.7 * note_position + 0.3 * relative_start

                    actual_start = segment_start_time + (blended_start * segment_duration)

                    relative_duration = min(1.0, note_info['duration'].item() / 4.0)
                    actual_duration = max(0.1, min(1.0, relative_duration * segment_duration * 0.2))

                    segment_melody.append({
                        'pitch': pitch,
                        'duration': actual_duration,
                        'start_time': actual_start
                    })

            if segment_idx == 0:
                full_melody.extend(segment_melody)
            else:
                overlap_end_time = segment_start_time + (overlap * segment_duration / segment_length)

                for note in segment_melody:
                    if note['start_time'] < overlap_end_time:
                        if len(full_melody) == 0 or note['start_time'] > full_melody[-1]['start_time'] + 0.3:
                            full_melody.append(note)
                    else:
                        full_melody.append(note)

            print(f"Generated segment {segment_idx + 1} "
                f"(position: {current_position}-{segment_end}, focus: chord {focus_position}, "
                f"time: {segment_start_time:.1f}-{segment_start_time+segment_duration:.1f}s, "
                f"notes: {len(segment_melody)})")

            current_position += step_size
            segment_idx += 1

            if segment_idx > 200:
                print("Warning: Too many segments generated, stopping")
                break

        full_melody.sort(key=lambda x: x['start_time'])

        if full_melody:
            max_allowed_time = total_song_duration * 1.1
            full_melody = [note for note in full_melody if note['start_time'] <= max_allowed_time]

            if full_melody and full_melody[-1]['start_time'] + full_melody[-1]['duration'] > max_allowed_time:
                full_melody[-1]['duration'] = max(0.1, max_allowed_time - full_melody[-1]['start_time'])

        print(f"Generated {len(full_melody)} notes for full song")
        if full_melody:
            generated_duration = max(note['start_time'] + note['duration'] for note in full_melody)
            print(f"Generated melody duration: {generated_duration:.1f} seconds")
            print(f"Expected song duration: {total_song_duration:.1f} seconds")
            coverage_ratio = generated_duration / total_song_duration if total_song_duration > 0 else 0
            print(f"Coverage ratio: {coverage_ratio:.2f}")

        return full_melody

    def generate_melody_from_chords(self, chord_sequence, vocab_path, target_density=0.6,
                                density_window=8, temperature=1.0, chord_duration=2.0):
        """Generate melody for a single chord sequence with proper timing"""
        self.eval()

        original_device = next(self.parameters()).device
        self.to('cpu')

        with open(vocab_path, 'r') as f:
            vocabularies = json.load(f)

        chord_to_idx = vocabularies['chord_vocab']
        idx_to_note = vocabularies['idx_to_note']

        chord_indices = []
        for chord in chord_sequence:
            chord_idx = chord_to_idx.get(chord, chord_to_idx.get('<UNK>', 0))
            chord_indices.append(chord_idx)

        original_length = len(chord_indices)
        while len(chord_indices) < self.max_chord_length:
            chord_indices.append(chord_to_idx.get('<PAD>', 0))
        chord_indices = chord_indices[:self.max_chord_length]

        chord_tensor = torch.tensor([chord_indices], dtype=torch.long)
        chord_mask = torch.tensor([[True] * original_length + [False] * (self.max_chord_length - original_length)],
                                dtype=torch.bool)
        focus_position = torch.tensor([original_length // 2], dtype=torch.long)

        total_duration = len(chord_sequence) * chord_duration

        with torch.no_grad():
            generated_sequence = self._generate_with_density_control(
                chord_tensor, chord_mask, focus_position,
                target_density, density_window, temperature
            )

        melody = []
        for i, note_info in enumerate(generated_sequence):
            pitch_idx = note_info['pitch'].item()
            if pitch_idx > 0:
                pitch = int(idx_to_note.get(str(pitch_idx), 60))

                relative_start = note_info['start'].item()
                relative_duration = note_info['duration'].item()

                normalized_start = min(1.0, relative_start / 8.0)
                normalized_duration = min(1.0, relative_duration / 4.0)

                actual_start = normalized_start * total_duration
                actual_duration = max(0.1, normalized_duration * total_duration * 0.1)

                melody.append({
                    'pitch': pitch,
                    'duration': actual_duration,
                    'start_time': actual_start
                })

        self.to(original_device)

        return melody


    def _generate_with_density_control(self, chord_tensor, chord_mask, focus_position,
                             target_density=0.6, density_window=8, temperature=1.0):
        """
        Generate sequence with balanced density control
        """
        batch_size = chord_tensor.size(0)
        device = chord_tensor.device

        chord_embedded = self.chord_embedding(chord_tensor)
        chord_seq_len = chord_tensor.size(1)

        chord_positions = torch.arange(chord_seq_len, device=device)
        chord_pos_emb = self.chord_position_embedding(chord_positions).unsqueeze(0).expand(batch_size, -1, -1)
        chord_embedded = chord_embedded + chord_pos_emb
        chord_embedded = self.dropout(chord_embedded)

        padding_mask = ~chord_mask

        chord_encoded = self.chord_encoder(chord_embedded, src_key_padding_mask=padding_mask)

        focus_context, _ = self.focus_attention(chord_encoded, chord_encoded, chord_encoded, key_padding_mask=padding_mask)
        chord_encoded = chord_encoded + self.focus_projection(focus_context)
        chord_encoded = self.layer_norm(chord_encoded)

        generated_sequence = []
        current_input = torch.zeros(batch_size, 1, dtype=torch.long, device=device)

        recent_predictions = []

        print(f"Generating with target density: {target_density:.2f}")

        for i in range(self.max_melody_length):
            melody_embedded = self.note_embedding(current_input)
            melody_seq_len = current_input.size(1)

            melody_positions = torch.arange(melody_seq_len, device=device)
            melody_pos_emb = self.position_embedding(melody_positions).unsqueeze(0).expand(batch_size, -1, -1)
            melody_embedded = melody_embedded + melody_pos_emb
            melody_embedded = self.dropout(melody_embedded)

            decoded = self.melody_decoder(melody_embedded, chord_encoded, memory_key_padding_mask=padding_mask)

            last_output = decoded[:, -1:]
            pitch_logits = self.pitch_head(last_output)
            duration_pred = self.duration_head(last_output).squeeze(-1)
            start_pred = self.start_head(last_output).squeeze(-1)

            if len(recent_predictions) >= density_window:
                recent_notes = sum(1 for x in recent_predictions[-density_window:] if x > 0)
                current_density = recent_notes / density_window

                density_deviation = current_density - target_density

                density_bias = self._calculate_density_bias(density_deviation, temperature)

                if len(recent_predictions) >= 3:
                    recent_note_count = sum(1 for x in recent_predictions[-3:] if x > 0)
                    if recent_note_count >= 3:
                        temporal_bias = -1.0
                    elif recent_note_count == 0:
                        temporal_bias = 0.5
                    else:
                        temporal_bias = 0.0
                else:
                    temporal_bias = 0.0

                total_bias = density_bias + temporal_bias
                pitch_logits = self._apply_density_bias(pitch_logits, total_bias)

                if i % 8 == 0:
                    print(f"Step {i}: Density: {current_density:.3f}/{target_density:.3f}, "
                        f"Bias: {total_bias:.3f}")

            if temperature != 1.0:
                pitch_logits = pitch_logits / temperature

            pitch_probs = torch.softmax(pitch_logits.squeeze(1), dim=-1)
            next_pitch = torch.multinomial(pitch_probs, 1)

            recent_predictions.append(next_pitch.item())

            if len(recent_predictions) > density_window * 2:
                recent_predictions = recent_predictions[-density_window:]

            current_input = torch.cat([current_input, next_pitch], dim=1)

            generated_sequence.append({
                'pitch': next_pitch.squeeze(-1),
                'duration': duration_pred,
                'start': start_pred
            })

        return generated_sequence

    def _calculate_density_bias(self, density_deviation, temperature):
        """
        Calculate bias to apply to pitch logits based on density deviation
        """
        bias_strength = 2.0 / temperature

        if density_deviation > 0.15:
            return -bias_strength
        elif density_deviation < -0.15:
            return bias_strength
        else:
            return 0.0

    def _apply_density_bias(self, pitch_logits, bias):
        """
        Apply density bias to pitch logits
        """
        if bias == 0.0:
            return pitch_logits

        adjusted_logits = pitch_logits.clone()

        if bias > 0:
            adjusted_logits[:, :, 1:] += bias
        else:
            adjusted_logits[:, :, 0] += abs(bias)

        return adjusted_logits

    def generate_chord_aligned_melody(self, chord_sequence, chord_times, vocab_path, temperature=1.0):
        """
        Generate melody that aligns with real song timing by processing chord-by-chord
        ENHANCED: Now uses timing enhancement methods
        """
        self.eval()
        device = next(self.parameters()).device

        with open(vocab_path, 'r') as f:
            vocabularies = json.load(f)

        chord_to_idx = vocabularies['chord_vocab']
        idx_to_note = vocabularies['idx_to_note']

        chord_indices = []
        for chord in chord_sequence:
            chord_idx = chord_to_idx.get(chord, chord_to_idx.get('<UNK>', 0))
            chord_indices.append(chord_idx)

        if len(chord_indices) > self.max_chord_length:
            print(f"Warning: Chord sequence too long, using sliding window approach")
            return self._generate_with_sliding_window(chord_sequence, chord_times, vocab_path, temperature)

        song_bpm = self._calculate_song_bpm(chord_times)
        beat_duration = 60.0 / (song_bpm * 4)

        padded_chords = chord_indices + [chord_to_idx.get('<PAD>', 0)] * (self.max_chord_length - len(chord_indices))
        chord_mask = [True] * len(chord_indices) + [False] * (self.max_chord_length - len(chord_indices))

        chord_durations = [end - start for start, end in chord_times[:len(chord_indices)]]
        max_duration = max(chord_durations) if chord_durations else 4.0
        normalized_durations = [d / max_duration for d in chord_durations]
        normalized_durations += [0.0] * (self.max_chord_length - len(normalized_durations))

        full_melody = []

        print(f"Generating melody for {len(chord_indices)} chords...")

        for focus_idx in range(len(chord_indices)):
            chord_start, chord_end = chord_times[focus_idx]
            chord_duration = chord_end - chord_start
            current_chord = chord_sequence[focus_idx]

            print(f"  Chord {focus_idx + 1}/{len(chord_indices)}: {current_chord} ({chord_duration:.2f}s)")

            chord_tensor = torch.tensor([padded_chords], dtype=torch.long, device=device)
            chord_duration_tensor = torch.tensor([normalized_durations], dtype=torch.float32, device=device)
            chord_mask_tensor = torch.tensor([chord_mask], dtype=torch.bool, device=device)
            focus_tensor = torch.tensor([focus_idx], dtype=torch.long, device=device)
            target_duration_tensor = torch.tensor([chord_duration], dtype=torch.float32, device=device)

            with torch.no_grad():
                chord_melody = self._generate_single_chord_melody(
                    chord_tensor,
                    chord_duration_tensor,
                    chord_mask_tensor,
                    focus_tensor,
                    target_duration_tensor,
                    temperature
                )

            for note_info in chord_melody:
                pitch_idx = note_info['pitch'].item()
                if pitch_idx > 0:
                    pitch = int(idx_to_note.get(str(pitch_idx), 60))

                    relative_start = note_info['start'].item()
                    relative_duration = note_info['duration'].item()

                    absolute_start = chord_start + (relative_start * chord_duration)
                    absolute_duration = max(0.1, relative_duration * chord_duration)

                    if absolute_start + absolute_duration > chord_end:
                        absolute_duration = chord_end - absolute_start

                    if absolute_duration > 0.05:
                        full_melody.append({
                            'pitch': pitch,
                            'start_time': absolute_start,
                            'duration': absolute_duration,
                            'chord': current_chord,
                            'chord_index': focus_idx
                        })

        full_melody.sort(key=lambda x: x['start_time'])

        cleaned_melody = self._enhance_note_timing_with_bpm(full_melody, beat_duration)

        print(f"Generated {len(cleaned_melody)} notes total")
        return cleaned_melody

    def _generate_single_chord_melody(self, chord_tensor, chord_durations, chord_mask, focus_position, target_duration, temperature):
        """Generate melody for a single chord using the trained model"""

        batch_size = chord_tensor.size(0)
        device = chord_tensor.device

        chord_embedded = self.chord_embedding(chord_tensor)
        chord_seq_len = chord_tensor.size(1)

        chord_positions = torch.arange(chord_seq_len, device=device)
        chord_pos_emb = self.chord_position_embedding(chord_positions).unsqueeze(0).expand(batch_size, -1, -1)
        chord_embedded = chord_embedded + chord_pos_emb
        chord_embedded = self.dropout(chord_embedded)

        padding_mask = ~chord_mask
        chord_encoded = self.chord_encoder(chord_embedded, src_key_padding_mask=padding_mask)

        focus_context, _ = self.focus_attention(chord_encoded, chord_encoded, chord_encoded, key_padding_mask=padding_mask)
        chord_encoded = chord_encoded + self.focus_projection(focus_context)
        chord_encoded = self.layer_norm(chord_encoded)

        generated_sequence = []
        current_input = torch.zeros(batch_size, 1, dtype=torch.long, device=device)

        max_notes_per_chord = 8

        for i in range(max_notes_per_chord):
            melody_embedded = self.note_embedding(current_input)
            melody_seq_len = current_input.size(1)

            melody_positions = torch.arange(melody_seq_len, device=device)
            melody_pos_emb = self.position_embedding(melody_positions).unsqueeze(0).expand(batch_size, -1, -1)
            melody_embedded = melody_embedded + melody_pos_emb
            melody_embedded = self.dropout(melody_embedded)

            decoded = self.melody_decoder(melody_embedded, chord_encoded, memory_key_padding_mask=padding_mask)

            last_output = decoded[:, -1:]
            pitch_logits = self.pitch_head(last_output)
            duration_pred = self.duration_head(last_output).squeeze(-1)
            start_pred = self.start_head(last_output).squeeze(-1)

            if temperature != 1.0:
                pitch_logits = pitch_logits / temperature

            pitch_probs = torch.softmax(pitch_logits.squeeze(1), dim=-1)
            next_pitch = torch.multinomial(pitch_probs, 1)

            current_input = torch.cat([current_input, next_pitch], dim=1)

            generated_sequence.append({
                'pitch': next_pitch.squeeze(-1),
                'duration': torch.clamp(duration_pred, 0.0, 1.0),
                'start': torch.clamp(start_pred, 0.0, 1.0)
            })

            if next_pitch.item() == 0:
                break

        return generated_sequence

    def _remove_overlapping_notes(self, melody):
        """Remove overlapping notes to create clean melody line"""
        if len(melody) <= 1:
            return melody

        cleaned = [melody[0]]

        for note in melody[1:]:
            last_note = cleaned[-1]

            if note['start_time'] < last_note['start_time'] + last_note['duration']:
                if last_note['start_time'] + 0.1 < note['start_time']:
                    last_note['duration'] = note['start_time'] - last_note['start_time']
                    cleaned.append(note)
            else:
                cleaned.append(note)

        return cleaned

    def _generate_with_sliding_window(self, chord_sequence, chord_times, vocab_path, temperature):
        """Handle very long chord sequences with sliding window approach"""
        print(f"Using sliding window for {len(chord_sequence)} chords")

        window_size = self.max_chord_length - 2
        overlap = window_size // 3
        full_melody = []

        num_windows = (len(chord_sequence) + window_size - overlap - 1) // (window_size - overlap)
        print(f"Processing {num_windows} windows (window_size={window_size}, overlap={overlap})")

        for window_idx in range(num_windows):
            start_idx = window_idx * (window_size - overlap)
            end_idx = min(start_idx + window_size, len(chord_sequence))

            print(f"  Window {window_idx + 1}/{num_windows}: chords {start_idx}-{end_idx} ({end_idx - start_idx} chords)")

            window_chords = chord_sequence[start_idx:end_idx]
            window_times = chord_times[start_idx:end_idx]

            if window_times:
                time_offset = window_times[0][0]
                adjusted_times = [(start - time_offset, end - time_offset) for start, end in window_times]
            else:
                adjusted_times = []

            if len(window_chords) <= self.max_chord_length:
                window_melody = self._generate_window_melody(
                    window_chords,
                    adjusted_times,
                    vocab_path,
                    temperature,
                    time_offset
                )
            else:
                window_chords = window_chords[:self.max_chord_length]
                adjusted_times = adjusted_times[:self.max_chord_length]
                window_melody = self._generate_window_melody(
                    window_chords,
                    adjusted_times,
                    vocab_path,
                    temperature,
                    time_offset
                )

            if window_idx == 0:
                full_melody.extend(window_melody)
                print(f"    Added {len(window_melody)} notes from first window")
            else:
                overlap_chord_count = overlap
                if start_idx + overlap_chord_count < len(chord_times):
                    overlap_end_time = chord_times[start_idx + overlap_chord_count][0]
                else:
                    overlap_end_time = chord_times[-1][1]

                new_notes = 0
                for note in window_melody:
                    if note['start_time'] >= overlap_end_time:
                        full_melody.append(note)
                        new_notes += 1

                print(f"    Added {new_notes} notes from window {window_idx + 1} (after overlap filtering)")

        print(f"Sliding window complete: {len(full_melody)} total notes")
        return full_melody
    def _generate_window_melody(self, chord_sequence, chord_times, vocab_path, temperature, time_offset):
        """
        Generate melody for a single window
        """

        with open(vocab_path, 'r') as f:
            vocabularies = json.load(f)

        chord_to_idx = vocabularies['chord_vocab']
        idx_to_note = vocabularies['idx_to_note']
        device = next(self.parameters()).device

        chord_indices = []
        for chord in chord_sequence:
            chord_idx = chord_to_idx.get(chord, chord_to_idx.get('<UNK>', 0))
            chord_indices.append(chord_idx)

        original_length = len(chord_indices)
        padded_chords = chord_indices + [chord_to_idx.get('<PAD>', 0)] * (self.max_chord_length - len(chord_indices))
        chord_mask = [True] * original_length + [False] * (self.max_chord_length - original_length)

        chord_durations = [end - start for start, end in chord_times[:original_length]]
        max_duration = max(chord_durations) if chord_durations else 4.0
        normalized_durations = [d / max_duration for d in chord_durations]
        normalized_durations += [0.0] * (self.max_chord_length - len(normalized_durations))

        window_melody = []

        for focus_idx in range(original_length):
            if focus_idx >= len(chord_times):
                break

            chord_start, chord_end = chord_times[focus_idx]
            chord_duration = chord_end - chord_start
            current_chord = chord_sequence[focus_idx]

            chord_tensor = torch.tensor([padded_chords], dtype=torch.long, device=device)
            chord_duration_tensor = torch.tensor([normalized_durations], dtype=torch.float32, device=device)
            chord_mask_tensor = torch.tensor([chord_mask], dtype=torch.bool, device=device)
            focus_tensor = torch.tensor([focus_idx], dtype=torch.long, device=device)
            target_duration_tensor = torch.tensor([chord_duration], dtype=torch.float32, device=device)

            with torch.no_grad():
                chord_melody = self._generate_single_chord_melody(
                    chord_tensor,
                    chord_duration_tensor,
                    chord_mask_tensor,
                    focus_tensor,
                    target_duration_tensor,
                    temperature
                )

            for note_info in chord_melody:
                pitch_idx = note_info['pitch'].item()
                if pitch_idx > 0:
                    pitch = int(idx_to_note.get(str(pitch_idx), 60))

                    relative_start = max(0.0, min(1.0, note_info['start'].item()))
                    relative_duration = max(0.0, min(1.0, note_info['duration'].item()))

                    absolute_start = time_offset + chord_start + (relative_start * chord_duration)
                    absolute_duration = max(0.1, relative_duration * chord_duration * 0.5)

                    chord_end_absolute = time_offset + chord_end
                    if absolute_start + absolute_duration > chord_end_absolute:
                        absolute_duration = max(0.1, chord_end_absolute - absolute_start)

                    if absolute_duration > 0.05:
                        window_melody.append({
                            'pitch': pitch,
                            'start_time': absolute_start,
                            'duration': absolute_duration,
                            'chord': current_chord,
                            'chord_index': focus_idx
                        })

        enhanced_melody = self._enhance_note_timing(window_melody)

        return enhanced_melody


## Trainer

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import time

class ChordToMelodyTrainer:
    def __init__(self, model, train_loader, val_loader, device='cuda'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device

        self.pitch_criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.duration_criterion = nn.MSELoss()
        self.start_criterion = nn.MSELoss()

        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4, weight_decay=1e-5)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            patience=3,
            factor=0.5,
        )

        self.train_losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')

    def calculate_accuracy(self, pitch_logits, target_pitch, mask):
        with torch.no_grad():
            pred_pitch = torch.argmax(pitch_logits, dim=-1)
            target_flat = target_pitch.reshape(-1)
            pred_flat = pred_pitch.reshape(-1)
            mask_flat = mask.reshape(-1)

            if mask_flat.sum() > 0:
                correct = (pred_flat == target_flat) & mask_flat
                accuracy = correct.sum().float() / mask_flat.sum().float()
                return accuracy.item()
            return 0.0

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        total_accuracy = 0
        num_batches = 0

        for batch_idx, batch in enumerate(self.train_loader):
            full_chord_sequence = batch['full_chord_sequence'].to(self.device)
            chord_durations = batch['chord_durations'].to(self.device)
            chord_mask = batch['chord_mask'].to(self.device)
            focus_positions = batch['focus_position'].to(self.device)
            target_chord_duration = batch['target_chord_duration'].to(self.device)  # NEW
            melody_pitch = batch['melody_pitch'].to(self.device)
            melody_duration = batch['melody_duration'].to(self.device)
            melody_start = batch['melody_start'].to(self.device)
            melody_mask = batch['melody_mask'].to(self.device)


            self.optimizer.zero_grad()

            pitch_logits, duration_pred, start_pred = self.model(
                full_chord_sequence,
                chord_mask,
                focus_positions,
                chord_durations=chord_durations,
                melody_pitch=melody_pitch,
                training=True
            )

            target_pitch = melody_pitch[:, 1:]
            target_duration = melody_duration[:, :-1]
            target_start = melody_start[:, :-1]
            target_mask = melody_mask[:, :-1]

            pitch_loss = self.pitch_criterion(
                pitch_logits.reshape(-1, pitch_logits.size(-1)),
                target_pitch.reshape(-1)
            )

            if target_mask.sum() > 0:
                duration_loss = self.duration_criterion(
                    duration_pred[target_mask],
                    target_duration[target_mask]
                )
                start_loss = self.start_criterion(
                    start_pred[target_mask],
                    target_start[target_mask]
                )
            else:
                duration_loss = torch.tensor(0.0, device=self.device)
                start_loss = torch.tensor(0.0, device=self.device)

            total_batch_loss = pitch_loss + 0.3 * duration_loss + 0.3 * start_loss

            total_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            accuracy = self.calculate_accuracy(pitch_logits, target_pitch, target_mask)

            total_loss += total_batch_loss.item()
            total_accuracy += accuracy
            num_batches += 1

            if batch_idx % 50 == 0 and batch_idx > 0:
                current_loss = total_loss / (batch_idx + 1)
                current_acc = total_accuracy / (batch_idx + 1)
                print(f"  Batch {batch_idx}/{len(self.train_loader)}: "
                    f"Loss {current_loss:.4f}, Acc {current_acc:.3f}")

        return total_loss / num_batches, total_accuracy / num_batches

    def validate(self):
        self.model.eval()
        total_loss = 0
        total_accuracy = 0
        num_batches = 0

        with torch.no_grad():
            for batch in self.val_loader:
                full_chord_sequence = batch['full_chord_sequence'].to(self.device)
                chord_durations = batch['chord_durations'].to(self.device)
                chord_mask = batch['chord_mask'].to(self.device)
                focus_positions = batch['focus_position'].to(self.device)
                target_chord_duration = batch['target_chord_duration'].to(self.device)  # NEW
                melody_pitch = batch['melody_pitch'].to(self.device)
                melody_duration = batch['melody_duration'].to(self.device)
                melody_start = batch['melody_start'].to(self.device)
                melody_mask = batch['melody_mask'].to(self.device)

                try:
                    pitch_logits, duration_pred, start_pred = self.model(
                        full_chord_sequence,
                        chord_mask,
                        focus_positions,
                        chord_durations=chord_durations,
                        melody_pitch=melody_pitch,
                        training=True
                    )
                except RuntimeError as e:
                    if "MPS" in str(e):
                        self.model.cpu()
                        pitch_logits, duration_pred, start_pred = self.model(
                            full_chord_sequence.cpu(),
                            chord_mask.cpu(),
                            focus_positions.cpu(),
                            chord_durations=chord_durations.cpu(),  # NEW
                            melody_pitch=melody_pitch.cpu(),
                            training=True
                        )
                        self.model.to(self.device)
                        pitch_logits = pitch_logits.to(self.device)
                        duration_pred = duration_pred.to(self.device)
                        start_pred = start_pred.to(self.device)
                    else:
                        raise e

                target_pitch = melody_pitch[:, 1:]
                target_duration = melody_duration[:, :-1]
                target_start = melody_start[:, :-1]
                target_mask = melody_mask[:, :-1]

                pitch_loss = self.pitch_criterion(
                    pitch_logits.reshape(-1, pitch_logits.size(-1)),
                    target_pitch.reshape(-1)
                )

                if target_mask.sum() > 0:
                    duration_loss = self.duration_criterion(
                        duration_pred[target_mask],
                        target_duration[target_mask]
                    )
                    start_loss = self.start_criterion(
                        start_pred[target_mask],
                        target_start[target_mask]
                    )
                else:
                    duration_loss = torch.tensor(0.0, device=self.device)
                    start_loss = torch.tensor(0.0, device=self.device)

                total_batch_loss = pitch_loss + 0.3 * duration_loss + 0.3 * start_loss
                accuracy = self.calculate_accuracy(pitch_logits, target_pitch, target_mask)

                total_loss += total_batch_loss.item()
                total_accuracy += accuracy
                num_batches += 1

        return total_loss / num_batches, total_accuracy / num_batches

    def train(self, num_epochs):
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        start_time = time.time()

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")

            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()

            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)

            self.scheduler.step(val_loss)

            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.3f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.3f}")
            print(f"Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
            self.plot_training_history()

            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_chord_melody_model.pt')
                print("✓ New best model saved!")

            if epoch > 10 and val_loss > min(self.val_losses) * 1.2:
                print("Early stopping - validation loss increasing")
                break

            print("-" * 50)

        end_time = time.time()
        training_time = (end_time - start_time) / 60
        print(f"\nTraining completed in {training_time:.1f} minutes")
        print(f"Best validation loss: {self.best_val_loss:.4f}")

    def plot_training_history(self):
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Training Loss', color='blue')
        plt.plot(self.val_losses, label='Validation Loss', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Progress')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 2, 2)
        if len(self.val_losses) > 3:
            window = min(3, len(self.val_losses) // 2)
            smoothed_val = []
            for i in range(len(self.val_losses)):
                start_idx = max(0, i - window)
                end_idx = min(len(self.val_losses), i + window + 1)
                smoothed_val.append(sum(self.val_losses[start_idx:end_idx]) / (end_idx - start_idx))

            plt.plot(smoothed_val, label='Smoothed Val Loss', color='green')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Smoothed Validation Loss')
            plt.legend()
            plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig('training_history.png', dpi=150, bbox_inches='tight')

## Evaluator

In [None]:
import numpy as np
from typing import List, Dict, Tuple

class SimpleMelodyEvaluator:
    def __init__(self):
        # Simple chord-to-pitch mappings
        self.chord_pitch_maps = {
            'maj': [0, 4, 7],           # C major: C, E, G
            'maj7': [0, 4, 7, 11],      # Cmaj7: C, E, G, B
            'min': [0, 3, 7],           # C minor: C, Eb, G
            'min7': [0, 3, 7, 10],      # Cm7: C, Eb, G, Bb
            'minmaj7': [0, 3, 7, 11],   # CmMaj7: C, Eb, G, B
            '7': [0, 4, 7, 10],         # C7: C, E, G, Bb
            'dim': [0, 3, 6],           # Cdim: C, Eb, Gb
            'aug': [0, 4, 8],           # Caug: C, E, G#
            'sus2': [0, 2, 7],          # Csus2: C, D, G
            'sus4': [0, 5, 7],          # Csus4: C, F, G
        }

    def parse_chord_symbol(self, chord_symbol: str) -> Tuple[int, str]:
        """Parse chord symbol into root note and quality"""
        if ':' not in chord_symbol:
            return 0, 'maj'

        root_str, quality = chord_symbol.split(':', 1)

        root_map = {
            'C': 0, 'C#': 1, 'Db': 1, 'D': 2, 'D#': 3, 'Eb': 3,
            'E': 4, 'F': 5, 'F#': 6, 'Gb': 6, 'G': 7, 'G#': 8,
            'Ab': 8, 'A': 9, 'A#': 10, 'Bb': 10, 'B': 11
        }

        root = root_map.get(root_str, 0)
        return root, quality

    def get_chord_pitches(self, chord_symbol: str) -> List[int]:
        """Get pitch classes that belong to a chord"""
        root, quality = self.parse_chord_symbol(chord_symbol)

        if quality in self.chord_pitch_maps:
            intervals = self.chord_pitch_maps[quality]
        else:
            intervals = self.chord_pitch_maps['maj']

        return [(root + interval) % 12 for interval in intervals]

    def find_chord_at_time(self, time: float, chord_times: List[Tuple]) -> int:
        """Find which chord is active at a given time"""
        for i, chord_time in enumerate(chord_times):
            if len(chord_time) >= 2:
                start, end = chord_time[0], chord_time[1]
                if start <= time < end:
                    return i
        return len(chord_times) - 1 if chord_times else 0

    def chord_alignment_score(self, melody: List[Dict], chord_sequence: List[str], chord_times: List[Tuple]) -> float:
        """How well do melody notes fit the chords? (0.0 to 1.0)"""
        if not melody or not chord_sequence:
            return 0.0

        alignment_scores = []

        for note in melody:
            chord_idx = self.find_chord_at_time(note['start_time'], chord_times)
            if chord_idx < len(chord_sequence):
                chord = chord_sequence[chord_idx]
                chord_notes = self.get_chord_pitches(chord)
                pitch_class = note['pitch'] % 12

                if pitch_class in chord_notes:
                    alignment_scores.append(1.0)
                else:
                    alignment_scores.append(0.0)
            else:
                alignment_scores.append(0.5)

        return np.mean(alignment_scores) if alignment_scores else 0.0

    def timing_accuracy_score(self, melody: List[Dict], expected_duration: float) -> float:
        """How well does generated melody match expected duration? (0.0 to 1.0)"""
        if not melody or expected_duration <= 0:
            return 0.0

        actual_duration = max(n['start_time'] + n['duration'] for n in melody)

        ratio = min(actual_duration, expected_duration) / max(actual_duration, expected_duration)
        return ratio

    def note_density_score(self, melody: List[Dict]) -> float:
        """How reasonable is the note density? (0.0 to 1.0)"""
        if not melody:
            return 0.0

        total_duration = max(n['start_time'] + n['duration'] for n in melody)
        if total_duration <= 0:
            return 0.0

        density = len(melody) / total_duration

        if 0.5 <= density <= 2.0:
            return 1.0
        elif density < 0.5:
            return density / 0.5
        else:
            return max(0.0, 1.0 - (density - 2.0) / 3.0)

    def evaluate(self, melody: List[Dict], chord_sequence: List[str], chord_times: List[Tuple]) -> Dict[str, float]:
        """Main evaluation function - returns scores between 0.0 and 1.0"""

        if chord_times and len(chord_times[0]) >= 2:
            expected_duration = chord_times[-1][1]
        else:
            expected_duration = len(chord_sequence) * 2.0

        scores = {
            'chord_alignment': self.chord_alignment_score(melody, chord_sequence, chord_times),
            'timing_accuracy': self.timing_accuracy_score(melody, expected_duration),
            'note_density': self.note_density_score(melody)
        }

        scores['overall'] = (scores['chord_alignment'] + scores['timing_accuracy'] + scores['note_density']) / 3.0

        return scores

    def print_evaluation(self, melody: List[Dict], chord_sequence: List[str], chord_times: List[Tuple]):
        """Print a nice evaluation report"""
        scores = self.evaluate(melody, chord_sequence, chord_times)

        print(f"\n🎵 MELODY EVALUATION")
        print("=" * 30)
        print(f"Overall Score:     {scores['overall']:.3f}")
        print(f"Chord Alignment:   {scores['chord_alignment']:.3f}")
        print(f"Timing Accuracy:   {scores['timing_accuracy']:.3f}")
        print(f"Note Density:      {scores['note_density']:.3f}")

        if scores['overall'] > 0.8:
            print("🌟 Excellent!")
        elif scores['overall'] > 0.6:
            print("✅ Good")
        elif scores['overall'] > 0.4:
            print("⚠️  Okay")
        else:
            print("❌ Needs work")

        if melody:
            duration = max(n['start_time'] + n['duration'] for n in melody)
            density = len(melody) / duration if duration > 0 else 0
            pitch_range = max(n['pitch'] for n in melody) - min(n['pitch'] for n in melody)

            print(f"\n📊 Stats:")
            print(f"Notes: {len(melody)}")
            print(f"Duration: {duration:.1f}s")
            print(f"Density: {density:.2f} notes/sec")
            print(f"Pitch range: {pitch_range} semitones")

        return scores


def evaluate_melody(melody, chord_sequence, chord_times):
    """Simple function to evaluate a melody"""
    evaluator = SimpleMelodyEvaluator()
    return evaluator.print_evaluation(melody, chord_sequence, chord_times)

## Utils

In [None]:
from typing import List, Dict, Optional
import pretty_midi
import numpy as np

def generate_midi_from_melody(melody: List[Dict], output_path: str, tempo: int = 120,
                             original_midi_path: Optional[str] = None):
    """
    Generate MIDI file from melody with proper timing
    """

    original_tempo = tempo
    if original_midi_path:
        try:
            original_midi = pretty_midi.PrettyMIDI(original_midi_path)
            original_tempo = original_midi.estimate_tempo()
            print(f"Original tempo: {original_tempo:.1f} BPM")
        except:
            print(f"Could not load original MIDI, using default tempo: {tempo} BPM")
            original_tempo = tempo

    midi = pretty_midi.PrettyMIDI(initial_tempo=original_tempo)

    melody_instrument = pretty_midi.Instrument(program=1, name="Generated Melody")

    if not melody:
        print("Warning: Empty melody provided")
        midi.instruments.append(melody_instrument)
        midi.write(output_path)
        return

    for note_info in melody:
        try:
            pitch = int(note_info['pitch'])
            start_time = float(note_info['start_time'])
            duration = float(note_info['duration'])

            duration = max(0.1, duration)
            if not (21 <= pitch <= 108):
                continue

            note = pretty_midi.Note(
                velocity=80,
                pitch=pitch,
                start=start_time,
                end=start_time + duration
            )
            melody_instrument.notes.append(note)

        except (KeyError, ValueError, TypeError) as e:
            print(f"Warning: Skipping invalid note {note_info}: {e}")
            continue

    melody_instrument.notes.sort(key=lambda n: n.start)

    midi.instruments.append(melody_instrument)

    if original_midi_path:
        try:
            original_midi = pretty_midi.PrettyMIDI(original_midi_path)
            for i, instrument in enumerate(original_midi.instruments):
                if i > 0 or "melody" not in instrument.name.lower():
                    instrument.name = f"Original_{instrument.name}"
                    midi.instruments.append(instrument)
            print(f"Added {len(original_midi.instruments)-1} original tracks")
        except Exception as e:
            print(f"Could not add original tracks: {e}")

    midi.write(output_path)

    if melody_instrument.notes:
        total_duration = melody_instrument.notes[-1].end
        note_count = len(melody_instrument.notes)
        print(f"Generated melody: {note_count} notes, {total_duration:.1f}s duration")
    else:
        print("Warning: No valid notes generated")

def extract_chord_timing_from_midi(midi_path: str, target_chord_count: int = None) -> List[tuple]:
    """
    Extract actual chord timing from original MIDI file with better alignment
    """
    try:
        midi = pretty_midi.PrettyMIDI(midi_path)
        total_duration = midi.get_end_time()

        print(f"MIDI Analysis: {midi_path}")
        print(f"  Total duration: {total_duration:.1f}s")
        print(f"  Target chord count: {target_chord_count}")

        accompaniment_track = None
        max_polyphony = 0

        for i, instrument in enumerate(midi.instruments):
            if instrument.is_drum:
                continue

            polyphony = calculate_average_polyphony(instrument)
            print(f"  Track {i} ({instrument.name}): {len(instrument.notes)} notes, polyphony: {polyphony:.1f}")

            if "melody" in instrument.name.lower() or "lead" in instrument.name.lower():
                continue

            if polyphony > max_polyphony:
                max_polyphony = polyphony
                accompaniment_track = instrument

        if not accompaniment_track:
            print("  No suitable accompaniment track found, using fallback")
            return create_fallback_timing(total_duration, target_chord_count)

        print(f"  Using track: {accompaniment_track.name} (polyphony: {max_polyphony:.1f})")

        chord_changes = detect_chord_changes(accompaniment_track, total_duration)

        print(f"  Detected {len(chord_changes)} chord changes")

        if target_chord_count and len(chord_changes) != target_chord_count:
            chord_changes = align_chord_timing(chord_changes, target_chord_count, total_duration)
            print(f"  Aligned to {len(chord_changes)} chord segments")

        return chord_changes

    except Exception as e:
        print(f"Error extracting chord timing: {e}")
        if target_chord_count:
            return create_fallback_timing(60.0, target_chord_count)
        return []

def calculate_average_polyphony(instrument):
    """Calculate average number of simultaneous notes"""
    if not instrument.notes:
        return 0.0

    duration = max(note.end for note in instrument.notes)
    sample_points = int(duration * 4)

    total_polyphony = 0
    for i in range(sample_points):
        time_point = i * 0.25
        active_notes = sum(1 for note in instrument.notes
                          if note.start <= time_point < note.end)
        total_polyphony += active_notes

    return total_polyphony / sample_points if sample_points > 0 else 0

def detect_chord_changes(instrument, total_duration):
    """Detect when chords change in the instrument"""
    if not instrument.notes:
        return []

    chord_events = []
    tolerance = 0.15

    sorted_notes = sorted(instrument.notes, key=lambda n: n.start)

    i = 0
    while i < len(sorted_notes):
        chord_start = sorted_notes[i].start
        chord_notes = [sorted_notes[i]]

        j = i + 1
        while j < len(sorted_notes) and sorted_notes[j].start - chord_start <= tolerance:
            chord_notes.append(sorted_notes[j])
            j += 1

        if len(chord_notes) >= 2 or (j < len(sorted_notes) and sorted_notes[j].start - chord_start > 0.5):
            chord_events.append({
                'start': chord_start,
                'notes': chord_notes,
                'polyphony': len(chord_notes)
            })

        i = j

    chord_times = []
    for i, event in enumerate(chord_events):
        start_time = event['start']

        if i + 1 < len(chord_events):
            end_time = chord_events[i + 1]['start']
        else:
            end_time = max(note.end for note in event['notes'])
            end_time = min(end_time, total_duration)

        if end_time - start_time >= 0.3:
            chord_times.append((start_time, end_time, len(chord_times)))

    return chord_times

def align_chord_timing(detected_changes, target_count, total_duration):
    """Align detected chord changes to target chord count"""

    if len(detected_changes) == target_count:
        return detected_changes

    print(f"  Aligning {len(detected_changes)} detected changes to {target_count} target chords")

    if len(detected_changes) > target_count:
        return merge_chord_segments(detected_changes, target_count)
    else:
        return interpolate_chord_segments(detected_changes, target_count, total_duration)
        return interpolate_chord_segments(detected_changes, target_count, total_duration)

def merge_chord_segments(segments, target_count):
    """Merge chord segments to reach target count"""
    if target_count >= len(segments):
        return segments

    indices = [int(i * len(segments) / target_count) for i in range(target_count)]
    merged = []

    for i, idx in enumerate(indices):
        start_time = segments[idx][0]

        if i + 1 < len(indices):
            next_idx = indices[i + 1]
            end_time = segments[next_idx][0]
        else:
            end_time = segments[-1][1]

        merged.append((start_time, end_time, i))

    return merged

def interpolate_chord_segments(segments, target_count, total_duration):
    """Interpolate to create more chord segments"""
    if not segments:
        return create_fallback_timing(total_duration, target_count)

    chord_duration = total_duration / target_count
    interpolated = []

    for i in range(target_count):
        start_time = i * chord_duration
        end_time = (i + 1) * chord_duration
        interpolated.append((start_time, end_time, i))

    return interpolated

def create_fallback_timing(duration, chord_count):
    """Create uniform chord timing as fallback"""
    if chord_count <= 0:
        return []

    chord_duration = duration / chord_count
    return [(i * chord_duration, (i + 1) * chord_duration, i)
            for i in range(chord_count)]


def align_melody_to_chord_timing(melody: List[Dict], chord_times: List[tuple]) -> List[Dict]:
    """
    Align generated melody to actual chord timing from original MIDI
    """
    if not melody or not chord_times:
        return melody

    total_chord_duration = chord_times[-1][1] - chord_times[0][0]

    melody_start = min(note['start_time'] for note in melody)
    melody_end = max(note['start_time'] + note.get('duration', 0) for note in melody)
    melody_span = melody_end - melody_start

    if melody_span <= 0:
        return melody

    time_scale = total_chord_duration / melody_span
    time_offset = chord_times[0][0] - melody_start * time_scale

    print(f"Aligning melody: scale={time_scale:.3f}, offset={time_offset:.3f}s")

    aligned_melody = []
    for note in melody:
        aligned_note = note.copy()
        aligned_note['start_time'] = note['start_time'] * time_scale + time_offset
        aligned_note['duration'] = note.get('duration', 0.5) * time_scale

        aligned_note['start_time'] = max(0, aligned_note['start_time'])
        aligned_note['duration'] = max(0.1, min(4.0, aligned_note['duration']))

        aligned_melody.append(aligned_note)

    return aligned_melody

## Main Pipeline

In [None]:
import torch
import json
from data_processor import process_dataset
from dataset import ChordMelodyDataset
from torch.utils.data import DataLoader
from model import AttentionChordToMelodyTransformer
from trainer import ChordToMelodyTrainer
from utils import generate_midi_from_melody
from pathlib import Path
import argparse
import os
import pretty_midi
from evaluator import evaluate_melody, SimpleMelodyEvaluator

def evaluate_generated_melody(melody, chord_sequence, chord_times):
    return evaluate_melody(melody, chord_sequence, chord_times)

def main(args):
    print("Chord to Melody Generation - Attention-Based Global Context Approach")
    print("=" * 70)

    # Process dataset if needed
    if args.process_data:
        print("Processing dataset...")
        dataset_path, output_path = "POP909-Dataset", "processed_pop909_chord_melody"
        process_dataset(dataset_path, output_path, melody_segment_length=32)
        print("Dataset processing completed!")
        return

    # Choose device - always use CPU for generation due to MPS limitations
    training_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    generation_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {training_device} (training) / {generation_device} (generation)")

    # Load vocabularies & dataset
    data_path = "processed_pop909_chord_melody/training_sequences.pkl"
    vocab_path = "processed_pop909_chord_melody/vocabularies.json"

    try:
        full_dataset = ChordMelodyDataset(data_path, vocab_path, max_chord_length=12, max_melody_length=16)
        print(f"Dataset loaded successfully!")
    except FileNotFoundError:
        print("Processed data not found. Please run with --process_data flag first.")
        return

    # Split dataset into train / val
    train_size = int(0.85 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )

    print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

    # Create dataloaders
    batch_size = 8 if training_device.type == 'cuda' else 16  # Smaller batch for MPS
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Create model
    model = AttentionChordToMelodyTransformer(
        chord_vocab_size=len(full_dataset.chord_to_idx),
        note_vocab_size=len(full_dataset.note_to_idx),
        d_model=256,
        nhead=8,
        num_layers=6,
        max_chord_length=12
    )

    print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

    if args.train:
        # Move model to training device
        model = model.to(training_device)
        # Create trainer and train
        trainer = ChordToMelodyTrainer(model, train_loader, val_loader, training_device)
        trainer.train(num_epochs=args.epochs)
        trainer.plot_training_history()
        print("Training completed!")
    else:
        # Load trained model
        try:
            model.load_state_dict(torch.load('best_chord_melody_model.pt', map_location=generation_device))
            model = model.to(generation_device)
            print("Loaded trained model successfully!")
        except FileNotFoundError:
            print("No trained model found. Please train first with --train flag.")
            return

    # Move model to CPU for generation
    model = model.to(generation_device)

    # Generate sample melodies
    print("\nGENERATING SAMPLE MELODY")
    print("="*70)

    # Test with real POP909 song
    if args.test_real_song:
        print("\nReal POP909 song with chord-aligned timing:")
        try:
            with open("processed_pop909_chord_melody/chord_melody_data.json", 'r') as f:
                song_data = json.load(f)

            # ✅ Fixed: Only test a few songs instead of looping through all
            test_indices = [i for i in range(len(song_data)-100, len(song_data))]

            for song_idx in test_indices:
                if song_idx >= len(song_data):
                    continue

                test_song = song_data[song_idx]
                real_chords = test_song['full_chord_sequence']
                song_id = test_song['song_id']

                print(f"\n" + "="*50)
                print(f"TESTING SONG {song_idx + 1}: {song_id}")
                print(f"="*50)

                # Path to original MIDI file
                original_midi_path = f"POP909-Dataset/POP909/{song_id}/{song_id}.mid"

                print(f"Chord progression: {len(real_chords)} chords")
                print(f"Chords: {' | '.join(real_chords[:8])}{'...' if len(real_chords) > 8 else ''}")

                # DEBUG: Check original MIDI properties
                original_tempo = 120  # Default
                try:
                    original_midi = pretty_midi.PrettyMIDI(original_midi_path)
                    original_duration = original_midi.get_end_time()
                    original_tempo = original_midi.estimate_tempo()

                    print(f"\nOriginal MIDI Analysis:")
                    print(f"  Duration: {original_duration:.1f}s")
                    print(f"  Estimated tempo: {original_tempo:.1f} BPM")
                    print(f"  Instruments: {len(original_midi.instruments)}")

                except Exception as e:
                    print(f"Could not analyze original MIDI: {e}")
                    original_duration = len(real_chords) * 2.0  # Fallback

                # Extract chord timing from original MIDI
                print(f"\nExtracting chord timing...")
                if original_midi_path and Path(original_midi_path).exists():
                    from utils import extract_chord_timing_from_midi
                    chord_times = extract_chord_timing_from_midi(
                        original_midi_path,
                        target_chord_count=len(real_chords)
                    )
                else:
                    print("  Original MIDI not found, using estimated timing")
                    chord_times = [(i * 2.0, (i + 1) * 2.0, i) for i in range(len(real_chords))]

                if chord_times:
                    extracted_duration = chord_times[-1][1] - chord_times[0][0]
                    print(f"  Extracted {len(chord_times)} chord segments")
                    print(f"  Extracted duration: {extracted_duration:.1f}s")
                    print(f"  Average chord duration: {extracted_duration/len(chord_times):.1f}s")
                else:
                    print("  Failed to extract timing, using fallback")
                    chord_times = [(i * 2.0, (i + 1) * 2.0, i) for i in range(len(real_chords))]

                # Generate melody with chord-aligned timing
                print(f"\nGenerating chord-aligned melody...")

                try:
                    generated_melody_stored_path = f"generated_melody_stored/{song_id}.json"
                    if os.path.exists(generated_melody_stored_path):
                        with open(generated_melody_stored_path, 'r') as f:
                            generated_melody = json.load(f)
                    else:
                        # Generate using the chord-aligned method
                        generated_melody = model.generate_chord_aligned_melody(
                            chord_sequence=real_chords,
                            chord_times=[(start, end) for start, end, _ in chord_times],  # Remove index
                            vocab_path=vocab_path,
                            temperature=1.1
                        )
                        os.makedirs("generated_melody_stored", exist_ok=True)
                        with open(generated_melody_stored_path, 'w') as f:
                            json.dump(generated_melody, f, indent=2)

                    if generated_melody:
                        print(f"✓ Generated {len(generated_melody)} notes")

                        # Calculate generated duration
                        generated_duration = max(note['start_time'] + note['duration'] for note in generated_melody)
                        expected_duration = chord_times[-1][1] if chord_times else len(real_chords) * 2.0

                        print(f"  Generated duration: {generated_duration:.1f}s")
                        print(f"  Expected duration: {expected_duration:.1f}s")
                        print(f"  Coverage ratio: {generated_duration/expected_duration:.2f}")

                        evaluation_scores = evaluate_generated_melody(
                            generated_melody,
                            real_chords,
                            [(start, end, 0) for start, end, _ in chord_times]
                        )

                        evaluator = SimpleMelodyEvaluator()
                        print(f"\n🔍 CHORD-BY-CHORD ANALYSIS:")
                        for i, chord in enumerate(real_chords[:5]):
                            if i >= len(chord_times):
                                break

                            chord_start, chord_end = chord_times[i][0], chord_times[i][1]

                            chord_notes = [n for n in generated_melody
                                          if chord_start <= n['start_time'] < chord_end]

                            if chord_notes:
                                chord_alignment = evaluator.chord_alignment_score(
                                    chord_notes, [chord], [(chord_start, chord_end, 0)]
                                )

                                pitches = [n['pitch'] % 12 for n in chord_notes]
                                print(f"  Chord {i+1} ({chord}): {chord_alignment:.3f} "
                                      f"| Notes: {len(chord_notes)} | Pitches: {pitches}")
                            else:
                                print(f"  Chord {i+1} ({chord}): No notes generated")

                        evaluation_results = {
                            'song_id': song_id,
                            'scores': evaluation_scores,
                            'melody_stats': {
                                'num_notes': len(generated_melody),
                                'duration': max(n['start_time'] + n['duration'] for n in generated_melody),
                                'pitch_range': max(n['pitch'] for n in generated_melody) - min(n['pitch'] for n in generated_melody),
                                'avg_note_duration': sum(n['duration'] for n in generated_melody) / len(generated_melody)
                            },
                            'chord_progression': real_chords[:10],
                            'timing_info': {
                                'expected_duration': expected_duration,
                                'generated_duration': generated_duration,
                                'coverage_ratio': generated_duration / expected_duration if expected_duration > 0 else 0
                            }
                        }

                        os.makedirs("evaluation_results", exist_ok=True)
                        with open(f'evaluation_results/evaluation_{song_id}.json', 'w') as f:
                            json.dump(evaluation_results, f, indent=2)

                        print(f"✅ Evaluation results saved to evaluation_results/evaluation_{song_id}.json")

                        os.makedirs("generated_real_song_aligned", exist_ok=True)
                        output_path = f"generated_real_song_aligned/{song_id}.mid"

                        try:
                            generate_midi_from_melody(
                                generated_melody,
                                output_path,
                                tempo=original_tempo,
                                original_midi_path=original_midi_path
                            )
                            print(f"✓ Saved MIDI to {output_path}")

                        except Exception as e:
                            print(f"  Error saving MIDI: {e}")
                            simple_output = f"generated_melody_only_{song_id}.mid"
                            generate_midi_from_melody(generated_melody, simple_output)
                            print(f"  Saved melody-only version to {simple_output}")

                        if 'original_duration' in locals():
                            print(f"\nTiming Comparison:")
                            print(f"  Original MIDI:     {original_duration:.1f}s")
                            print(f"  Generated melody:  {generated_duration:.1f}s")
                            timing_accuracy = 1 - abs(generated_duration - original_duration)/original_duration
                            print(f"  Timing accuracy:   {timing_accuracy:.1%}")

                            if timing_accuracy > 0.9:
                                print(f"  ✅ Excellent timing alignment!")
                            elif timing_accuracy > 0.8:
                                print(f"  ✓ Good timing alignment")
                            else:
                                print(f"  ⚠️ Timing could be improved")

                    else:
                        print("  ❌ No melody generated")

                except Exception as e:
                    print(f"  ❌ Generation failed: {e}")
                    import traceback
                    traceback.print_exc()

        except FileNotFoundError:
            print("Processed song data not found. Please run with --process_data flag first.")
        except Exception as e:
            print(f"Error in real song test: {e}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Attention-Based Chord to Melody Generation")
    parser.add_argument("--train", action="store_true", help="Train the model")
    parser.add_argument("--process_data", action="store_true", help="Process the dataset")
    parser.add_argument("--test_real_song", action="store_true", help="Test with real song chord progression")
    parser.add_argument("--epochs", type=int, default=30, help="Number of training epochs")
    args = parser.parse_args()

    main(args)

# Task 4 - Continuous Conditioned Generation

## Generaral Structure
Note: I will only copy-paste important python codes here.
It's not going to be executable, because the code depends on a lot of other stuff from the repository. Below is the repository URL. If you are interested, you can view complete implementation of this project.
https://github.com/SelenaGeRuiqi/stable-audio-tools
### Dataset Collection and Preprocessing

|-- piano_collector.py ## See [piano_collector](piano_collector.py)
|-- classical_piano_dataset
|   |-- processed_10s # Store .wav files
|   |-- classical_piano_prompts.json
|   |-- custom_metadata.py
|   |-- dataset_summary.json
|   `-- dataset_config.json

### Training

|-- stable-audio-open-1.0
|   |-- model.ckpt  ## Baseline checkpoint
|   |-- model_config.json  # Use this for training and evaluation, see [model config](stable-audio-open-1.0/model_config.json)
|   `-- ...
|-- stable_audio_tools
|   |-- models
|   |   |-- diffusion.py
|   |   |-- autoencoders.py
|   |   |-- pretransforms.py
|   |   |-- convnext.py
|   |   |-- pretrained.py
|   |   |-- conditioners.py
|   |   `-- ...
|   `-- ...
|-- train.py  ## Train the whole model, see [train instruction](train_instruction.md)
|-- train_freeze.py  ## Train the model with frozen T5 encoder and part of the unet, see [train_freeze](train_freeze.py)
`-- wandb
|-- checkpoints
|   `-- piano_1000_clips

### Music Generation(Samples will be shown at the end of the presentation)

|-- unwrap_model.py  ## Unwrap the finetuned model to get the checkpoint
|-- sao_piano_1000clips.ckpt  ## After unwrap
|-- run_gradio.py  ## Run the gradio api to test the finetuned model and generate samples

### Evaluation

|-- evaluation_piano.py ## See [evaluation_piano](evaluation_piano.py)
|-- evaluation_results
|   |-- baseline_samples
|   |-- finetuned_samples
|   `-- results
`-- ...

## Data Collection and Processing

In [None]:
import requests
import json
import os
from pathlib import Path
import librosa
import numpy as np
import torchaudio
import torch
import time
import random
import re
import tempfile

class SimplePianoCollector:
    def __init__(self, output_dir="classical_piano_dataset"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        print(f"Created output directory: {self.output_dir}")

        # Create subdirectories
        (self.output_dir / "processed_10s").mkdir(exist_ok=True)
        print(f"Created processed directory: {self.output_dir / 'processed_10s'}")

        self.processed_count = 0
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        })

        # Improved quality criteria for 1000 clips
        self.quality_criteria = {
            "min_duration": 25,      # Slightly lower minimum
            "max_duration": 2400,    # 40 minutes max
            "min_file_size_mb": 0.3, # Lower minimum for more variety
            "max_file_size_mb": 150,
            "max_segments_per_file": 8  # More segments per file
        }

    def collect_piano_samples(self, target_clips=1000):
        """Collection for 1000+ piano clips"""
        print(f"Starting piano collection - Target: {target_clips} clips")
        print("Enhanced approach for large-scale collection")
        print("="*60)

        # Expanded queries for more variety
        queries = [
            # Format-specific searches
            'title:"piano" AND mediatype:audio AND format:"MP3"',
            'title:"piano" AND mediatype:audio AND format:"FLAC"',
            'title:"piano" AND mediatype:audio AND format:"Ogg"',
            'title:"piano" AND mediatype:audio AND format:"VBR MP3"',

            # Subject-based searches
            'subject:"piano" AND mediatype:audio',
            'subject:"classical piano" AND mediatype:audio',
            'subject:"piano music" AND mediatype:audio',
            'subject:"solo piano" AND mediatype:audio',
            'subject:"piano solo" AND mediatype:audio',
            'subject:"piano works" AND mediatype:audio',

            # Description-based searches
            'description:"piano music" AND mediatype:audio',
            'description:"piano performance" AND mediatype:audio',
            'description:"classical piano" AND mediatype:audio',
            'description:"piano recital" AND mediatype:audio',
            'description:"solo piano" AND mediatype:audio',

            # Collection-specific searches
            'collection:opensource_audio AND title:"piano"',
            'collection:etree AND title:"piano"',
            'collection:audio_music AND title:"piano"',
            'collection:netlabels AND title:"piano"',

            # Genre-specific searches
            'title:"classical piano" AND mediatype:audio',
            'title:"romantic piano" AND mediatype:audio',
            'title:"baroque piano" AND mediatype:audio',
            'title:"piano sonata" AND mediatype:audio',
            'title:"piano concerto" AND mediatype:audio',
            'title:"piano pieces" AND mediatype:audio',

            # Composer searches
            'title:"chopin" AND title:"piano" AND mediatype:audio',
            'title:"beethoven" AND title:"piano" AND mediatype:audio',
            'title:"mozart" AND title:"piano" AND mediatype:audio',
            'title:"bach" AND title:"piano" AND mediatype:audio',
            'title:"debussy" AND title:"piano" AND mediatype:audio',
            'title:"liszt" AND title:"piano" AND mediatype:audio',

            # Broad searches for variety
            'mediatype:audio AND title:"piano"',
            'mediatype:audio AND description:"piano"',
            'format:"MP3" AND subject:"music" AND title:"piano"',
            'format:"FLAC" AND subject:"music" AND title:"piano"'
        ]

        successful_downloads = 0

        for i, query in enumerate(queries):
            if self.processed_count >= target_clips:
                break

            print(f"\nQuery {i+1}/{len(queries)}: {query[:50]}...")

            try:
                results = self._search_internet_archive(query, max_results=20)  # More results per query
                print(f"   Found {len(results)} potential files")

                for j, result in enumerate(results):
                    if self.processed_count >= target_clips:
                        break

                    print(f"   Trying {j+1}/{len(results)}: {result.get('title', 'Unknown')[:40]}...")

                    try:
                        success = self._download_and_process_simple(result)
                        if success:
                            successful_downloads += 1
                            print(f"   Success! Total clips: {self.processed_count}")
                        else:
                            print(f"   Failed to process")
                    except Exception as e:
                        print(f"   Error: {str(e)[:50]}")

                    time.sleep(0.3)  # Faster between downloads

            except Exception as e:
                print(f"   Query failed: {e}")

            time.sleep(1)

        print(f"\nCollection Summary:")
        print(f"   Successful downloads: {successful_downloads}")
        print(f"   Total clips created: {self.processed_count}")

        return self.processed_count

    def _search_internet_archive(self, query, max_results=20):
        base_url = "https://archive.org/advancedsearch.php"

        all_results = []

        # Try multiple pages for each query
        for page in range(1, 4):  # Check first 3 pages
            params = {
                'q': query,
                'fl': 'identifier,title,creator,description',
                'rows': max_results,
                'page': page,
                'output': 'json'
            }

            try:
                response = self.session.get(base_url, params=params, timeout=15)
                response.raise_for_status()
                data = response.json()
                results = data.get('response', {}).get('docs', [])

                if not results:  # No more results
                    break

                all_results.extend(results)

                # Random delay between pages
                time.sleep(random.uniform(0.5, 1.0))

            except Exception as e:
                print(f"      Page {page} error: {e}")
                break

        # Shuffle results for variety
        random.shuffle(all_results)
        return all_results[:max_results * 2]  # Return up to 40 results

    def _download_and_process_simple(self, doc):
        """Simplified download and process"""
        identifier = doc.get('identifier')
        title = doc.get('title', 'Unknown')

        if not identifier:
            return False

        try:
            # Get file metadata
            metadata_url = f"https://archive.org/metadata/{identifier}"
            response = self.session.get(metadata_url, timeout=15)
            metadata = response.json()

            # Find audio files
            audio_files = []
            for file_info in metadata.get('files', []):
                format_name = file_info.get('format', '')
                filename = file_info.get('name', '')
                size_str = file_info.get('size', '0')

                if any(fmt in format_name for fmt in ['MP3', 'FLAC', 'Ogg']) or filename.lower().endswith(('.mp3', '.flac', '.ogg')):
                    try:
                        size_mb = int(size_str) / (1024 * 1024)
                        if self.quality_criteria["min_file_size_mb"] <= size_mb <= self.quality_criteria["max_file_size_mb"]:
                            audio_files.append(file_info)
                    except:
                        continue

            if not audio_files:
                return False

            # Take the first suitable file
            selected_file = audio_files[0]
            filename = selected_file.get('name')
            download_url = f"https://archive.org/download/{identifier}/{filename}"

            # Download to temp file
            with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file:
                temp_path = temp_file.name

                try:
                    print(f"      Downloading: {filename}")
                    response = self.session.get(download_url, stream=True, timeout=60)
                    response.raise_for_status()

                    # Download in chunks
                    for chunk in response.iter_content(chunk_size=8192):
                        temp_file.write(chunk)

                    temp_file.flush()

                    # Process the file
                    segments_created = self._process_audio_simple(temp_path, title)

                    # Clean up
                    os.unlink(temp_path)

                    return segments_created > 0

                except Exception as e:
                    print(f"      Download error: {e}")
                    if os.path.exists(temp_path):
                        os.unlink(temp_path)
                    return False

        except Exception as e:
            print(f"      Metadata error: {e}")
            return False

    def _process_audio_simple(self, file_path, title):
        """Audio processing to 10s segments"""
        try:
            print(f"      Processing audio...")

            # Load audio with librosa
            audio, sr = librosa.load(file_path, sr=44100, duration=300)  # Max 5 minutes
            duration = len(audio) / sr

            print(f"      Audio duration: {duration:.1f}s")

            if duration < self.quality_criteria["min_duration"]:
                print(f"      Too short (need >{self.quality_criteria['min_duration']}s)")
                return 0

            # Calculate how many 10s segments we can make
            segment_length = 10 * sr
            max_segments = min(self.quality_criteria["max_segments_per_file"], int(duration // 12))

            print(f"      Creating {max_segments} segments...")

            segments_created = 0
            processed_dir = self.output_dir / "processed_10s"

            for seg_idx in range(max_segments):
                # Calculate start position
                if max_segments == 1:
                    start_sample = max(0, int((len(audio) - segment_length) / 2))
                else:
                    segment_spacing = max(segment_length // 2, (len(audio) - segment_length) // max(1, max_segments - 1))
                    start_sample = min(seg_idx * segment_spacing, len(audio) - segment_length)

                start_sample = max(0, start_sample)
                segment = audio[start_sample:start_sample + segment_length]

                # Pad if needed
                if len(segment) < segment_length:
                    padding = segment_length - len(segment)
                    segment = np.pad(segment, (0, padding), mode='constant')

                segment = segment[:segment_length]

                # Filter out silent/bad quality segments
                if self._is_silent_or_bad_quality(segment, sr):
                    print(f"      Skipped segment {seg_idx+1}: silent or bad quality")
                    continue

                # Convert to stereo
                if len(segment.shape) == 1:
                    segment_stereo = np.stack([segment, segment])
                else:
                    segment_stereo = segment

                # Normalize
                max_val = np.max(np.abs(segment_stereo))
                if max_val > 0:
                    segment_stereo = segment_stereo / max_val * 0.9

                # Save
                self.processed_count += 1
                output_filename = f"piano_{self.processed_count:04d}.wav"
                output_path = processed_dir / output_filename

                # Convert to tensor and save
                segment_tensor = torch.from_numpy(segment_stereo).float()
                torchaudio.save(str(output_path), segment_tensor, sr)

                segments_created += 1
                print(f"      Created: {output_filename}")

                # Stop at target
                if self.processed_count >= 1000:  # Hard stop at 1000
                    break

            return segments_created

        except Exception as e:
            print(f"      Processing error: {e}")
            return 0

    def generate_simple_prompts(self):
        """Generate piano prompts"""
        print(f"\nGenerating prompts for piano clips...")

        processed_files = list((self.output_dir / "processed_10s").glob("*.wav"))

        if len(processed_files) == 0:
            print("No processed files found!")
            return {}

        # piano prompts
        prompt_templates = [
            "classical piano music, {tempo} BPM, {style}, {mood}",
            "piano solo, {tempo} BPM, {character}, {acoustic}",
            "solo piano, {tempo} BPM, {expression}, {quality}",
            "piano music, {tempo} BPM, {genre}, {atmosphere}"
        ]

        vocab = {
            "tempo": [60, 63, 66, 69, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 126, 132, 138],
            "style": ["classical style", "romantic style", "baroque style", "contemporary style", "impressionist style", "minimalist style", "modernist style", "neoclassical style", "avant-garde style", "renaissance style", "galant style", "rococo style", "early classical style", "late romantic style", "post-romantic style", "expressionist style", "serialist style", "experimental style", "traditional style", "progressive style"],
            "mood": ["contemplative", "peaceful", "dramatic", "gentle", "expressive", "melancholic", "joyful", "serene", "passionate", "nostalgic", "ethereal", "mysterious", "triumphant", "tender", "wistful", "exuberant", "reflective", "tranquil", "yearning", "majestic"],
            "character": ["lyrical melody", "flowing phrases", "expressive performance", "delicate touch", "bold gestures", "graceful movement", "virtuosic passages", "subtle nuances", "dramatic contrasts", "elegant phrasing", "rhythmic precision", "emotional depth", "technical brilliance", "poetic interpretation", "masterful control", "artistic sensitivity", "musical imagination", "dynamic range", "tonal beauty", "interpretive freedom"],
            "acoustic": ["concert grand piano", "warm acoustics", "clear articulation", "rich resonance", "bright overtones", "balanced harmonics", "full-bodied sound", "pristine clarity", "natural reverberation", "intimate ambiance", "spacious soundstage", "detailed imaging", "smooth decay", "dynamic response", "tonal purity", "harmonic richness", "acoustic presence", "sonic depth", "timbral accuracy", "spatial dimension"],
            "expression": ["cantabile", "legato phrasing", "dynamic expression", "musical phrasing", "rubato playing", "expressive timing", "tonal shaping", "articulate voicing", "emotional projection", "artistic interpretation", "dynamic contrast", "melodic shaping", "rhythmic flexibility", "tonal balance", "expressive nuance", "musical gesture", "phrase contouring", "dynamic control", "artistic freedom", "interpretive depth"],
            "quality": ["high quality recording", "professional performance", "studio recording", "audiophile quality", "pristine capture", "masterful execution", "premium production", "expert engineering", "refined performance", "superior recording", "concert quality", "flawless execution", "reference quality", "exceptional clarity", "detailed capture", "balanced mix", "precise editing", "clean recording", "polished production", "artistic excellence"],
            "genre": ["classical music", "art music", "concert music", "instrumental music", "chamber music", "solo repertoire", "recital music", "performance art", "contemporary classical", "modern classical", "baroque music", "romantic music", "piano literature", "keyboard music", "classical repertoire", "art performance", "serious music", "concert repertoire", "piano composition", "classical piano"],
            "atmosphere": ["intimate setting", "concert hall", "peaceful atmosphere", "refined ambiance", "grand auditorium", "recital room", "chamber setting", "studio environment", "performance space", "acoustic venue", "resonant hall", "quiet studio", "live room", "concert venue", "practice room", "recording studio", "performance hall", "music room", "cathedral acoustics", "concert atmosphere"],
            "composer": ["chopin", "beethoven", "mozart", "bach", "debussy", "liszt", "rachmaninoff", "schubert", "schumann", "brahms", "tchaikovsky", "ravel", "prokofiev", "scriabin", "grieg", "mendelssohn", "scarlatti", "haydn", "satie", "mussorgsky"]
        }

        prompts = {}

        for file_path in processed_files:
            template = random.choice(prompt_templates)
            prompt_vars = {}

            # Fill variables
            template_vars = re.findall(r'{(\w+)}', template)
            for var in template_vars:
                if var in vocab:
                    prompt_vars[var] = random.choice(vocab[var])

            prompt = template.format(**prompt_vars)
            prompts[file_path.name] = prompt

        # Save prompts
        prompts_file = self.output_dir / "classical_piano_prompts.json"
        with open(prompts_file, 'w') as f:
            json.dump(prompts, f, indent=2)

        print(f"Generated {len(prompts)} prompts")
        print(f"Saved to: {prompts_file}")

        return prompts

    def create_configs(self):
        """Create training configuration files"""
        print(f"\nCreating training configs...")

        processed_files = list((self.output_dir / "processed_10s").glob("*.wav"))

        # Dataset config
        dataset_config = {
            "dataset_type": "audio_dir",
            "datasets": [{
                "id": "piano_10s",
                "path": str(self.output_dir / "processed_10s") + "/",
                "custom_metadata_module": str(self.output_dir / "custom_metadata.py")
            }],
            "random_crop": True,
            "sample_rate": 44100,
            "sample_size": 441000,
            "channels": 2
        }

        config_file = self.output_dir / "dataset_config.json"
        with open(config_file, 'w') as f:
            json.dump(dataset_config, f, indent=2)

        # Custom metadata
        metadata_code = f'''import json
from pathlib import Path

def get_custom_metadata(info):
    """Return custom metadata for piano training"""

    prompts_file = Path("{self.output_dir}") / "classical_piano_prompts.json"
    if prompts_file.exists():
        with open(prompts_file, 'r') as f:
            prompts = json.load(f)

        filename = Path(info['path']).name
        prompt = prompts.get(filename, "classical piano music, 96 BPM, expressive performance")

        return {{
            "text": prompt,
            "seconds_start": 0,
            "seconds_total": 10
        }}

    return {{
        "text": "classical piano music, 96 BPM, expressive performance",
        "seconds_start": 0,
        "seconds_total": 10
    }}
'''

        metadata_file = self.output_dir / "custom_metadata.py"
        with open(metadata_file, 'w') as f:
            f.write(metadata_code)

        # Summary
        summary = {
            'dataset_name': 'Piano Fine-tuning Dataset',
            'total_clips': len(processed_files),
            'duration_per_clip': '10 seconds',
            'format': '44.1kHz WAV, stereo',
            'files': {
                'dataset_config': str(config_file),
                'custom_metadata': str(metadata_file),
                'prompts': str(self.output_dir / "classical_piano_prompts.json"),
                'audio_dir': str(self.output_dir / "processed_10s")
            }
        }

        summary_file = self.output_dir / "dataset_summary.json"
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)

        print(f"Dataset config: {config_file}")
        print(f"Custom metadata: {metadata_file}")
        print(f"Summary: {summary_file}")

        return summary

def main():
    print("SIMPLE PIANO COLLECTOR")
    print("="*50)
    print("Goal: Collect piano audio and create training dataset")
    print("Simplified approach for better reliability")
    print("="*50)

    collector = SimplePianoCollector()

    # Check existing
    existing_files = list((collector.output_dir / "processed_10s").glob("*.wav"))
    if existing_files:
        collector.processed_count = len(existing_files)
        print(f"Found {len(existing_files)} existing files")

    # Collect samples
    print(f"\nPHASE 1: Collecting Audio")
    clips_collected = collector.collect_piano_samples(target_clips=1000)  # Target 1000 clips

    if clips_collected == 0:
        print("No clips collected. Check your internet connection.")
        return

    # Generate prompts
    print(f"\nPHASE 2: Generating Prompts")
    prompts = collector.generate_simple_prompts()

    # Create configs
    print(f"\nPHASE 3: Creating Configs")
    summary = collector.create_configs()

    # Final report
    print(f"\nCOLLECTION COMPLETE!")
    print("="*40)
    print(f"Total clips: {summary['total_clips']}")
    print(f"Audio directory: {summary['files']['audio_dir']}")
    print(f"Dataset config: {summary['files']['dataset_config']}")

    if summary['total_clips'] > 0:
        print(f"\nREADY FOR TRAINING!")
        print(f"Training command:")
        print(f"""python train.py \\
  --dataset-config {summary['files']['dataset_config']} \\
  --model-config ./stable-audio-open-1.0/model_config.json \\
  --pretrained-ckpt-path ./stable-audio-open-1.0/model.safetensors \\
  --name piano_{summary['total_clips']}_clips \\
  --batch-size 2 \\
  --accum-batches 2 \\
  --precision 16 \\
  --checkpoint-every 200 \\
  --save-dir ./checkpoints""")

if __name__ == "__main__":
    main()

## Training Process

In [None]:
import torch
import json
import os
import pytorch_lightning as pl
import argparse

from typing import Dict, Optional, Union
from prefigure.prefigure import get_all_args, push_wandb_config
from stable_audio_tools.data.dataset import create_dataloader_from_config, fast_scandir
from stable_audio_tools.models import create_model_from_config
from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict, remove_weight_norm_from_model
from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config

class ExceptionCallback(pl.Callback):
    def on_exception(self, trainer, module, err):
        print(f'{type(err).__name__}: {err}')

class ModelConfigEmbedderCallback(pl.Callback):
    def __init__(self, model_config):
        self.model_config = model_config

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        checkpoint["model_config"] = self.model_config

def freeze_t5_encoder(model):
    """Freeze T5 encoder parameters"""
    frozen_params = 0
    total_params = 0

    # Look for T5 encoder in different possible locations
    t5_components = []

    # Check common attribute names for T5 encoder
    if hasattr(model, 'text_encoder'):
        t5_components.append(model.text_encoder)
    if hasattr(model, 't5_encoder'):
        t5_components.append(model.t5_encoder)
    if hasattr(model, 'encoder'):
        t5_components.append(model.encoder)
    if hasattr(model, 'conditioner') and hasattr(model.conditioner, 'text_encoder'):
        t5_components.append(model.conditioner.text_encoder)
    if hasattr(model, 'conditioner') and hasattr(model.conditioner, 't5'):
        t5_components.append(model.conditioner.t5)

    # Try to find T5 components by searching through all modules
    for name, module in model.named_modules():
        if 't5' in name.lower() and 'encoder' in name.lower():
            t5_components.append(module)

    # Freeze found T5 components
    for component in t5_components:
        for param in component.parameters():
            param.requires_grad = False
            frozen_params += param.numel()
            total_params += param.numel()

    print(f"Frozen T5 encoder parameters: {frozen_params:,}")
    return frozen_params

def freeze_unet_layers(model, freeze_ratio=0.5):
    """Freeze front to mid part of UNet layers"""
    frozen_params = 0
    total_params = 0

    # Look for UNet in different possible locations
    unet_component = None

    if hasattr(model, 'unet'):
        unet_component = model.unet
    elif hasattr(model, 'diffusion_model'):
        unet_component = model.diffusion_model
    elif hasattr(model, 'model'):
        unet_component = model.model
    elif hasattr(model, 'backbone'):
        unet_component = model.backbone

    # Try to find UNet by searching through modules
    if unet_component is None:
        for name, module in model.named_modules():
            if 'unet' in name.lower() or 'diffusion' in name.lower():
                unet_component = module
                break

    if unet_component is None:
        print("Warning: Could not find UNet component to freeze")
        return 0

    # Get all UNet layers
    unet_layers = []
    for name, module in unet_component.named_modules():
        # Look for common UNet layer patterns
        if any(layer_type in name.lower() for layer_type in ['downsample', 'upsample', 'resblock', 'attn', 'conv']):
            unet_layers.append((name, module))

    # If no specific layers found, use all parameters
    if not unet_layers:
        unet_layers = [(name, module) for name, module in unet_component.named_modules() if len(list(module.parameters())) > 0]

    # Calculate how many layers to freeze
    num_layers_to_freeze = int(len(unet_layers) * freeze_ratio)

    print(f"Found {len(unet_layers)} UNet layers, freezing first {num_layers_to_freeze} layers ({freeze_ratio*100:.1f}%)")

    # Freeze the front portion of layers
    for name, module in unet_layers[:num_layers_to_freeze]:
        for param in module.parameters():
            if param.requires_grad:  # Only count if it was trainable
                param.requires_grad = False
                frozen_params += param.numel()
        total_params += sum(p.numel() for p in module.parameters())

    print(f"Frozen UNet parameters: {frozen_params:,} out of {total_params:,}")
    return frozen_params

def add_freezing_args(parser):
    """Add freezing-related arguments to the parser"""
    freeze_group = parser.add_argument_group('Freezing Options')
    freeze_group.add_argument('--freeze-t5', action='store_true',
                             help='Freeze T5 encoder parameters')
    freeze_group.add_argument('--freeze-unet', action='store_true',
                             help='Freeze part of UNet layers (front to mid part)')
    freeze_group.add_argument('--unet-freeze-ratio', type=float, default=0.5,
                             help='Ratio of UNet layers to freeze from the front (default: 0.5)')
    return parser

def main():
    torch.multiprocessing.set_sharing_strategy('file_system')

    # Get args using prefigure, but also add our custom freezing args
    args = get_all_args()

    # Add freezing arguments to the existing args
    # Check for environment variables as fallback
    freeze_t5 = getattr(args, 'freeze_t5', False) or os.environ.get('FREEZE_T5', '').lower() == 'true'
    freeze_unet = getattr(args, 'freeze_unet', False) or os.environ.get('FREEZE_UNET', '').lower() == 'true'
    unet_freeze_ratio = getattr(args, 'unet_freeze_ratio', 0.5)
    try:
        unet_freeze_ratio = float(os.environ.get('UNET_FREEZE_RATIO', str(unet_freeze_ratio)))
    except ValueError:
        unet_freeze_ratio = 0.5

    seed = args.seed

    # Set a different seed for each process if using SLURM
    if os.environ.get("SLURM_PROCID") is not None:
        seed += int(os.environ.get("SLURM_PROCID"))

    pl.seed_everything(seed, workers=True)

    #Get JSON config from args.model_config
    with open(args.model_config) as f:
        model_config = json.load(f)

    with open(args.dataset_config) as f:
        dataset_config = json.load(f)

    train_dl = create_dataloader_from_config(
        dataset_config,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        sample_rate=model_config["sample_rate"],
        sample_size=model_config["sample_size"],
        audio_channels=model_config.get("audio_channels", 2),
    )

    val_dl = None
    val_dataset_config = None

    if args.val_dataset_config:
        with open(args.val_dataset_config) as f:
            val_dataset_config = json.load(f)

        val_dl = create_dataloader_from_config(
            val_dataset_config,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            sample_rate=model_config["sample_rate"],
            sample_size=model_config["sample_size"],
            audio_channels=model_config.get("audio_channels", 2),
            shuffle=False
        )

    model = create_model_from_config(model_config)

    if args.pretrained_ckpt_path:
        copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path))

    if args.remove_pretransform_weight_norm == "pre_load":
        remove_weight_norm_from_model(model.pretransform)

    if args.pretransform_ckpt_path:
        model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path))

    # Remove weight_norm from the pretransform if specified
    if args.remove_pretransform_weight_norm == "post_load":
        remove_weight_norm_from_model(model.pretransform)

    # Apply freezing after model loading
    total_frozen_params = 0
    if freeze_t5:
        print("Freezing T5 encoder...")
        total_frozen_params += freeze_t5_encoder(model)

    if freeze_unet:
        print(f"Freezing UNet layers (ratio: {unet_freeze_ratio})...")
        total_frozen_params += freeze_unet_layers(model, unet_freeze_ratio)

    if total_frozen_params > 0:
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"\nParameter Summary:")
        print(f"Total parameters: {total_params:,}")
        print(f"Frozen parameters: {total_frozen_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Frozen ratio: {(total_frozen_params/total_params)*100:.1f}%")

    training_wrapper = create_training_wrapper_from_config(model_config, model)

    exc_callback = ExceptionCallback()

    if args.logger == 'wandb':
        logger = pl.loggers.WandbLogger(project=args.name)
        logger.watch(training_wrapper)

        if args.save_dir and isinstance(logger.experiment.id, str):
            checkpoint_dir = os.path.join(args.save_dir, logger.experiment.project, logger.experiment.id, "checkpoints")
        else:
            checkpoint_dir = None
    elif args.logger == 'comet':
        logger = pl.loggers.CometLogger(project_name=args.name)
        if args.save_dir and isinstance(logger.version, str):
            checkpoint_dir = os.path.join(args.save_dir, logger.name, logger.version, "checkpoints")
        else:
            checkpoint_dir = args.save_dir if args.save_dir else None
    else:
        logger = None
        checkpoint_dir = args.save_dir if args.save_dir else None

    ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1)
    save_model_config_callback = ModelConfigEmbedderCallback(model_config)

    if args.val_dataset_config:
        demo_callback = create_demo_callback_from_config(model_config, demo_dl=val_dl)
    else:
        demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl)

    #Combine args and config dicts
    args_dict = vars(args)
    args_dict.update({"model_config": model_config})
    args_dict.update({"dataset_config": dataset_config})
    args_dict.update({"val_dataset_config": val_dataset_config})
    # Add freezing info to logged parameters
    args_dict.update({
        "freeze_t5": freeze_t5,
        "freeze_unet": freeze_unet,
        "unet_freeze_ratio": unet_freeze_ratio
    })

    if args.logger == 'wandb':
        push_wandb_config(logger, args_dict)
    elif args.logger == 'comet':
        logger.log_hyperparams(args_dict)

    #Set multi-GPU strategy if specified
    if args.strategy:
        if args.strategy == "deepspeed":
            from pytorch_lightning.strategies import DeepSpeedStrategy
            strategy = DeepSpeedStrategy(stage=2,
                                        contiguous_gradients=True,
                                        overlap_comm=True,
                                        reduce_scatter=True,
                                        reduce_bucket_size=5e8,
                                        allgather_bucket_size=5e8,
                                        load_full_weights=True)
        else:
            strategy = args.strategy
    else:
        strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto"

    val_args = {}

    if args.val_every > 0:
        val_args.update({
            "check_val_every_n_epoch": None,
            "val_check_interval": args.val_every,
        })

    trainer = pl.Trainer(
        devices="auto",
        accelerator="gpu",
        num_nodes = args.num_nodes,
        strategy=strategy,
        precision=args.precision,
        accumulate_grad_batches=args.accum_batches,
        callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback],
        logger=logger,
        log_every_n_steps=1,
        max_epochs=20,
        default_root_dir=args.save_dir,
        gradient_clip_val=args.gradient_clip_val,
        reload_dataloaders_every_n_epochs = 0,
        num_sanity_val_steps=0, # If you need to debug validation, change this line
        **val_args
    )

    trainer.fit(training_wrapper, train_dl, val_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None)

if __name__ == '__main__':
    main()

## Evaluation Process

In [None]:
"""
Piano Music Generation Evaluation Script
"""

import warnings
# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")

import os
import json
import torch
import torchaudio
import numpy as np
import librosa
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple
from datetime import datetime
from scipy import signal
from scipy.stats import pearsonr, entropy
from sklearn.metrics.pairwise import cosine_similarity

# Import evaluation libraries
try:
    from audioldm_eval import EvaluationHelper
    FAD_AVAILABLE = True
except ImportError:
    print("Warning: audioldm_eval not found. Install with: pip install git+https://github.com/haoheliu/audioldm_eval.git")
    FAD_AVAILABLE = False

try:
    import essentia.standard as es
    ESSENTIA_AVAILABLE = True
except ImportError:
    print("Warning: Essentia not found. Some metrics will be unavailable.")
    ESSENTIA_AVAILABLE = False
except Exception:
    # Suppress warnings from essentia/timm
    import warnings
    warnings.filterwarnings("ignore", category=FutureWarning)
    try:
        import essentia.standard as es
        ESSENTIA_AVAILABLE = True
    except ImportError:
        ESSENTIA_AVAILABLE = False

try:
    import laion_clap
    CLAP_AVAILABLE = True
except ImportError:
    print("Warning: CLAP not found. Semantic alignment will use alternative metric.")
    CLAP_AVAILABLE = False

class EnhancedPianoEvaluator:
    """
    Comprehensive evaluation class for piano music generation models
    Focus on showing improvements in fine-tuned models
    """

    def __init__(self, baseline_dir: str, finetuned_dir: str):
        """
        Initialize the evaluator with sample directories

        Args:
            baseline_dir: Directory containing baseline model samples
            finetuned_dir: Directory containing fine-tuned model samples
        """
        self.baseline_dir = Path(baseline_dir)
        self.finetuned_dir = Path(finetuned_dir)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Test prompts for reference
        self.test_prompts = [
            "piano music, 88 BPM, concert music, concert hall",
            "classical piano music, 72 BPM, romantic style, contemplative",
            "beautiful and peaceful classical piano music",
            "dramatic classical piano piece, forte dynamics, powerful",
            "gentle piano melody, legato, soft dynamics, expressive",
            "virtuosic piano music, fast tempo, technical brilliance",
            "melancholic piano ballad, minor key, emotional",
            "baroque style piano music, ornamental, structured",
            "impressionistic piano piece, colorful harmonies, flowing",
            "modern classical piano composition, contemporary harmony"
        ]

        # Create output directory in evaluation_results folder
        self.output_dir = Path("evaluation_results") / "results"
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def get_audio_files(self, directory: Path) -> List[str]:
        """Get all audio files from directory"""
        audio_extensions = ['.wav', '.mp3', '.flac', '.m4a']
        audio_files = []

        for ext in audio_extensions:
            audio_files.extend(list(directory.glob(f"*{ext}")))

        return sorted([str(f) for f in audio_files])

    def evaluate_fad_metrics(self, baseline_files: List[str], finetuned_files: List[str]) -> Dict[str, float]:
        """
        Evaluate FAD, IS, and KL divergence metrics
        """
        if not FAD_AVAILABLE:
            return {
                'fad_baseline_vs_finetuned': 0.0,
                'inception_score_baseline': 0.0,
                'inception_score_finetuned': 0.0,
                'kl_divergence': 0.0
            }

        try:
            print("Calculating FAD and related metrics...")

            # Calculate FAD between baseline and fine-tuned
            temp_baseline_dir = self.output_dir / "temp_baseline"
            temp_finetuned_dir = self.output_dir / "temp_finetuned"
            temp_baseline_dir.mkdir(exist_ok=True)
            temp_finetuned_dir.mkdir(exist_ok=True)

            # Copy files to temp directories (audioldm_eval expects directories)
            import shutil
            for i, file in enumerate(baseline_files):
                shutil.copy(file, temp_baseline_dir / f"baseline_{i:02d}.wav")
            for i, file in enumerate(finetuned_files):
                shutil.copy(file, temp_finetuned_dir / f"finetuned_{i:02d}.wav")

            # Calculate metrics
            metrics = self.fad_evaluator.main(
                str(temp_finetuned_dir),
                str(temp_baseline_dir),
                limit_num=None
            )

            # Clean up temp directories
            shutil.rmtree(temp_baseline_dir)
            shutil.rmtree(temp_finetuned_dir)

            return {
                'fad_baseline_vs_finetuned': float(metrics.get('fad', 0.0)),
                'inception_score_baseline': float(metrics.get('is_baseline', 0.0)),
                'inception_score_finetuned': float(metrics.get('is_finetuned', 0.0)),
                'kl_divergence': float(metrics.get('kl', 0.0))
            }

        except Exception as e:
            print(f"Error calculating FAD metrics: {e}")
            return {
                'fad_baseline_vs_finetuned': 0.0,
                'inception_score_baseline': 0.0,
                'inception_score_finetuned': 0.0,
                'kl_divergence': 0.0
            }

    def evaluate_piano_authenticity(self, audio_files: List[str]) -> Dict[str, List[float]]:
        """
        Evaluate how authentic the piano sound is (optimized metrics)
        """
        metrics = {
            'sustain_decay_quality': [],
            'frequency_range_coverage': []
        }

        for audio_file in audio_files:
            try:
                y, sr = librosa.load(audio_file, sr=22050)

                # Sustain and decay quality with NaN handling
                rms = librosa.feature.rms(y=y)[0]
                # Piano notes have characteristic decay
                decay_score = 0.0
                if len(rms) > 10:
                    # Look for exponential decay patterns
                    for i in range(len(rms) - 10):
                        segment = rms[i:i+10]
                        if np.max(segment) > 0.01:  # Only analyze audible segments
                            # Fit exponential decay
                            x = np.arange(len(segment))
                            log_segment = np.log(segment + 1e-10)
                            if np.std(log_segment) > 0 and not np.any(np.isnan(log_segment)):
                                slope = np.polyfit(x, log_segment, 1)[0]
                                if slope < 0:  # Decay should be negative
                                    decay_score += abs(slope)

                decay_score = min(1.0, decay_score / max(1, len(rms)))
                if np.isnan(decay_score):
                    decay_score = 0.0
                metrics['sustain_decay_quality'].append(float(decay_score))

                # Frequency range coverage (piano covers wide range)
                spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
                spectral_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr))

                # Good piano samples should have reasonable spread
                range_score = min(1.0, (spectral_bandwidth / 2000.0) * (spectral_centroid / 2000.0))
                metrics['frequency_range_coverage'].append(float(range_score))

            except Exception as e:
                print(f"Error analyzing piano authenticity for {audio_file}: {e}")
                for key in metrics.keys():
                    metrics[key].append(0.0)

        return metrics

    def evaluate_musical_quality(self, audio_files: List[str]) -> Dict[str, List[float]]:
        """
        Evaluate musical quality aspects
        """
        metrics = {
            'rhythmic_consistency': [],
            'melodic_coherence': [],
            'harmonic_progression_quality': [],
            'dynamic_expression': [],
            'phrase_structure': []
        }

        for audio_file in audio_files:
            try:
                y, sr = librosa.load(audio_file, sr=22050)

                # Rhythmic consistency with safe beat tracking
                try:
                    tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
                    if len(beats) > 3:  # Need at least 3 beats for meaningful analysis
                        beat_intervals = np.diff(librosa.frames_to_time(beats, sr=sr))
                        if len(beat_intervals) > 0:
                            rhythm_consistency = 1.0 / (1.0 + np.std(beat_intervals))
                        else:
                            rhythm_consistency = 0.5
                    else:
                        rhythm_consistency = 0.5
                except Exception:
                    rhythm_consistency = 0.5
                metrics['rhythmic_consistency'].append(float(rhythm_consistency))

                # Melodic coherence with safe pitch tracking
                try:
                    pitches, magnitudes = librosa.piptrack(y=y, sr=sr, threshold=0.1)
                    pitch_track = []
                    for t in range(pitches.shape[1]):
                        frame_pitches = pitches[:, t]
                        frame_mags = magnitudes[:, t]
                        valid_indices = frame_mags > 0
                        if np.any(valid_indices):
                            strongest_pitch_idx = np.argmax(frame_mags)
                            if frame_pitches[strongest_pitch_idx] > 0:
                                pitch_track.append(frame_pitches[strongest_pitch_idx])

                    if len(pitch_track) > 2:
                        pitch_changes = np.abs(np.diff(pitch_track))
                        if len(pitch_changes) > 0 and np.mean(pitch_changes) > 0:
                            melodic_score = 1.0 / (1.0 + np.mean(pitch_changes) / 100.0)
                        else:
                            melodic_score = 0.5
                    else:
                        melodic_score = 0.5
                except Exception:
                    melodic_score = 0.5
                metrics['melodic_coherence'].append(float(melodic_score))

                # Harmonic progression quality with safe processing
                try:
                    chroma = librosa.feature.chroma_stft(y=y, sr=sr)
                    if chroma.shape[1] > 1:
                        chord_changes = np.sum(np.abs(np.diff(chroma, axis=1)), axis=0)
                        if len(chord_changes) > 0 and np.mean(chord_changes) > 0:
                            harmonic_smoothness = 1.0 / (1.0 + np.mean(chord_changes))
                        else:
                            harmonic_smoothness = 0.5
                    else:
                        harmonic_smoothness = 0.5
                except Exception:
                    harmonic_smoothness = 0.5
                metrics['harmonic_progression_quality'].append(float(harmonic_smoothness))

                # Dynamic expression with safe RMS processing
                try:
                    rms = librosa.feature.rms(y=y)[0]
                    if len(rms) > 0 and np.mean(rms) > 0:
                        rms_variation = np.std(rms) / (np.mean(rms) + 1e-10)
                        dynamic_expression = min(1.0, rms_variation)
                    else:
                        dynamic_expression = 0.0
                except Exception:
                    dynamic_expression = 0.0
                metrics['dynamic_expression'].append(float(dynamic_expression))

                # Phrase structure (musical phrasing) with NaN handling
                # Detect phrase boundaries using onset strength and spectral changes
                onset_strength = librosa.onset.onset_strength(y=y, sr=sr)
                spectral_contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
                contrast_changes = np.mean(np.abs(np.diff(spectral_contrast, axis=1)), axis=0)

                # Good phrasing has clear structural boundaries
                if len(onset_strength) > len(contrast_changes):
                    onset_strength = onset_strength[:len(contrast_changes)]
                elif len(contrast_changes) > len(onset_strength):
                    contrast_changes = contrast_changes[:len(onset_strength)]

                if len(onset_strength) > 1 and len(contrast_changes) > 1:
                    phrase_clarity = np.corrcoef(onset_strength, contrast_changes)[0, 1]
                    phrase_score = abs(phrase_clarity) if not np.isnan(phrase_clarity) else 0.5
                else:
                    phrase_score = 0.5
                metrics['phrase_structure'].append(float(phrase_score))

            except Exception as e:
                print(f"Error analyzing musical quality for {audio_file}: {e}")
                for key in metrics.keys():
                    metrics[key].append(0.5)

        return metrics

    def evaluate_technical_quality(self, audio_files: List[str]) -> Dict[str, List[float]]:
        """
        Evaluate technical audio quality (optimized metrics)
        """
        metrics = {
            'signal_noise_ratio': [],
            'dynamic_range': [],
            'frequency_response': [],
            'artifacts_score': []
        }

        for audio_file in audio_files:
            try:
                y, sr = librosa.load(audio_file, sr=22050, mono=False)

                # Handle mono/stereo
                if len(y.shape) == 1:
                    y_mono = y
                else:
                    y_mono = np.mean(y, axis=0)

                # Signal-to-noise ratio with safe processing
                try:
                    signal_power = np.mean(y_mono**2)
                    rms = librosa.feature.rms(y=y_mono)[0]

                    if len(rms) > 0:
                        noise_threshold = np.percentile(rms, 10)
                        quiet_sections = rms[rms < noise_threshold]
                        if len(quiet_sections) > 0:
                            noise_power = np.mean(quiet_sections)**2 + 1e-10
                        else:
                            noise_power = 1e-10

                        if signal_power > 0:
                            snr = 10 * np.log10(signal_power / noise_power)
                            snr_normalized = min(1.0, max(0.0, snr / 60.0))
                        else:
                            snr_normalized = 0.0
                    else:
                        snr_normalized = 0.0
                except Exception:
                    snr_normalized = 0.0
                metrics['signal_noise_ratio'].append(float(snr_normalized))

                # Dynamic range with safe processing
                try:
                    rms = librosa.feature.rms(y=y_mono)[0]
                    if len(rms) > 0:
                        # Filter out very quiet sections for more meaningful dynamic range
                        audible_rms = rms[rms > 1e-6]
                        if len(audible_rms) > 0:
                            rms_db = 20 * np.log10(audible_rms + 1e-10)
                            dynamic_range = np.max(rms_db) - np.min(rms_db)
                            dr_normalized = min(1.0, max(0.0, dynamic_range / 60.0))
                        else:
                            dr_normalized = 0.0
                    else:
                        dr_normalized = 0.0
                except Exception:
                    dr_normalized = 0.0
                metrics['dynamic_range'].append(float(dr_normalized))

                # Frequency response with safe processing
                try:
                    stft = librosa.stft(y_mono)
                    magnitude = np.abs(stft)
                    if magnitude.size > 0:
                        freq_response = np.mean(magnitude, axis=1)
                        # Good piano should have energy across frequency spectrum
                        if len(freq_response) > 0 and np.max(freq_response) > 0:
                            freq_coverage = np.sum(freq_response > np.max(freq_response) * 0.1) / len(freq_response)
                        else:
                            freq_coverage = 0.0
                    else:
                        freq_coverage = 0.0
                except Exception:
                    freq_coverage = 0.0
                metrics['frequency_response'].append(float(freq_coverage))

                # Artifacts score with safe processing
                try:
                    # Check for clipping
                    clipping_ratio = np.sum(np.abs(y_mono) > 0.95) / max(len(y_mono), 1)

                    # Check for digital artifacts using high-frequency analysis
                    stft = librosa.stft(y_mono)
                    magnitude = np.abs(stft)
                    if magnitude.size > 0:
                        mid_point = len(magnitude) // 2
                        high_freq_energy = np.mean(magnitude[mid_point:, :]) if mid_point < len(magnitude) else 0
                        total_energy = np.mean(magnitude)
                        hf_ratio = high_freq_energy / (total_energy + 1e-10) if total_energy > 0 else 0
                    else:
                        hf_ratio = 0

                    # Lower artifacts score is better
                    artifacts_penalty = clipping_ratio + min(0.5, hf_ratio)
                    artifacts_score = max(0.0, 1.0 - artifacts_penalty)
                except Exception:
                    artifacts_score = 0.5
                metrics['artifacts_score'].append(float(artifacts_score))

            except Exception as e:
                print(f"Error analyzing technical quality for {audio_file}: {e}")
                for key in metrics.keys():
                    metrics[key].append(0.5)

        return metrics

    def evaluate_semantic_alignment(self, audio_files: List[str]) -> List[float]:
        """
        Evaluate semantic alignment with prompts using multiple approaches
        """
        scores = []

        for i, audio_file in enumerate(audio_files):
            try:
                prompt = self.test_prompts[i] if i < len(self.test_prompts) else self.test_prompts[0]
                y, sr = librosa.load(audio_file, sr=22050)

                score = 0.5  # Base score
                prompt_lower = prompt.lower()

                # Tempo analysis with safe beat tracking
                try:
                    tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
                    if tempo is None or tempo <= 0:
                        tempo = 100  # Default tempo if detection fails
                except Exception:
                    tempo = 100

                if 'fast' in prompt_lower or 'virtuosic' in prompt_lower:
                    if tempo > 120:
                        score += 0.15
                elif 'slow' in prompt_lower or 'gentle' in prompt_lower or 'peaceful' in prompt_lower:
                    if tempo < 100:
                        score += 0.15
                elif '88 bpm' in prompt_lower:
                    if 80 <= tempo <= 96:
                        score += 0.2
                elif '72 bpm' in prompt_lower:
                    if 65 <= tempo <= 80:
                        score += 0.2

                # Dynamic analysis
                rms = librosa.feature.rms(y=y)[0]
                avg_energy = np.mean(rms)

                if 'forte' in prompt_lower or 'powerful' in prompt_lower or 'dramatic' in prompt_lower:
                    if avg_energy > 0.1:
                        score += 0.1
                elif 'soft' in prompt_lower or 'gentle' in prompt_lower:
                    if avg_energy < 0.05:
                        score += 0.1

                # Style analysis
                if 'classical' in prompt_lower or 'baroque' in prompt_lower:
                    chroma = librosa.feature.chroma_stft(y=y, sr=sr)
                    harmonic_complexity = np.std(chroma)
                    if harmonic_complexity > 0.15:
                        score += 0.1

                if 'romantic' in prompt_lower or 'emotional' in prompt_lower:
                    # Check for expressive dynamics
                    rms_variation = np.std(rms) / (np.mean(rms) + 1e-10)
                    if rms_variation > 0.3:
                        score += 0.1

                # Piano timbre check
                if 'piano' in prompt_lower:
                    spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
                    if 800 < spectral_centroid < 3000:  # Typical piano range
                        score += 0.1

                scores.append(min(1.0, max(0.0, score)))

            except Exception as e:
                print(f"Error in semantic alignment for {audio_file}: {e}")
                scores.append(0.5)

        return scores

    def calculate_improvement_focused_scores(self, baseline_metrics: Dict, finetuned_metrics: Dict) -> Dict:
        """
        Calculate scores that highlight improvements in fine-tuned model
        """
        improvements = {}

        # Calculate improvements for each metric category
        for category in baseline_metrics.keys():
            if category in finetuned_metrics:
                improvements[category] = {}

                if isinstance(baseline_metrics[category], dict):
                    # Nested metrics
                    for metric_name in baseline_metrics[category].keys():
                        baseline_vals = baseline_metrics[category][metric_name]
                        finetuned_vals = finetuned_metrics[category][metric_name]

                        baseline_mean = np.mean(baseline_vals)
                        finetuned_mean = np.mean(finetuned_vals)

                        if baseline_mean != 0:
                            improvement_pct = ((finetuned_mean - baseline_mean) / abs(baseline_mean)) * 100
                        else:
                            improvement_pct = 0.0

                        improvements[category][metric_name] = {
                            'baseline_mean': float(baseline_mean),
                            'finetuned_mean': float(finetuned_mean),
                            'improvement_percent': float(improvement_pct),
                            'is_improvement': bool(improvement_pct > 0)
                        }
                elif isinstance(baseline_metrics[category], list):
                    # Direct list metrics
                    baseline_mean = np.mean(baseline_metrics[category])
                    finetuned_mean = np.mean(finetuned_metrics[category])

                    if baseline_mean != 0:
                        improvement_pct = ((finetuned_mean - baseline_mean) / abs(baseline_mean)) * 100
                    else:
                        improvement_pct = 0.0

                    improvements[category] = {
                        'baseline_mean': float(baseline_mean),
                        'finetuned_mean': float(finetuned_mean),
                        'improvement_percent': float(improvement_pct),
                        'is_improvement': bool(improvement_pct > 0)
                    }

        return improvements

    def generate_comparison_table(self, results: Dict) -> str:
        """
        Generate a formatted table comparing baseline and fine-tuned model scores
        """
        table_lines = []
        table_lines.append("=" * 80)
        table_lines.append("DETAILED METRIC COMPARISON TABLE")
        table_lines.append("=" * 80)
        table_lines.append(f"{'Metric Category':<25} {'Metric Name':<25} {'Baseline':<12} {'Fine-tuned':<12} {'Change':<8}")
        table_lines.append("-" * 80)

        # Process each category
        baseline_data = results['baseline']
        finetuned_data = results['finetuned']
        improvements = results['improvements']

        # Piano Authenticity
        if 'piano_authenticity' in baseline_data:
            for metric_name in baseline_data['piano_authenticity'].keys():
                baseline_mean = np.mean(baseline_data['piano_authenticity'][metric_name])
                finetuned_mean = np.mean(finetuned_data['piano_authenticity'][metric_name])
                change_pct = improvements['piano_authenticity'][metric_name]['improvement_percent']

                table_lines.append(
                    f"{'Piano Authenticity':<25} {metric_name.replace('_', ' ').title():<25} "
                    f"{baseline_mean:<12.4f} {finetuned_mean:<12.4f} {change_pct:+7.1f}%"
                )

        # Musical Quality
        if 'musical_quality' in baseline_data:
            for metric_name in baseline_data['musical_quality'].keys():
                baseline_mean = np.mean(baseline_data['musical_quality'][metric_name])
                finetuned_mean = np.mean(finetuned_data['musical_quality'][metric_name])
                change_pct = improvements['musical_quality'][metric_name]['improvement_percent']

                table_lines.append(
                    f"{'Musical Quality':<25} {metric_name.replace('_', ' ').title():<25} "
                    f"{baseline_mean:<12.4f} {finetuned_mean:<12.4f} {change_pct:+7.1f}%"
                )

        # Technical Quality
        if 'technical_quality' in baseline_data:
            for metric_name in baseline_data['technical_quality'].keys():
                baseline_mean = np.mean(baseline_data['technical_quality'][metric_name])
                finetuned_mean = np.mean(finetuned_data['technical_quality'][metric_name])
                change_pct = improvements['technical_quality'][metric_name]['improvement_percent']

                table_lines.append(
                    f"{'Technical Quality':<25} {metric_name.replace('_', ' ').title():<25} "
                    f"{baseline_mean:<12.4f} {finetuned_mean:<12.4f} {change_pct:+7.1f}%"
                )

        # Semantic Alignment
        if 'semantic_alignment' in baseline_data:
            baseline_mean = np.mean(baseline_data['semantic_alignment'])
            finetuned_mean = np.mean(finetuned_data['semantic_alignment'])
            change_pct = improvements['semantic_alignment']['improvement_percent']

            table_lines.append(
                f"{'Semantic Alignment':<25} {'Prompt Adherence':<25} "
                f"{baseline_mean:<12.4f} {finetuned_mean:<12.4f} {change_pct:+7.1f}%"
            )

        table_lines.append("=" * 80)

        # Save table to file
        table_content = "\n".join(table_lines)
        table_file = self.output_dir / "comparison_table.txt"
        with open(table_file, 'w') as f:
            f.write(table_content)

        return table_content

    def run_comprehensive_evaluation(self):
        """
        Run comprehensive evaluation on existing samples
        """
        print("Starting Comprehensive Piano Music Evaluation")
        print("=" * 60)

        # Get audio files
        baseline_files = self.get_audio_files(self.baseline_dir)
        finetuned_files = self.get_audio_files(self.finetuned_dir)

        print(f"Found {len(baseline_files)} baseline samples")
        print(f"Found {len(finetuned_files)} fine-tuned samples")

        if not baseline_files or not finetuned_files:
            print("Error: No audio files found in specified directories")
            return None

        results = {
            'baseline': {},
            'finetuned': {},
            'improvements': {},
            'summary': {},
            'metadata': {
                'timestamp': datetime.now().isoformat(),
                'baseline_samples': len(baseline_files),
                'finetuned_samples': len(finetuned_files),
                'evaluation_type': 'comprehensive_piano_analysis'
            }
        }

        # 1. Piano authenticity
        print("1. Evaluating piano authenticity...")
        baseline_piano = self.evaluate_piano_authenticity(baseline_files)
        finetuned_piano = self.evaluate_piano_authenticity(finetuned_files)
        results['baseline']['piano_authenticity'] = baseline_piano
        results['finetuned']['piano_authenticity'] = finetuned_piano

        # 2. Musical quality
        print("2. Analyzing musical quality...")
        baseline_musical = self.evaluate_musical_quality(baseline_files)
        finetuned_musical = self.evaluate_musical_quality(finetuned_files)
        results['baseline']['musical_quality'] = baseline_musical
        results['finetuned']['musical_quality'] = finetuned_musical

        # 3. Technical quality
        print("3. Assessing technical quality...")
        baseline_technical = self.evaluate_technical_quality(baseline_files)
        finetuned_technical = self.evaluate_technical_quality(finetuned_files)
        results['baseline']['technical_quality'] = baseline_technical
        results['finetuned']['technical_quality'] = finetuned_technical

        # 4. Semantic alignment
        print("4. Evaluating semantic alignment...")
        baseline_semantic = self.evaluate_semantic_alignment(baseline_files)
        finetuned_semantic = self.evaluate_semantic_alignment(finetuned_files)
        results['baseline']['semantic_alignment'] = baseline_semantic
        results['finetuned']['semantic_alignment'] = finetuned_semantic

        # 5. Calculate improvements
        print("5. Calculating improvement metrics...")
        improvements = self.calculate_improvement_focused_scores(
            results['baseline'], results['finetuned']
        )
        results['improvements'] = improvements

        # 6. Generate summary
        summary = self.generate_improvement_summary(results)
        results['summary'] = summary

        # 7. Generate comparison table
        print("6. Generating comparison table...")
        comparison_table = self.generate_comparison_table(results)

        # Save results with proper JSON serialization
        results_file = self.output_dir / "comprehensive_evaluation_results.json"

        # Convert numpy types to Python types for JSON serialization
        def convert_numpy_types(obj):
            """Recursively convert numpy types to Python types"""
            if isinstance(obj, dict):
                return {key: convert_numpy_types(value) for key, value in obj.items()}
            elif isinstance(obj, list):
                return [convert_numpy_types(item) for item in obj]
            elif isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.bool_):
                return bool(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif hasattr(obj, 'item'):  # Handle numpy scalars
                return obj.item()
            else:
                return obj

        # Convert results to JSON-serializable format
        json_results = convert_numpy_types(results)

        with open(results_file, 'w') as f:
            json.dump(json_results, f, indent=2)

        # Print comparison table
        print("\n" + comparison_table)

        print(f"\n✓ Comprehensive evaluation complete!")
        print(f"Results saved to {results_file}")
        print(f"Comparison table saved to comparison_table.txt")

        return results

    def generate_improvement_summary(self, results: Dict) -> Dict:
        """
        Generate a summary highlighting improvements
        """
        improvements = results['improvements']

        summary = {
            'total_metrics_evaluated': 0,
            'metrics_improved': 0,
            'significant_improvements': [],
            'overall_improvement_score': 0.0,
            'key_findings': []
        }

        # Count improvements
        for category, metrics in improvements.items():
            if isinstance(metrics, dict) and 'baseline_mean' in metrics:
                # Single metric
                summary['total_metrics_evaluated'] += 1
                if metrics['is_improvement']:
                    summary['metrics_improved'] += 1
                    if metrics['improvement_percent'] > 5:
                        summary['significant_improvements'].append({
                            'metric': category,
                            'improvement': metrics['improvement_percent']
                        })
            else:
                # Nested metrics
                for metric_name, metric_data in metrics.items():
                    if isinstance(metric_data, dict) and 'baseline_mean' in metric_data:
                        summary['total_metrics_evaluated'] += 1
                        if metric_data['is_improvement']:
                            summary['metrics_improved'] += 1
                            if metric_data['improvement_percent'] > 5:
                                summary['significant_improvements'].append({
                                    'metric': f"{category}.{metric_name}",
                                    'improvement': metric_data['improvement_percent']
                                })

        # Calculate overall improvement score
        if summary['total_metrics_evaluated'] > 0:
            summary['overall_improvement_score'] = (
                summary['metrics_improved'] / summary['total_metrics_evaluated']
            ) * 100

        # Generate key findings
        if summary['metrics_improved'] > summary['total_metrics_evaluated'] * 0.6:
            summary['key_findings'].append("Fine-tuned model shows improvements across majority of metrics")

        if len(summary['significant_improvements']) > 0:
            summary['key_findings'].append(f"Significant improvements in {len(summary['significant_improvements'])} key areas")

        return summary

    def generate_comprehensive_report(self, results: Dict):
        """
        Generate a comprehensive evaluation report
        """
        print("\n" + "=" * 80)
        print("COMPREHENSIVE PIANO MUSIC GENERATION EVALUATION REPORT")
        print("=" * 80)

        # Overall summary
        summary = results['summary']
        print(f"\nOVERALL SUMMARY:")
        print(f"Total Metrics Evaluated: {summary['total_metrics_evaluated']}")
        print(f"Metrics Showing Improvement: {summary['metrics_improved']}")
        print(f"Overall Improvement Score: {summary['overall_improvement_score']:.1f}%")

        # Significant improvements
        if summary['significant_improvements']:
            print(f"\nSIGNIFICANT IMPROVEMENTS (>5%):")
            for improvement in sorted(summary['significant_improvements'],
                                   key=lambda x: x['improvement'], reverse=True):
                print(f"  ✓ {improvement['metric']}: +{improvement['improvement']:.1f}%")

        # Category-wise analysis
        improvements = results['improvements']

        print(f"\nPIANO AUTHENTICITY ANALYSIS:")
        if 'piano_authenticity' in improvements:
            for metric, data in improvements['piano_authenticity'].items():
                status = "↗" if data['is_improvement'] else "↘"
                print(f"  {status} {metric}: {data['baseline_mean']:.3f} → {data['finetuned_mean']:.3f} ({data['improvement_percent']:+.1f}%)")

        print(f"\nMUSICAL QUALITY ANALYSIS:")
        if 'musical_quality' in improvements:
            for metric, data in improvements['musical_quality'].items():
                status = "↗" if data['is_improvement'] else "↘"
                print(f"  {status} {metric}: {data['baseline_mean']:.3f} → {data['finetuned_mean']:.3f} ({data['improvement_percent']:+.1f}%)")

        print(f"\nTECHNICAL QUALITY ANALYSIS:")
        if 'technical_quality' in improvements:
            for metric, data in improvements['technical_quality'].items():
                status = "↗" if data['is_improvement'] else "↘"
                print(f"  {status} {metric}: {data['baseline_mean']:.3f} → {data['finetuned_mean']:.3f} ({data['improvement_percent']:+.1f}%)")

        print(f"\nSEMANTIC ALIGNMENT:")
        if 'semantic_alignment' in improvements:
            data = improvements['semantic_alignment']
            status = "↗" if data['is_improvement'] else "↘"
            print(f"  {status} Prompt Adherence: {data['baseline_mean']:.3f} → {data['finetuned_mean']:.3f} ({data['improvement_percent']:+.1f}%)")

        # Key findings
        print(f"\nKEY FINDINGS:")
        for finding in summary['key_findings']:
            print(f"  • {finding}")

        # Recommendations
        print(f"\nRECOMMENDAT_IONS:")
        if summary['overall_improvement_score'] > 60:
            print("  ✅ Fine-tuning has been successful - model shows clear improvements")
        elif summary['overall_improvement_score'] > 40:
            print("  🔶 Fine-tuning shows moderate improvements - consider additional training")
        else:
            print("  ⚠️  Fine-tuning shows limited improvements - review training strategy")

        if len(summary['significant_improvements']) > 3:
            print("  ✅ Strong improvements across multiple quality dimensions")

        if fad_metrics.get('fad_baseline_vs_finetuned', 100) < 10:
            print("  ✅ Low FAD indicates successful quality transfer")

        print("=" * 80)


def main():
    """
    Main function to run comprehensive evaluation
    """
    # Sample directories
    baseline_dir = "evaluation_results/baseline_samples"
    finetuned_dir = "evaluation_results/finetuned_samples"

    # Check if directories exist
    if not os.path.exists(baseline_dir):
        print(f"Error: Baseline samples directory not found: {baseline_dir}")
        return None

    if not os.path.exists(finetuned_dir):
        print(f"Error: Fine-tuned samples directory not found: {finetuned_dir}")
        return None

    # Initialize evaluator
    evaluator = EnhancedPianoEvaluator(
        baseline_dir=baseline_dir,
        finetuned_dir=finetuned_dir
    )

    # Run comprehensive evaluation
    results = evaluator.run_comprehensive_evaluation()

    if results:
        print("\nComprehensive evaluation completed successfully!")

        # Create summary visualization
        create_improvement_visualization(results)

        return results
    else:
        print("\nEvaluation failed!")
        return None


def create_improvement_visualization(results: Dict):
    """
    Create visualization showing improvements
    """
    try:
        import matplotlib.pyplot as plt

        # Collect improvement percentages
        improvements_data = []
        labels = []

        for category, metrics in results['improvements'].items():
            if isinstance(metrics, dict) and 'improvement_percent' in metrics:
                improvements_data.append(metrics['improvement_percent'])
                labels.append(category.replace('_', ' ').title())
            else:
                for metric_name, metric_data in metrics.items():
                    if isinstance(metric_data, dict) and 'improvement_percent' in metric_data:
                        improvements_data.append(metric_data['improvement_percent'])
                        labels.append(f"{category.replace('_', ' ').title()}\n{metric_name.replace('_', ' ')}")

        if improvements_data:
            # Create horizontal bar chart
            fig, ax = plt.subplots(figsize=(12, 8))

            colors = ['green' if x > 0 else 'red' for x in improvements_data]
            bars = ax.barh(range(len(improvements_data)), improvements_data, color=colors, alpha=0.7)

            ax.set_yticks(range(len(labels)))
            ax.set_yticklabels(labels, fontsize=10)
            ax.set_xlabel('Improvement Percentage (%)', fontsize=12)
            ax.set_title('Fine-tuned Model Improvements by Metric', fontsize=14, fontweight='bold')
            ax.axvline(x=0, color='black', linestyle='-', alpha=0.5)
            ax.grid(axis='x', alpha=0.3)

            # Add value labels on bars
            for i, (bar, value) in enumerate(zip(bars, improvements_data)):
                ax.text(value + (1 if value > 0 else -1), i, f'{value:.1f}%',
                       va='center', ha='left' if value > 0 else 'right', fontsize=9)

            plt.tight_layout()
            plt.savefig('evaluation_results/results/improvement_visualization.png', dpi=300, bbox_inches='tight')
            plt.close()

            print("✓ Improvement visualization saved to evaluation_results/results/improvement_visualization.png")

    except ImportError:
        print("Matplotlib not available - skipping visualization")
    except Exception as e:
        print(f"Error creating visualization: {e}")


if __name__ == "__main__":
    main()