In [3]:
pip install numpy librosa pretty_midi matplotlib torch tqdm soundfile pandas

Defaulting to user installation because normal site-packages is not writeable
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import os
import librosa
import pretty_midi
import numpy as np
from pathlib import Path
import tensorflow as tf

# Directory setup
base_dir = Path("/nfsshare/selva/maestro-v3.0.0/maestro-v3.0.0")
output_dir = Path("./preprocessed_output_all_480_minimal")
mel_output_dir = output_dir / "mel_spectrograms"
piano_roll_output_dir = output_dir / "piano_rolls"
output_dir.mkdir(exist_ok=True)
mel_output_dir.mkdir(exist_ok=True)
piano_roll_output_dir.mkdir(exist_ok=True)

# Parameters
sr = 16000
hop_length = 512
n_fft = 2048
n_mels = 229
n_pitches = 88
segment_length = 480  # Fixed input size for CNN

# Preprocessing functions
def preprocess_wav_to_mel(wav_path, midi_duration):
    audio, _ = librosa.load(str(wav_path), sr=sr)
    # Trim audio to MIDI duration
    max_samples = int(midi_duration * sr)
    audio = audio[:max_samples] if len(audio) > max_samples else audio
    S = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    S_db = librosa.power_to_db(S, ref=np.max)
    expected_T = (len(audio) - n_fft) // hop_length + 1
    if S_db.shape[1] > expected_T:
        S_db = S_db[:, :expected_T]
    return S_db

def preprocess_midi_to_piano_roll(midi_path, T_target, frame_rate):
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    piano_roll = midi.get_piano_roll(fs=frame_rate)
    piano_roll = piano_roll[21:109, :]  # 88 pitches
    if piano_roll.shape[1] < T_target:
        padding = np.zeros((n_pitches, T_target - piano_roll.shape[1]), dtype=np.uint8)
        piano_roll = np.hstack((piano_roll, padding))
    elif piano_roll.shape[1] > T_target:
        piano_roll = piano_roll[:, :T_target]
    piano_roll = (piano_roll > 0).astype(np.uint8)
    return piano_roll

# Segment data into fixed-length chunks
def segment_data(data, segment_length, axis=1):
    T = data.shape[axis]
    if T < segment_length:
        pad_width = [(0, 0)] * len(data.shape)
        pad_width[axis] = (0, segment_length - T)
        return [np.pad(data, pad_width, mode='constant', constant_values=0)]
    segments = []
    for start in range(0, T, segment_length):
        end = min(start + segment_length, T)
        segment = data[:, start:end] if axis == 1 else data[start:end, :]
        if segment.shape[axis] < segment_length:
            pad_width = [(0, 0)] * len(segment.shape)
            pad_width[axis] = (0, segment_length - segment.shape[axis])
            segment = np.pad(segment, pad_width, mode='constant', constant_values=0)
        segments.append(segment)
    return segments

# Collect all WAV and MIDI files across subfolders
wav_files = []
for year_dir in base_dir.iterdir():
    if year_dir.is_dir():
        wav_files.extend(sorted([f for f in year_dir.glob("*.wav") if f.with_suffix(".midi").exists()]))
midi_files = [f.with_suffix(".midi") for f in wav_files]

print(f"Found {len(wav_files)} files across all subfolders in {base_dir}")

# Process all files with minimal validation
segment_counts = []  # Track number of segments per file
skipped_files = 0
for i, (wav_file, midi_file) in enumerate(zip(wav_files, midi_files)):
    print(f"Processing file {i+1}/{len(wav_files)}: {wav_file.name} (from {wav_file.parent.name})")
    
    # Get MIDI duration
    try:
        midi = pretty_midi.PrettyMIDI(str(midi_file))
        midi_duration = midi.get_end_time()
    except Exception as e:
        print(f"  Skipping {wav_file.name}: Error reading MIDI file ({e})")
        skipped_files += 1
        continue
    
    # Mel spectrogram
    try:
        S_db = preprocess_wav_to_mel(wav_file, midi_duration)
        if S_db.shape[1] < segment_length:
            print(f"  Skipping {wav_file.name}: Too short (< 480 frames)")
            skipped_files += 1
            continue
    except Exception as e:
        print(f"  Skipping {wav_file.name}: Error processing WAV file ({e})")
        skipped_files += 1
        continue
    
    # Piano-roll
    frame_rate = sr / hop_length
    piano_roll = preprocess_midi_to_piano_roll(midi_file, S_db.shape[1], frame_rate)
    
    # Segment into 480-frame chunks
    mel_segments = segment_data(S_db, segment_length, axis=1)
    piano_roll_segments = segment_data(piano_roll, segment_length, axis=1)
    
    # Minimal shape validation for the first segment
    assert mel_segments[0].shape == (n_mels, segment_length), f"Mel shape mismatch: {mel_segments[0].shape}"
    assert piano_roll_segments[0].shape == (n_pitches, segment_length), f"Piano-Roll shape mismatch: {piano_roll_segments[0].shape}"
    
    # Save segments and count
    segment_count = len(mel_segments)
    segment_counts.append(segment_count)
    for j in range(segment_count):
        mel_output_path = mel_output_dir / f"{wav_file.stem}_seg{j:04d}_mel.npy"
        piano_roll_output_path = piano_roll_output_dir / f"{midi_file.stem}_seg{j:04d}_piano_roll.npy"
        np.save(mel_output_path, mel_segments[j])
        np.save(piano_roll_output_path, piano_roll_segments[j])
    print(f"  Saved {segment_count} segments")

# Summary of processing
total_segments = sum(segment_counts)
print(f"\nProcessing Summary:")
print(f"Total files processed: {len(wav_files) - skipped_files}")
print(f"Total files skipped: {skipped_files}")
print(f"Total segments generated: {total_segments}")
print(f"Mel spectrograms saved in: {mel_output_dir}")
print(f"Piano-rolls saved in: {piano_roll_output_dir}")

2025-03-14 23:41:15.582597: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-14 23:41:16.023006: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741975876.178176   32621 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741975876.185638   32621 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-14 23:41:16.727828: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Found 1276 files across all subfolders in /nfsshare/selva/maestro-v3.0.0/maestro-v3.0.0
Processing file 1/1276: MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav (from 2004)
  Saved 64 segments
Processing file 2/1276: MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.wav (from 2004)
  Saved 18 segments
Processing file 3/1276: MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_08_Track08_wav.wav (from 2004)
  Saved 13 segments
Processing file 4/1276: MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_10_Track10_wav.wav (from 2004)
  Saved 21 segments
Processing file 5/1276: MIDI-Unprocessed_SMF_05_R1_2004_01_ORIG_MID--AUDIO_05_R1_2004_02_Track02_wav.wav (from 2004)
  Saved 92 segments
Processing file 6/1276: MIDI-Unprocessed_SMF_05_R1_2004_01_ORIG_MID--AUDIO_05_R1_2004_03_Track03_wav.wav (from 2004)
  Saved 22 segments
Processing file 7/1276: MIDI-Unprocessed_SMF_05_R1_2004_02-03_ORIG_MID--AUDIO_

  Saved 35 segments
Processing file 59/1276: MIDI-Unprocessed_XP_04_R2_2004_01_ORIG_MID--AUDIO_04_R2_2004_01_Track01_wav.wav (from 2004)
  Saved 67 segments
Processing file 60/1276: MIDI-Unprocessed_XP_04_R2_2004_01_ORIG_MID--AUDIO_04_R2_2004_02_Track02_wav.wav (from 2004)
  Saved 85 segments
Processing file 61/1276: MIDI-Unprocessed_XP_06_R1_2004_01_ORIG_MID--AUDIO_06_R1_2004_01_Track01_wav.wav (from 2004)
  Saved 157 segments
Processing file 62/1276: MIDI-Unprocessed_XP_06_R1_2004_02-03_ORIG_MID--AUDIO_06_R1_2004_05_Track05_wav.wav (from 2004)
  Saved 49 segments
Processing file 63/1276: MIDI-Unprocessed_XP_06_R2_2004_01_ORIG_MID--AUDIO_06_R2_2004_01_Track01_wav.wav (from 2004)
  Saved 46 segments
Processing file 64/1276: MIDI-Unprocessed_XP_06_R2_2004_01_ORIG_MID--AUDIO_06_R2_2004_02_Track02_wav.wav (from 2004)
  Saved 35 segments
Processing file 65/1276: MIDI-Unprocessed_XP_06_R2_2004_01_ORIG_MID--AUDIO_06_R2_2004_03_Track03_wav.wav (from 2004)
  Saved 49 segments
Processing file 6

  Saved 137 segments
Processing file 119/1276: MIDI-Unprocessed_XP_19_R2_2004_01_ORIG_MID--AUDIO_19_R2_2004_01_Track01_wav.wav (from 2004)
  Saved 30 segments
Processing file 120/1276: MIDI-Unprocessed_XP_19_R2_2004_01_ORIG_MID--AUDIO_19_R2_2004_02_Track02_wav.wav (from 2004)
  Saved 31 segments
Processing file 121/1276: MIDI-Unprocessed_XP_19_R2_2004_01_ORIG_MID--AUDIO_19_R2_2004_03_Track03_wav.wav (from 2004)
  Saved 46 segments
Processing file 122/1276: MIDI-Unprocessed_XP_20_R2_2004_01_ORIG_MID--AUDIO_20_R1_2004_01_Track01_wav.wav (from 2004)
  Saved 50 segments
Processing file 123/1276: MIDI-Unprocessed_XP_20_R2_2004_01_ORIG_MID--AUDIO_20_R1_2004_02_Track02_wav.wav (from 2004)
  Saved 29 segments
Processing file 124/1276: MIDI-Unprocessed_XP_20_R2_2004_01_ORIG_MID--AUDIO_20_R1_2004_03_Track03_wav.wav (from 2004)
  Saved 45 segments
Processing file 125/1276: MIDI-Unprocessed_XP_21_R1_2004_01_ORIG_MID--AUDIO_21_R1_2004_01_Track01_wav.wav (from 2004)
  Saved 57 segments
Processing fi

  Saved 61 segments
Processing file 179/1276: MIDI-Unprocessed_09_R2_2006_01_ORIG_MID--AUDIO_09_R2_2006_02_Track02_wav.wav (from 2006)
  Saved 39 segments
Processing file 180/1276: MIDI-Unprocessed_09_R2_2006_01_ORIG_MID--AUDIO_09_R2_2006_03_Track03_wav.wav (from 2006)
  Saved 51 segments
Processing file 181/1276: MIDI-Unprocessed_10_R1_2006_01-04_ORIG_MID--AUDIO_10_R1_2006_02_Track02_wav.wav (from 2006)
  Saved 92 segments
Processing file 182/1276: MIDI-Unprocessed_10_R1_2006_01-04_ORIG_MID--AUDIO_10_R1_2006_03_Track03_wav.wav (from 2006)
  Saved 83 segments
Processing file 183/1276: MIDI-Unprocessed_10_R1_2006_01-04_ORIG_MID--AUDIO_10_R1_2006_05_Track05_wav.wav (from 2006)
  Saved 60 segments
Processing file 184/1276: MIDI-Unprocessed_11_R1_2006_01-06_ORIG_MID--AUDIO_11_R1_2006_01_Track01_wav.wav (from 2006)
  Saved 22 segments
Processing file 185/1276: MIDI-Unprocessed_11_R1_2006_01-06_ORIG_MID--AUDIO_11_R1_2006_02_Track02_wav.wav (from 2006)
  Saved 73 segments
Processing file 186/

  Saved 123 segments
Processing file 239/1276: MIDI-Unprocessed_23_R2_2006_01_ORIG_MID--AUDIO_23_R2_2006_01_Track01_wav.wav (from 2006)
  Saved 43 segments
Processing file 240/1276: MIDI-Unprocessed_23_R2_2006_01_ORIG_MID--AUDIO_23_R2_2006_02_Track02_wav.wav (from 2006)
  Saved 48 segments
Processing file 241/1276: MIDI-Unprocessed_23_R2_2006_01_ORIG_MID--AUDIO_23_R2_2006_03_Track03_wav.wav (from 2006)
  Saved 28 segments
Processing file 242/1276: MIDI-Unprocessed_23_R2_2006_01_ORIG_MID--AUDIO_23_R2_2006_04_Track04_wav.wav (from 2006)
  Saved 18 segments
Processing file 243/1276: MIDI-Unprocessed_24_R1_2006_01-05_ORIG_MID--AUDIO_24_R1_2006_01_Track01_wav.wav (from 2006)
  Saved 62 segments
Processing file 244/1276: MIDI-Unprocessed_24_R1_2006_01-05_ORIG_MID--AUDIO_24_R1_2006_02_Track02_wav.wav (from 2006)
  Saved 45 segments
Processing file 245/1276: MIDI-Unprocessed_24_R1_2006_01-05_ORIG_MID--AUDIO_24_R1_2006_03_Track03_wav.wav (from 2006)
  Saved 28 segments
Processing file 246/1276:

  Saved 7 segments
Processing file 302/1276: MIDI-Unprocessed_07_R2_2008_01-05_ORIG_MID--AUDIO_07_R2_2008_wav--3.wav (from 2008)
  Saved 9 segments
Processing file 303/1276: MIDI-Unprocessed_07_R2_2008_01-05_ORIG_MID--AUDIO_07_R2_2008_wav--4.wav (from 2008)
  Saved 28 segments
Processing file 304/1276: MIDI-Unprocessed_07_R3_2008_01-05_ORIG_MID--AUDIO_07_R3_2008_wav--1.wav (from 2008)
  Saved 16 segments
Processing file 305/1276: MIDI-Unprocessed_07_R3_2008_01-05_ORIG_MID--AUDIO_07_R3_2008_wav--2.wav (from 2008)
  Saved 45 segments
Processing file 306/1276: MIDI-Unprocessed_07_R3_2008_01-05_ORIG_MID--AUDIO_07_R3_2008_wav--3.wav (from 2008)
  Saved 12 segments
Processing file 307/1276: MIDI-Unprocessed_07_R3_2008_01-05_ORIG_MID--AUDIO_07_R3_2008_wav--4.wav (from 2008)
  Saved 14 segments
Processing file 308/1276: MIDI-Unprocessed_08_R1_2008_01-05_ORIG_MID--AUDIO_08_R1_2008_wav--1.wav (from 2008)
  Saved 11 segments
Processing file 309/1276: MIDI-Unprocessed_08_R1_2008_01-05_ORIG_MID--AU

  Saved 41 segments
Processing file 366/1276: MIDI-Unprocessed_12_R3_2008_01-04_ORIG_MID--AUDIO_12_R3_2008_wav--3.wav (from 2008)
  Saved 12 segments
Processing file 367/1276: MIDI-Unprocessed_13_R1_2008_01-04_ORIG_MID--AUDIO_13_R1_2008_wav--1.wav (from 2008)
  Saved 16 segments
Processing file 368/1276: MIDI-Unprocessed_13_R1_2008_01-04_ORIG_MID--AUDIO_13_R1_2008_wav--2.wav (from 2008)
  Saved 13 segments
Processing file 369/1276: MIDI-Unprocessed_13_R1_2008_01-04_ORIG_MID--AUDIO_13_R1_2008_wav--4.wav (from 2008)
  Saved 44 segments
Processing file 370/1276: MIDI-Unprocessed_14_R1_2008_01-05_ORIG_MID--AUDIO_14_R1_2008_wav--1.wav (from 2008)
  Saved 15 segments
Processing file 371/1276: MIDI-Unprocessed_14_R1_2008_01-05_ORIG_MID--AUDIO_14_R1_2008_wav--2.wav (from 2008)
  Saved 18 segments
Processing file 372/1276: MIDI-Unprocessed_14_R1_2008_01-05_ORIG_MID--AUDIO_14_R1_2008_wav--3.wav (from 2008)
  Saved 12 segments
Processing file 373/1276: MIDI-Unprocessed_14_R1_2008_01-05_ORIG_MID--

  Saved 77 segments
Processing file 427/1276: MIDI-Unprocessed_07_R2_2009_01_ORIG_MID--AUDIO_07_R2_2009_07_R2_2009_01_WAV.wav (from 2009)
  Saved 32 segments
Processing file 428/1276: MIDI-Unprocessed_07_R2_2009_01_ORIG_MID--AUDIO_07_R2_2009_07_R2_2009_02_WAV.wav (from 2009)
  Saved 31 segments
Processing file 429/1276: MIDI-Unprocessed_07_R2_2009_01_ORIG_MID--AUDIO_07_R2_2009_07_R2_2009_03_WAV.wav (from 2009)
  Saved 11 segments
Processing file 430/1276: MIDI-Unprocessed_07_R2_2009_01_ORIG_MID--AUDIO_07_R2_2009_07_R2_2009_04_WAV.wav (from 2009)
  Saved 34 segments
Processing file 431/1276: MIDI-Unprocessed_08_R1_2009_01-04_ORIG_MID--AUDIO_08_R1_2009_08_R1_2009_01_WAV.wav (from 2009)
  Saved 44 segments
Processing file 432/1276: MIDI-Unprocessed_08_R1_2009_01-04_ORIG_MID--AUDIO_08_R1_2009_08_R1_2009_02_WAV.wav (from 2009)
  Saved 22 segments
Processing file 433/1276: MIDI-Unprocessed_08_R1_2009_01-04_ORIG_MID--AUDIO_08_R1_2009_08_R1_2009_03_WAV.wav (from 2009)
  Saved 16 segments
Proce

  Saved 26 segments
Processing file 486/1276: MIDI-Unprocessed_16_R1_2009_01-02_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_01_WAV.wav (from 2009)
  Saved 100 segments
Processing file 487/1276: MIDI-Unprocessed_16_R1_2009_01-02_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_02_WAV.wav (from 2009)
  Saved 46 segments
Processing file 488/1276: MIDI-Unprocessed_16_R1_2009_03-06_ORIG_MID--AUDIO_16_R1_2009_16_R1_2009_03_WAV.wav (from 2009)
  Saved 41 segments
Processing file 489/1276: MIDI-Unprocessed_16_R2_2009_01_ORIG_MID--AUDIO_16_R2_2009_16_R2_2009_01_WAV.wav (from 2009)
  Saved 47 segments
Processing file 490/1276: MIDI-Unprocessed_16_R2_2009_01_ORIG_MID--AUDIO_16_R2_2009_16_R2_2009_02_WAV.wav (from 2009)
  Saved 36 segments
Processing file 491/1276: MIDI-Unprocessed_16_R2_2009_01_ORIG_MID--AUDIO_16_R2_2009_16_R2_2009_03_WAV.wav (from 2009)
  Saved 21 segments
Processing file 492/1276: MIDI-Unprocessed_16_R2_2009_01_ORIG_MID--AUDIO_16_R2_2009_16_R2_2009_04_WAV.wav (from 2009)
  Saved 35 segments
Proc

  Saved 39 segments
Processing file 549/1276: MIDI-Unprocessed_04_R2_2011_MID--AUDIO_R2-D2_03_Track03_wav.wav (from 2011)
  Saved 36 segments
Processing file 550/1276: MIDI-Unprocessed_04_R3_2011_MID--AUDIO_R3-D2_03_Track03_wav.wav (from 2011)
  Saved 24 segments
Processing file 551/1276: MIDI-Unprocessed_04_R3_2011_MID--AUDIO_R3-D2_04_Track04_wav.wav (from 2011)
  Saved 44 segments
Processing file 552/1276: MIDI-Unprocessed_04_R3_2011_MID--AUDIO_R3-D2_05_Track05_wav.wav (from 2011)
  Saved 14 segments
Processing file 553/1276: MIDI-Unprocessed_04_R3_2011_MID--AUDIO_R3-D2_06_Track06_wav.wav (from 2011)
  Saved 53 segments
Processing file 554/1276: MIDI-Unprocessed_05_R1_2011_MID--AUDIO_R1-D2_08_Track08_wav.wav (from 2011)
  Saved 8 segments
Processing file 555/1276: MIDI-Unprocessed_05_R1_2011_MID--AUDIO_R1-D2_09_Track09_wav.wav (from 2011)
  Saved 13 segments
Processing file 556/1276: MIDI-Unprocessed_05_R1_2011_MID--AUDIO_R1-D2_10_Track10_wav.wav (from 2011)
  Saved 13 segments
Proce

  Saved 15 segments
Processing file 617/1276: MIDI-Unprocessed_16_R1_2011_MID--AUDIO_R1-D6_14_Track14_wav.wav (from 2011)
  Saved 26 segments
Processing file 618/1276: MIDI-Unprocessed_16_R1_2011_MID--AUDIO_R1-D6_15_Track15_wav.wav (from 2011)
  Saved 24 segments
Processing file 619/1276: MIDI-Unprocessed_16_R2_2011_MID--AUDIO_R2-D4_08_Track08_wav.wav (from 2011)
  Saved 54 segments
Processing file 620/1276: MIDI-Unprocessed_16_R2_2011_MID--AUDIO_R2-D4_09_Track09_wav.wav (from 2011)
  Saved 17 segments
Processing file 621/1276: MIDI-Unprocessed_16_R3_2011_MID--AUDIO_R3-D5_02_Track02_wav.wav (from 2011)
  Saved 44 segments
Processing file 622/1276: MIDI-Unprocessed_17_R1_2011_MID--AUDIO_R1-D7_02_Track02_wav.wav (from 2011)
  Saved 11 segments
Processing file 623/1276: MIDI-Unprocessed_17_R1_2011_MID--AUDIO_R1-D7_03_Track03_wav.wav (from 2011)
  Saved 23 segments
Processing file 624/1276: MIDI-Unprocessed_17_R1_2011_MID--AUDIO_R1-D7_04_Track04_wav.wav (from 2011)
  Saved 13 segments
Proc

  Saved 30 segments
Processing file 685/1276: ORIG-MIDI_01_7_10_13_Group_MID--AUDIO_07_R3_2013_wav--2.wav (from 2013)
  Saved 53 segments
Processing file 686/1276: ORIG-MIDI_01_7_10_13_Group_MID--AUDIO_07_R3_2013_wav--3.wav (from 2013)
  Saved 36 segments
Processing file 687/1276: ORIG-MIDI_01_7_10_13_Group_MID--AUDIO_08_R3_2013_wav--1.wav (from 2013)
  Saved 40 segments
Processing file 688/1276: ORIG-MIDI_01_7_10_13_Group_MID--AUDIO_08_R3_2013_wav--2.wav (from 2013)
  Saved 54 segments
Processing file 689/1276: ORIG-MIDI_01_7_10_13_Group_MID--AUDIO_08_R3_2013_wav--3.wav (from 2013)
  Saved 34 segments
Processing file 690/1276: ORIG-MIDI_01_7_6_13_Group__MID--AUDIO_01_R1_2013_wav--1.wav (from 2013)
  Saved 29 segments
Processing file 691/1276: ORIG-MIDI_01_7_6_13_Group__MID--AUDIO_01_R1_2013_wav--2.wav (from 2013)
  Saved 22 segments
Processing file 692/1276: ORIG-MIDI_01_7_6_13_Group__MID--AUDIO_01_R1_2013_wav--3.wav (from 2013)
  Saved 20 segments
Processing file 693/1276: ORIG-MIDI_

  Saved 15 segments
Processing file 755/1276: ORIG-MIDI_02_7_6_13_Group__MID--AUDIO_08_R1_2013_wav--5.wav (from 2013)
  Saved 11 segments
Processing file 756/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_15_R1_2013_wav--1.wav (from 2013)
  Saved 21 segments
Processing file 757/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_15_R1_2013_wav--2.wav (from 2013)
  Saved 22 segments
Processing file 758/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_15_R1_2013_wav--3.wav (from 2013)
  Saved 14 segments
Processing file 759/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_15_R1_2013_wav--4.wav (from 2013)
  Saved 20 segments
Processing file 760/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_16_R1_2013_wav--1.wav (from 2013)
  Saved 12 segments
Processing file 761/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_16_R1_2013_wav--2.wav (from 2013)
  Saved 28 segments
Processing file 762/1276: ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_16_R1_2013_wav--3.wav (from 2013)
  Saved 16 segments
Processing file 763/1276: ORIG-MIDI_

  Saved 10 segments
Processing file 824/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_04_R1_2014_wav--2.wav (from 2014)
  Saved 79 segments
Processing file 825/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_04_R1_2014_wav--3.wav (from 2014)
  Saved 74 segments
Processing file 826/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_04_R1_2014_wav--4.wav (from 2014)
  Saved 19 segments
Processing file 827/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_04_R1_2014_wav--5.wav (from 2014)
  Saved 13 segments
Processing file 828/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_05_R1_2014_wav--1.wav (from 2014)
  Saved 49 segments
Processing file 829/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_05_R1_2014_wav--2.wav (from 2014)
  Saved 14 segments
Processing file 830/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_05_R1_2014_wav--4.wav (from 2014)
  Saved 17 segments
Processing file 831/1276: MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_05_R1_2014_wav--5.wav (from 2014)
  Saved 8 segments
Proce

  Saved 72 segments
Processing file 891/1276: MIDI-UNPROCESSED_16-18_R1_2014_MID--AUDIO_18_R1_2014_wav--3.wav (from 2014)
  Saved 10 segments
Processing file 892/1276: MIDI-UNPROCESSED_16-18_R1_2014_MID--AUDIO_18_R1_2014_wav--4.wav (from 2014)
  Saved 15 segments
Processing file 893/1276: MIDI-UNPROCESSED_16-18_R1_2014_MID--AUDIO_18_R1_2014_wav--5.wav (from 2014)
  Saved 31 segments
Processing file 894/1276: MIDI-UNPROCESSED_19-20-21_R2_2014_MID--AUDIO_19_R2_2014_wav.wav (from 2014)
  Saved 160 segments
Processing file 895/1276: MIDI-UNPROCESSED_19-20-21_R2_2014_MID--AUDIO_20_R2_2014_wav.wav (from 2014)
  Saved 104 segments
Processing file 896/1276: MIDI-UNPROCESSED_19-20-21_R2_2014_MID--AUDIO_21_R2_2014_wav.wav (from 2014)
  Saved 116 segments
Processing file 897/1276: MIDI-UNPROCESSED_19-20_R1_2014_MID--AUDIO_19_R1_2014_wav--1.wav (from 2014)
  Saved 40 segments
Processing file 898/1276: MIDI-UNPROCESSED_19-20_R1_2014_MID--AUDIO_19_R1_2014_wav--2.wav (from 2014)
  Saved 104 segments


  Saved 23 segments
Processing file 957/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_10_R1_2015_wav--2.wav (from 2015)
  Saved 8 segments
Processing file 958/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_10_R1_2015_wav--3.wav (from 2015)
  Saved 82 segments
Processing file 959/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_11_R1_2015_wav--1.wav (from 2015)
  Saved 60 segments
Processing file 960/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_11_R1_2015_wav--2.wav (from 2015)
  Saved 14 segments
Processing file 961/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_11_R1_2015_wav--3.wav (from 2015)
  Saved 10 segments
Processing file 962/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_11_R1_2015_wav--4.wav (from 2015)
  Saved 18 segments
Processing file 963/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_11_R1_2015_wav--5.wav (from 2015)
  Saved 8 segments
Processing file 964/1276: MIDI-Unprocessed_R1_D1-9-12_mid--AUDIO-from_mp3_12_R1_2015_w

  Saved 45 segments
Processing file 1020/1276: MIDI-Unprocessed_R2_D1-2-3-6-7-8-11_mid--AUDIO-from_mp3_08_R2_2015_wav--3.wav (from 2015)
  Saved 28 segments
Processing file 1021/1276: MIDI-Unprocessed_R2_D1-2-3-6-7-8-11_mid--AUDIO-from_mp3_11_R2_2015_wav--1.wav (from 2015)
  Saved 45 segments
Processing file 1022/1276: MIDI-Unprocessed_R2_D1-2-3-6-7-8-11_mid--AUDIO-from_mp3_11_R2_2015_wav--2.wav (from 2015)
  Saved 24 segments
Processing file 1023/1276: MIDI-Unprocessed_R2_D1-2-3-6-7-8-11_mid--AUDIO-from_mp3_11_R2_2015_wav--4.wav (from 2015)
  Saved 54 segments
Processing file 1024/1276: MIDI-Unprocessed_R2_D2-12-13-15_mid--AUDIO-from_mp3_12_R2_2015_wav--1.wav (from 2015)
  Saved 38 segments
Processing file 1025/1276: MIDI-Unprocessed_R2_D2-12-13-15_mid--AUDIO-from_mp3_12_R2_2015_wav--2.wav (from 2015)
  Saved 23 segments
Processing file 1026/1276: MIDI-Unprocessed_R2_D2-12-13-15_mid--AUDIO-from_mp3_12_R2_2015_wav--3.wav (from 2015)
  Saved 49 segments
Processing file 1027/1276: MIDI-U

  Saved 25 segments
Processing file 1080/1276: MIDI-Unprocessed_050_PIANO050_MID--AUDIO-split_07-06-17_Piano-e_3-01_wav--3.wav (from 2017)
  Saved 32 segments
Processing file 1081/1276: MIDI-Unprocessed_050_PIANO050_MID--AUDIO-split_07-06-17_Piano-e_3-01_wav--4.wav (from 2017)
  Saved 31 segments
Processing file 1082/1276: MIDI-Unprocessed_051_PIANO051_MID--AUDIO-split_07-06-17_Piano-e_3-02_wav--1.wav (from 2017)
  Saved 10 segments
Processing file 1083/1276: MIDI-Unprocessed_051_PIANO051_MID--AUDIO-split_07-06-17_Piano-e_3-02_wav--2.wav (from 2017)
  Saved 21 segments
Processing file 1084/1276: MIDI-Unprocessed_051_PIANO051_MID--AUDIO-split_07-06-17_Piano-e_3-02_wav--3.wav (from 2017)
  Saved 15 segments
Processing file 1085/1276: MIDI-Unprocessed_051_PIANO051_MID--AUDIO-split_07-06-17_Piano-e_3-02_wav--4.wav (from 2017)
  Saved 42 segments
Processing file 1086/1276: MIDI-Unprocessed_051_PIANO051_MID--AUDIO-split_07-06-17_Piano-e_3-02_wav--5.wav (from 2017)
  Saved 16 segments
Process

  Saved 19 segments
Processing file 1139/1276: MIDI-Unprocessed_067_PIANO067_MID--AUDIO-split_07-07-17_Piano-e_3-03_wav--1.wav (from 2017)
  Saved 21 segments
Processing file 1140/1276: MIDI-Unprocessed_067_PIANO067_MID--AUDIO-split_07-07-17_Piano-e_3-03_wav--2.wav (from 2017)
  Saved 19 segments
Processing file 1141/1276: MIDI-Unprocessed_067_PIANO067_MID--AUDIO-split_07-07-17_Piano-e_3-03_wav--3.wav (from 2017)
  Saved 8 segments
Processing file 1142/1276: MIDI-Unprocessed_067_PIANO067_MID--AUDIO-split_07-07-17_Piano-e_3-03_wav--4.wav (from 2017)
  Saved 28 segments
Processing file 1143/1276: MIDI-Unprocessed_070_PIANO070_MID--AUDIO-split_07-08-17_Piano-e_1-02_wav--1.wav (from 2017)
  Saved 19 segments
Processing file 1144/1276: MIDI-Unprocessed_070_PIANO070_MID--AUDIO-split_07-08-17_Piano-e_1-02_wav--2.wav (from 2017)
  Saved 6 segments
Processing file 1145/1276: MIDI-Unprocessed_070_PIANO070_MID--AUDIO-split_07-08-17_Piano-e_1-02_wav--3.wav (from 2017)
  Saved 48 segments
Processin

  Saved 133 segments
Processing file 1201/1276: MIDI-Unprocessed_Recital1-3_MID--AUDIO_03_R1_2018_wav--4.wav (from 2018)
  Saved 17 segments
Processing file 1202/1276: MIDI-Unprocessed_Recital1-3_MID--AUDIO_03_R1_2018_wav--5.wav (from 2018)
  Saved 11 segments
Processing file 1203/1276: MIDI-Unprocessed_Recital12_MID--AUDIO_12_R1_2018_wav--1.wav (from 2018)
  Saved 53 segments
Processing file 1204/1276: MIDI-Unprocessed_Recital12_MID--AUDIO_12_R1_2018_wav--2.wav (from 2018)
  Saved 46 segments
Processing file 1205/1276: MIDI-Unprocessed_Recital12_MID--AUDIO_12_R1_2018_wav--3.wav (from 2018)
  Saved 42 segments
Processing file 1206/1276: MIDI-Unprocessed_Recital13-15_MID--AUDIO_13_R1_2018_wav--1.wav (from 2018)
  Saved 70 segments
Processing file 1207/1276: MIDI-Unprocessed_Recital13-15_MID--AUDIO_13_R1_2018_wav--2.wav (from 2018)
  Saved 100 segments
Processing file 1208/1276: MIDI-Unprocessed_Recital13-15_MID--AUDIO_13_R1_2018_wav--3.wav (from 2018)
  Saved 107 segments
Processing fil

  Saved 124 segments
Processing file 1269/1276: MIDI-Unprocessed_Schubert10-12_MID--AUDIO_18_R2_2018_wav.wav (from 2018)
  Saved 130 segments
Processing file 1270/1276: MIDI-Unprocessed_Schubert10-12_MID--AUDIO_20_R2_2018_wav.wav (from 2018)
  Saved 153 segments
Processing file 1271/1276: MIDI-Unprocessed_Schubert4-6_MID--AUDIO_08_R2_2018_wav.wav (from 2018)
  Saved 114 segments
Processing file 1272/1276: MIDI-Unprocessed_Schubert4-6_MID--AUDIO_09_R2_2018_wav.wav (from 2018)
  Saved 167 segments
Processing file 1273/1276: MIDI-Unprocessed_Schubert4-6_MID--AUDIO_10_R2_2018_wav.wav (from 2018)
  Saved 158 segments
Processing file 1274/1276: MIDI-Unprocessed_Schubert7-9_MID--AUDIO_11_R2_2018_wav.wav (from 2018)
  Saved 107 segments
Processing file 1275/1276: MIDI-Unprocessed_Schubert7-9_MID--AUDIO_15_R2_2018_wav.wav (from 2018)
  Saved 136 segments
Processing file 1276/1276: MIDI-Unprocessed_Schubert7-9_MID--AUDIO_16_R2_2018_wav.wav (from 2018)
  Saved 115 segments

Processing Summary:
To

In [7]:
import os
import shutil
import numpy as np
from pathlib import Path
from collections import defaultdict

# Input directories (from previous preprocessing)
base_output_dir = Path("./preprocessed_output_all_480_minimal")
mel_input_dir = base_output_dir / "mel_spectrograms"
piano_roll_input_dir = base_output_dir / "piano_rolls"

# Output directories for train, val, test splits
mel_train_dir = mel_input_dir / "train"
mel_val_dir = mel_input_dir / "val"
mel_test_dir = mel_input_dir / "test"
piano_roll_train_dir = piano_roll_input_dir / "train"
piano_roll_val_dir = piano_roll_input_dir / "val"
piano_roll_test_dir = piano_roll_input_dir / "test"

# Create directories
for d in [mel_train_dir, mel_val_dir, mel_test_dir, piano_roll_train_dir, piano_roll_val_dir, piano_roll_test_dir]:
    d.mkdir(exist_ok=True)

# Step 1: Collect all Mel spectrogram files and group by original file
mel_files = sorted(list(mel_input_dir.glob("*_mel.npy")))
print(f"Total Mel spectrogram segments: {len(mel_files)}")

# Group segments by original file (based on filename before "_segXXXX_mel.npy")
file_groups = defaultdict(list)
for mel_file in mel_files:
    # Extract the base filename (e.g., "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav")
    base_name = mel_file.stem.split("_seg")[0]
    file_groups[base_name].append(mel_file)

# Step 2: Split files into train, val, test (80-10-10)
file_names = list(file_groups.keys())
np.random.seed(42)  # For reproducibility
np.random.shuffle(file_names)

total_files = len(file_names)
train_split = int(0.8 * total_files)  # 80%
val_split = int(0.1 * total_files)    # 10%
test_split = total_files - train_split - val_split  # Remaining 10%

train_files = file_names[:train_split]
val_files = file_names[train_split:train_split + val_split]
test_files = file_names[train_split + val_split:]

print(f"Total files: {total_files}")
print(f"Train files: {len(train_files)} ({len(train_files)/total_files*100:.1f}%)")
print(f"Val files: {len(val_files)} ({len(val_files)/total_files*100:.1f}%)")
print(f"Test files: {len(test_files)} ({len(test_files)/total_files*100:.1f}%)")

# Step 3: Move Mel spectrograms and piano-rolls to appropriate directories
def move_files(file_list, split_name, mel_dest_dir, piano_roll_dest_dir):
    mel_count = 0
    piano_roll_count = 0
    for base_name in file_list:
        # Move Mel spectrograms
        mel_segments = file_groups[base_name]
        for mel_file in mel_segments:
            dest_path = mel_dest_dir / mel_file.name
            shutil.move(str(mel_file), str(dest_path))
            mel_count += 1
        
        # Move corresponding piano-rolls
        piano_roll_segments = [piano_roll_input_dir / mel_file.name.replace("_mel.npy", "_piano_roll.npy") for mel_file in mel_segments]
        for piano_roll_file in piano_roll_segments:
            if piano_roll_file.exists():
                dest_path = piano_roll_dest_dir / piano_roll_file.name
                shutil.move(str(piano_roll_file), str(dest_path))
                piano_roll_count += 1
            else:
                print(f"Warning: Piano-roll file {piano_roll_file} not found for {base_name}")
    
    print(f"\nMoved to {split_name}:")
    print(f"  Mel spectrograms: {mel_count} segments")
    print(f"  Piano-rolls: {piano_roll_count} segments")

# Move files to train, val, test directories
move_files(train_files, "train", mel_train_dir, piano_roll_train_dir)
move_files(val_files, "val", mel_val_dir, piano_roll_val_dir)
move_files(test_files, "test", mel_test_dir, piano_roll_test_dir)

print("\nDataset split completed.")
print(f"Train: {mel_train_dir}, {piano_roll_train_dir}")
print(f"Val: {mel_val_dir}, {piano_roll_val_dir}")
print(f"Test: {mel_test_dir}, {piano_roll_test_dir}")

Total Mel spectrogram segments: 47268
Total files: 1276
Train files: 1020 (79.9%)
Val files: 127 (10.0%)
Test files: 129 (10.1%)

Moved to train:
  Mel spectrograms: 37993 segments
  Piano-rolls: 37993 segments

Moved to val:
  Mel spectrograms: 4113 segments
  Piano-rolls: 4113 segments

Moved to test:
  Mel spectrograms: 5162 segments
  Piano-rolls: 5162 segments

Dataset split completed.
Train: preprocessed_output_all_480_minimal/mel_spectrograms/train, preprocessed_output_all_480_minimal/piano_rolls/train
Val: preprocessed_output_all_480_minimal/mel_spectrograms/val, preprocessed_output_all_480_minimal/piano_rolls/val
Test: preprocessed_output_all_480_minimal/mel_spectrograms/test, preprocessed_output_all_480_minimal/piano_rolls/test


In [10]:
import numpy as np
from pathlib import Path
from scipy.stats import pearsonr
import tensorflow as tf

# Directory for training piano-rolls
piano_roll_train_dir = Path("./preprocessed_output_all_480_minimal/piano_rolls/train")

# Parameters
n_pitches = 88
max_files_to_process = 37993
target_sparsity = (0.94, 0.96)  # Adjusted to hit ~0.95
threshold_range = np.linspace(0.15, 0.25, 21)  # Narrow range to target ~200 edges

# Step 1: Load training piano-rolls and compute note activations
piano_roll_files = sorted(list(piano_roll_train_dir.glob("*_piano_roll.npy")))
print(f"Found {len(piano_roll_files)} training piano-roll segments")
print(f"Processing up to {max_files_to_process} segments for efficiency")

total_frames = 0
processed_files = 0
for i, piano_roll_file in enumerate(piano_roll_files):
    if processed_files >= max_files_to_process:
        break
    piano_roll = np.load(piano_roll_file)
    total_frames += piano_roll.shape[1]
    processed_files += 1

note_activations = np.zeros((n_pitches, total_frames), dtype=np.uint8)
current_frame = 0
processed_files = 0
for i, piano_roll_file in enumerate(piano_roll_files):
    if processed_files >= max_files_to_process:
        break
    if i % 1000 == 0:
        print(f"Processing piano-roll {i}/{len(piano_roll_files)}")
    
    piano_roll = np.load(piano_roll_file)
    piano_roll = (piano_roll > 0).astype(np.uint8)
    num_frames = piano_roll.shape[1]
    note_activations[:, current_frame:current_frame + num_frames] = piano_roll
    current_frame += num_frames
    processed_files += 1

# Step 2: Compute Pearson Correlation Matrix
P_pearson = np.zeros((n_pitches, n_pitches), dtype=np.float32)
for i in range(n_pitches):
    for j in range(i + 1, n_pitches):
        if np.sum(note_activations[i, :]) == 0 or np.sum(note_activations[j, :]) == 0:
            corr = 0.0
        else:
            corr, _ = pearsonr(note_activations[i, :], note_activations[j, :])
        P_pearson[i, j] = corr
        P_pearson[j, i] = corr

print("\nPearson Correlation Matrix Statistics:")
print(f"Min value: {P_pearson.min():.4f}")
print(f"Max value: {P_pearson.max():.4f}")
print(f"Mean value: {P_pearson.mean():.4f}")

# Step 3: Find optimal threshold based on sparsity
def get_adjacency_matrix(P, tau):
    A = (P >= tau).astype(np.int32)
    np.fill_diagonal(A, 0)
    A_symmetric = np.logical_or(A, A.T).astype(np.int32)
    np.fill_diagonal(A_symmetric, 0)
    n_edges = np.sum(A_symmetric) // 2
    sparsity = np.sum(A_symmetric == 0) / (n_pitches * n_pitches)
    return A_symmetric, n_edges, sparsity

best_tau = None
best_A = None
best_edges = 0
best_sparsity = 0
for tau in threshold_range:
    A_pearson, edges_pearson, sparsity_pearson = get_adjacency_matrix(P_pearson, tau)
    if target_sparsity[0] <= sparsity_pearson <= target_sparsity[1]:
        best_tau = tau
        best_A = A_pearson
        best_edges = edges_pearson
        best_sparsity = sparsity_pearson
        break
    elif abs(sparsity_pearson - (target_sparsity[0] + target_sparsity[1]) / 2) < abs(best_sparsity - (target_sparsity[0] + target_sparsity[1]) / 2):
        best_tau = tau
        best_A = A_pearson
        best_edges = edges_pearson
        best_sparsity = sparsity_pearson

# Step 4: Validate final adjacency matrix
def validate_adjacency_matrix(A, label="", tau_used=None):
    n_nodes = A.shape[0]
    n_edges = np.sum(A) // 2
    sparsity = np.sum(A == 0) / (n_nodes * n_nodes)
    is_symmetric = np.array_equal(A, A.T)

    print(f"\nFinal Adjacency Matrix Validation ({label}, tau = {tau_used:.3f}):")
    print(f"Total number of nodes: {n_nodes} (should be 88)")
    print(f"Total number of edges: {n_edges}")
    print(f"Sparsity: {sparsity:.4f}")
    print(f"Is symmetric: {is_symmetric} (should be True)")

validate_adjacency_matrix(best_A, label="pearson", tau_used=best_tau)

# Convert to TensorFlow tensor
A_tf = tf.convert_to_tensor(best_A, dtype=tf.float32)

# Save adjacency matrix
np.save("adjacency_matrix_final.npy", best_A)
print("Adjacency matrix saved as adjacency_matrix_final.npy")

Found 37993 training piano-roll segments
Processing up to 37993 segments for efficiency
Processing piano-roll 0/37993
Processing piano-roll 1000/37993
Processing piano-roll 2000/37993
Processing piano-roll 3000/37993
Processing piano-roll 4000/37993
Processing piano-roll 5000/37993
Processing piano-roll 6000/37993
Processing piano-roll 7000/37993
Processing piano-roll 8000/37993
Processing piano-roll 9000/37993
Processing piano-roll 10000/37993
Processing piano-roll 11000/37993
Processing piano-roll 12000/37993
Processing piano-roll 13000/37993
Processing piano-roll 14000/37993
Processing piano-roll 15000/37993
Processing piano-roll 16000/37993
Processing piano-roll 17000/37993
Processing piano-roll 18000/37993
Processing piano-roll 19000/37993
Processing piano-roll 20000/37993
Processing piano-roll 21000/37993
Processing piano-roll 22000/37993
Processing piano-roll 23000/37993
Processing piano-roll 24000/37993
Processing piano-roll 25000/37993
Processing piano-roll 26000/37993
Process

In [38]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, metrics, callbacks
from pathlib import Path
import mir_eval
from sklearn.metrics import precision_recall_fscore_support

# Parameters
n_pitches = 88
n_mels = 229
frames_per_segment = 480
d = 88
d_prime = 768
hop_length = 512
sr = 22050

# Custom Learning Rate Schedule with Warmup
class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, target_lr, warmup_steps, decay_steps, decay_rate):
        super(WarmupDecaySchedule, self).__init__()
        self.initial_lr = tf.cast(initial_lr, tf.float32)
        self.target_lr = tf.cast(target_lr, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)
        self.decay_steps = tf.cast(decay_steps, tf.float32)
        self.decay_rate = tf.cast(decay_rate, tf.float32)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_lr = self.initial_lr + (self.target_lr - self.initial_lr) * (step / self.warmup_steps)
        warmup_lr = tf.minimum(self.target_lr, warmup_lr)
        decay_steps = tf.maximum(step - self.warmup_steps, 0.0)
        decay_lr = self.target_lr * tf.pow(self.decay_rate, decay_steps / self.decay_steps)
        lr = tf.where(step <= self.warmup_steps, warmup_lr, decay_lr)
        return lr

    def get_config(self):
        return {
            "initial_lr": self.initial_lr,
            "target_lr": self.target_lr,
            "warmup_steps": self.warmup_steps,
            "decay_steps": self.decay_steps,
            "decay_rate": self.decay_rate
        }

# Load and normalize adjacency matrix
A = np.load("adjacency_matrix_final.npy")
A = A + np.eye(n_pitches) + 0.5
row_sums = np.sum(A, axis=1, keepdims=True)
row_sums = np.where(row_sums == 0, 1e-8, row_sums)
A = A / row_sums
A = np.nan_to_num(A, nan=0.0)
A_tf = tf.convert_to_tensor(A, dtype=tf.float32)
if np.any(np.isnan(A_tf.numpy())):
    raise ValueError("Adjacency matrix A_tf contains NaN values after normalization")

# Load precomputed mel-spectrograms and piano-rolls
mel_train_dir = Path("./preprocessed_output_all_480_minimal/mel_spectrograms/train")
piano_roll_train_dir = Path("./preprocessed_output_all_480_minimal/piano_rolls/train")
mel_files = sorted(list(mel_train_dir.glob("*.npy")))
piano_roll_files = sorted(list(piano_roll_train_dir.glob("*_piano_roll.npy")))

print(f"Number of mel files: {len(mel_files)}, Number of piano-roll files: {len(piano_roll_files)}")
if not mel_files or not piano_roll_files:
    raise ValueError("No files found in mel_spectrograms or piano_rolls directories.")
if len(mel_files) != len(piano_roll_files):
    raise ValueError(f"Mismatch in number of mel files ({len(mel_files)}) and piano-roll files ({len(piano_roll_files)}).")

spectrograms = []
labels = []
for mel_file, pr_file in zip(mel_files[:1000], piano_roll_files[:1000]):
    mel_spec = np.load(mel_file)
    if mel_spec.shape != (n_mels, frames_per_segment):
        raise ValueError(f"Mel-spectrogram shape {mel_spec.shape} does not match expected ({n_mels}, {frames_per_segment})")
    mel_spec = (mel_spec - mel_spec.min()) / (mel_spec.max() - mel_spec.min() + 1e-8)
    spectrograms.append(mel_spec)
    pr = np.load(pr_file)
    if pr.size == 0 or np.all(pr == 0):
        print(f"Warning: Skipping invalid piano-roll file {pr_file} (empty or all zeros)")
        continue
    pr = (pr > 0).astype(np.float32)
    labels.append(pr)

spectrograms = np.array(spectrograms)
labels = np.array(labels)

print(f"Number of loaded spectrograms: {len(spectrograms)}, Number of loaded labels: {len(labels)}")
if len(spectrograms) == 0 or len(labels) == 0:
    raise ValueError("No valid samples loaded. Check your data files or subset size.")

num_positives = np.sum(labels == 1)
num_negatives = np.sum(labels == 0)
print(f"Number of positive labels: {num_positives}, Number of negative labels: {num_negatives}")
if num_positives == 0:
    print("Warning: No positive labels found in the dataset. Setting pos_weight to a default value.")
    pos_weight = 2.0
else:
    pos_weight = num_negatives / num_positives
    pos_weight = min(pos_weight * 0.5, 2.0)
print(f"Computed pos_weight: {pos_weight}")

# Custom Layers
class ClipLayer(layers.Layer):
    def __init__(self, min_value, max_value, **kwargs):
        super(ClipLayer, self).__init__(**kwargs)
        self.min_value = min_value
        self.max_value = max_value

    def call(self, inputs):
        return tf.clip_by_value(inputs, self.min_value, self.max_value)

class TransposeLayer(layers.Layer):
    def __init__(self, perm, **kwargs):
        super(TransposeLayer, self).__init__(**kwargs)
        self.perm = perm

    def call(self, inputs):
        return tf.transpose(inputs, perm=self.perm)

class DotProductLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(DotProductLayer, self).__init__(**kwargs)

    def call(self, inputs):
        W, x = inputs
        return tf.matmul(W, x)

# Custom Focal Loss
def focal_loss_with_class_weight(gamma=1.5, alpha=0.8, pos_weight=2.0):
    def focal_loss_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1.0 - tf.keras.backend.epsilon())
        cross_entropy = -y_true * tf.math.log(y_pred) - (1 - y_true) * tf.math.log(1 - y_pred)
        class_weight = y_true * pos_weight + (1 - y_true) * 1.0
        weight = alpha * y_true * tf.pow(1 - y_pred, gamma) + (1 - alpha) * (1 - y_true) * tf.pow(y_pred, gamma)
        return tf.reduce_mean(class_weight * weight * cross_entropy)
    return focal_loss_fixed

# Custom Metrics
class Precision(metrics.Metric):
    def __init__(self, threshold=0.3, name='precision', **kwargs):
        super(Precision, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        self.true_positives = self.add_weight(name='tp', initializer='zeros')
        self.false_positives = self.add_weight(name='fp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred > self.threshold, tf.bool)
        tp = tf.reduce_sum(tf.cast(tf.logical_and(y_true, y_pred), tf.float32))
        fp = tf.reduce_sum(tf.cast(tf.logical_and(tf.logical_not(y_true), y_pred), tf.float32))
        self.true_positives.assign_add(tp)
        self.false_positives.assign_add(fp)

    def result(self):
        return self.true_positives / (self.true_positives + self.false_positives + tf.keras.backend.epsilon())

    def reset_state(self):
        self.true_positives.assign(0.)
        self.false_positives.assign(0.)

class Recall(metrics.Metric):
    def __init__(self, threshold=0.3, name='recall', **kwargs):
        super(Recall, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        self.true_positives = self.add_weight(name='tp', initializer='zeros')
        self.false_negatives = self.add_weight(name='fn', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred > self.threshold, tf.bool)
        tp = tf.reduce_sum(tf.cast(tf.logical_and(y_true, y_pred), tf.float32))
        fn = tf.reduce_sum(tf.cast(tf.logical_and(y_true, tf.logical_not(y_pred)), tf.float32))
        self.true_positives.assign_add(tp)
        self.false_negatives.assign_add(fn)

    def result(self):
        return self.true_positives / (self.true_positives + self.false_negatives + tf.keras.backend.epsilon())

    def reset_state(self):
        self.true_positives.assign(0.)
        self.false_negatives.assign(0.)

class F1Score(metrics.Metric):
    def __init__(self, threshold=0.3, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        self.precision = Precision(threshold=self.threshold)
        self.recall = Recall(threshold=self.threshold)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)

    def result(self):
        p = self.precision.result()
        r = self.recall.result()
        return 2 * ((p * r) / (p + r + tf.keras.backend.epsilon()))

    def reset_state(self):
        self.precision.reset_state()
        self.recall.reset_state()

# Build CNN+LSTM branch
def build_cnn_lstm():
    inputs = layers.Input(shape=(n_mels, frames_per_segment, 1))
    x = layers.Conv2D(32, (3, 1), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 1))(x)
    x = layers.Conv2D(57, (3, 1), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 1))(x)
    x = layers.Reshape((frames_per_segment, -1))(x)
    x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)
    x = layers.Dropout(0.5)(x)
    features = layers.Dense(d_prime, activation=None)(x)
    features = ClipLayer(min_value=-1e6, max_value=1e6)(features)
    pred = layers.Dense(n_pitches, activation='sigmoid')(features)
    pred = TransposeLayer(perm=[0, 2, 1])(pred)
    return Model(inputs=inputs, outputs={'cnn_lstm_features': features, 'cnn_lstm_pred': pred}, name="cnn_lstm")

# Custom GCN Layer
class GCNLayer(layers.Layer):
    def __init__(self, units, activation="relu", **kwargs):
        super(GCNLayer, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.dense = layers.Dense(units, activation=activation, use_bias=True)
        self.layer_norm = layers.LayerNormalization()

    def call(self, inputs):
        node_features, adjacency = inputs
        x = self.dense(node_features)
        x_gcn = tf.matmul(adjacency, x)
        x = x + x_gcn
        x = self.layer_norm(x)
        x = ClipLayer(min_value=-1e6, max_value=1e6)(x)
        return x

# Build GCN branch
def build_gcn(scale):
    node_features = layers.Input(shape=(n_pitches, d), name="node_features")
    adjacency = layers.Input(shape=(n_pitches, n_pitches), name="adjacency", dtype=tf.float32)
    x = GCNLayer(d_prime)([node_features, adjacency])
    x = GCNLayer(d_prime)([x, adjacency])
    x = GCNLayer(d_prime)([x, adjacency])
    W = x
    W = layers.Dense(d_prime, activation='relu')(W)
    W = W * scale
    tf.print("GCN Output W (first batch, first 10 values): ", W[0, :10, 0])
    return Model(inputs=[node_features, adjacency], outputs={'gcn_pred': W}, name="gcn")

# Combined CR-GCN model
class CRGCN(Model):
    def __init__(self, A):
        super(CRGCN, self).__init__()
        self.cnn_lstm = build_cnn_lstm()
        self.scale = tf.Variable(1.0, trainable=True, name='gcn_output_scale', dtype=tf.float32)  # Increased initial value
        self.gcn = build_gcn(self.scale)
        self.A = A
        self.transpose_layer = TransposeLayer(perm=[0, 2, 1])
        self.dot_product_layer = DotProductLayer()
        self.H0 = self.add_weight(
            name="H0",
            shape=(n_pitches, d),
            initializer=tf.keras.initializers.RandomNormal(mean=0.1, stddev=0.05),
            trainable=True
        )
        self.H0_transform = layers.Dense(d, activation='relu')
        self.feature_projection = layers.Dense(n_pitches, activation='relu')
        self.feature_projection_d_prime = layers.Dense(d_prime, activation='relu')

    def call(self, inputs, training=False):
        spectrogram = inputs
        cnn_lstm_outputs = self.cnn_lstm(spectrogram)
        x = cnn_lstm_outputs['cnn_lstm_features']
        cnn_lstm_pred = cnn_lstm_outputs['cnn_lstm_pred']
        cnn_lstm_features_projected = self.feature_projection(x)
        cnn_lstm_features_projected = tf.transpose(cnn_lstm_features_projected, perm=[0, 2, 1])
        cnn_lstm_features_projected = self.feature_projection_d_prime(cnn_lstm_features_projected)
        
        x = self.transpose_layer(x)
        batch_size = tf.shape(spectrogram)[0]
        H0_expanded = tf.expand_dims(self.H0, 0)
        H0_tiled = tf.tile(H0_expanded, [batch_size, 1, 1])
        H0_transformed = self.H0_transform(H0_tiled)
        A_tiled = tf.tile(tf.expand_dims(self.A, 0), [batch_size, 1, 1])
        gcn_outputs = self.gcn([H0_transformed, A_tiled])
        W = gcn_outputs['gcn_pred']
        W = W + cnn_lstm_features_projected
        W = tf.debugging.check_numerics(W, "GCN output contains NaN or Inf")
        gcn_pred = self.dot_product_layer([W, x])
        gcn_pred = layers.Activation('sigmoid')(gcn_pred)
        
        if training:
            gcn_loss = tf.reduce_mean(tf.square(W)) * 0.01  # Auxiliary loss
            gcn_loss = tf.debugging.check_numerics(gcn_loss, "Auxiliary GCN loss contains NaN or Inf")
            self.add_loss(gcn_loss)
        return gcn_pred

    def build(self, input_shape):
        super(CRGCN, self).build(input_shape)
        self.cnn_lstm.build(input_shape)
        self.gcn.build([(None, n_pitches, d), (None, n_pitches, n_pitches)])

# Instantiate and compile
model = CRGCN(A_tf)
model.build(input_shape=(None, n_mels, frames_per_segment, 1))

# Learning rate schedule
steps_per_epoch = len(spectrograms) // 64
warmup_epochs = 15
warmup_steps = steps_per_epoch * warmup_epochs
lr_schedule = WarmupDecaySchedule(
    initial_lr=0.00001,
    target_lr=0.0001,
    warmup_steps=warmup_steps,
    decay_steps=3000,
    decay_rate=0.95
)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipnorm=1.0)

model.compile(
    optimizer=optimizer,
    loss=focal_loss_with_class_weight(gamma=1.5, alpha=0.8, pos_weight=2.0),
    metrics=[Precision(threshold=0.3), Recall(threshold=0.3), F1Score(threshold=0.3), 'accuracy']
)

# Debugging callback
class DebugCallback(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        batch = spectrograms[:16][..., np.newaxis]
        with tf.GradientTape() as tape:
            predictions = self.model(batch, training=True)
            loss = self.model.compute_loss(batch, labels[:16], predictions, None, training=True)
        grads = tape.gradient(loss, self.model.trainable_variables)
        gcn_grads = [g for g, v in zip(grads, self.model.trainable_variables) if 'gcn' in v.name]
        print(f"\nEpoch {epoch + 1} - Prediction Stats:")
        preds = self.model.predict(batch, verbose=0)
        print(f"Mean prediction: {np.mean(preds):.4f}")
        print(f"Min prediction: {np.min(preds):.4f}")
        print(f"Max prediction: {np.max(preds):.4f}")
        print(f"Percentage of predictions > 0.5: {np.mean(preds > 0.5) * 100:.2f}%")
        print(f"Percentage of predictions > 0.3: {np.mean(preds > 0.3) * 100:.2f}%")
        print("GCN Gradients Present:", any(g is not None for g in gcn_grads))

# Callbacks
checkpoint = callbacks.ModelCheckpoint(
    'best_model.keras',
    monitor='val_f1_score',
    mode='max',
    save_best_only=True,
    verbose=1
)

early_stopping = callbacks.EarlyStopping(
    monitor='val_f1_score',
    patience=20,
    mode='max',
    restore_best_weights=True,
    verbose=1
)

# Train for 100 epochs
spectrograms = spectrograms[..., np.newaxis]
history = model.fit(
    spectrograms,
    labels,
    epochs=100,
    batch_size=64,
    validation_split=0.2,
    callbacks=[DebugCallback(), checkpoint, early_stopping],
    verbose=1
)

# Updated post-processing functions
def smooth_predictions(predictions, kernel_size=3):
    kernel = np.ones(kernel_size) / kernel_size
    smoothed = np.zeros_like(predictions)
    for b in range(predictions.shape[0]):
        for p in range(predictions.shape[1]):
            smoothed[b, p, :] = np.convolve(predictions[b, p, :], kernel, mode='same')
    return smoothed

def filter_short_notes(predictions, threshold=0.3, min_duration_frames=12):
    filtered = np.zeros_like(predictions)
    for b in range(predictions.shape[0]):
        for p in range(predictions.shape[1]):
            binary = predictions[b, p, :] > threshold
            diff = np.diff(binary, prepend=0)
            onsets = np.where(diff == 1)[0]
            offsets = np.where(diff == -1)[0]
            for onset, offset in zip(onsets, offsets):
                if offset - onset >= min_duration_frames:
                    filtered[b, p, onset:offset] = 1.0
    return filtered

def piano_roll_to_notes(piano_roll, fs=22050, hop_length=512, threshold=0.3, min_duration=0.05):
    notes = []
    piano_roll = np.pad(piano_roll, ((0, 0), (1, 0)), mode='constant')
    for pitch in range(n_pitches):
        onsets = np.where(np.diff(piano_roll[pitch] > threshold, prepend=False) == 1)[0]
        offsets = np.where(np.diff(piano_roll[pitch] > threshold, prepend=False) == -1)[0]
        for onset, offset in zip(onsets, offsets):
            onset_time = onset * hop_length / fs
            offset_time = offset * hop_length / fs
            duration = offset_time - onset_time
            if duration >= min_duration and offset_time > onset_time:
                notes.append((onset_time, offset_time, pitch + 21))
    return np.array(notes) if notes else np.empty((0, 3))

# Find optimal threshold
val_indices = np.random.choice(len(spectrograms), int(0.2 * len(spectrograms)), replace=False)
val_spectrograms = spectrograms[val_indices]
val_labels = labels[val_indices]

val_predictions = model.predict(val_spectrograms)
val_predictions_probs = smooth_predictions(val_predictions)
val_predictions_filtered = filter_short_notes(val_predictions_probs)

thresholds = np.arange(0.2, 0.5, 0.05)
best_f1 = 0
best_threshold = 0.3
for thresh in thresholds:
    val_preds_binary = (val_predictions_filtered > thresh).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(val_labels.flatten(), val_preds_binary.flatten(), average='binary', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = thresh

print(f"Optimal threshold: {best_threshold}, F1-score: {best_f1}")

# Evaluate
val_predictions_binary = (val_predictions_filtered > best_threshold).astype(int)
val_labels_flat = val_labels.flatten()
val_preds_flat = val_predictions_binary.flatten()
precision, recall, f1, _ = precision_recall_fscore_support(val_labels_flat, val_preds_flat, average='binary', zero_division=0)
print(f"Frame-level Precision: {precision:.4f}")
print(f"Frame-level Recall: {recall:.4f}")
print(f"Frame-level F1-score: {f1:.4f}")

ref_notes = [piano_roll_to_notes(pr, fs=sr, hop_length=hop_length, threshold=best_threshold, min_duration=0.05) for pr in val_labels]
est_notes = [piano_roll_to_notes(pr, fs=sr, hop_length=hop_length, threshold=best_threshold, min_duration=0.05) for pr in val_predictions_binary]
ref_intervals = [notes[:, :2] if len(notes) > 0 else np.empty((0, 2)) for notes in ref_notes]
ref_pitches = [notes[:, 2] if len(notes) > 0 else np.array([]) for notes in ref_notes]
est_intervals = [notes[:, :2] if len(notes) > 0 else np.empty((0, 2)) for notes in est_notes]
est_pitches = [notes[:, 2] if len(notes) > 0 else np.array([]) for notes in est_notes]

note_precision, note_recall, note_f1 = 0, 0, 0
num_samples = len(ref_intervals)
valid_samples = 0
for i in range(num_samples):
    if len(ref_intervals[i]) == 0 and len(est_intervals[i]) == 0:
        continue
    scores = mir_eval.transcription.evaluate(
        ref_intervals[i], ref_pitches[i], est_intervals[i], est_pitches[i],
        onset_tolerance=0.05, offset_ratio=0.2
    )
    note_precision += scores['Precision']
    note_recall += scores['Recall']
    note_f1 += scores['F-measure']
    valid_samples += 1

if valid_samples > 0:
    note_precision /= valid_samples
    note_recall /= valid_samples
    note_f1 /= valid_samples
else:
    note_precision, note_recall, note_f1 = 0.0, 0.0, 0.0

print(f"Note-level Precision: {note_precision:.4f}")
print(f"Note-level Recall: {note_recall:.4f}")
print(f"Note-level F1-score: {note_f1:.4f}")

model.save("cr_gcn_model.keras")

Number of mel files: 37993, Number of piano-roll files: 37993
Number of loaded spectrograms: 1000, Number of loaded labels: 1000
Number of positive labels: 2450463, Number of negative labels: 39789537
Computed pos_weight: 2.0
GCN Output W (first batch, first 10 values):  <KerasTensor shape=(10,), dtype=float32, sparse=False, name=keras_tensor_602>
Epoch 1/100




[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0020 - f1_score: 0.1081 - loss: 0.1120 - precision: 0.0580 - recall: 0.7892
Epoch 1 - Prediction Stats:
Mean prediction: 0.3845
Min prediction: 0.1660
Max prediction: 0.5244
Percentage of predictions > 0.5: 0.10%
Percentage of predictions > 0.3: 90.60%
GCN Gradients Present: False

Epoch 1: val_f1_score improved from -inf to 0.10794, saving model to best_model.keras
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 3s/step - accuracy: 0.0019 - f1_score: 0.1081 - loss: 0.1113 - precision: 0.0581 - recall: 0.7892 - val_accuracy: 4.5455e-04 - val_f1_score: 0.1079 - val_loss: 0.0661 - val_precision: 0.0572 - val_recall: 0.9467
Epoch 2/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0022 - f1_score: 0.1087 - loss: 0.0811 - precision: 0.0581 - recall: 0.8483
Epoch 2 - Prediction Stats:
Mean prediction: 0.4516
Min prediction: 0.2878
Max prediction: 0.5

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0068 - f1_score: 0.1456 - loss: 0.0541 - precision: 0.0787 - recall: 0.9713
Epoch 12 - Prediction Stats:
Mean prediction: 0.3597
Min prediction: 0.0128
Max prediction: 0.5698
Percentage of predictions > 0.5: 9.62%
Percentage of predictions > 0.3: 70.67%
GCN Gradients Present: False

Epoch 12: val_f1_score improved from 0.13860 to 0.14669, saving model to best_model.keras
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 3s/step - accuracy: 0.0070 - f1_score: 0.1459 - loss: 0.0542 - precision: 0.0789 - recall: 0.9711 - val_accuracy: 0.0034 - val_f1_score: 0.1467 - val_loss: 0.0527 - val_precision: 0.0794 - val_recall: 0.9621
Epoch 13/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0140 - f1_score: 0.1521 - loss: 0.0555 - precision: 0.0825 - recall: 0.9667
Epoch 13 - Prediction Stats:
Mean prediction: 0.3697
Min prediction: 0.0100
Max prediction: 

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 3s/step - accuracy: 0.0276 - f1_score: 0.1553 - loss: 0.0528 - precision: 0.0845 - recall: 0.9625 - val_accuracy: 0.0264 - val_f1_score: 0.1450 - val_loss: 0.0527 - val_precision: 0.0783 - val_recall: 0.9736
Epoch 24/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0397 - f1_score: 0.1523 - loss: 0.0526 - precision: 0.0826 - recall: 0.9697
Epoch 24 - Prediction Stats:
Mean prediction: 0.3739
Min prediction: 0.0174
Max prediction: 0.5415
Percentage of predictions > 0.5: 31.48%
Percentage of predictions > 0.3: 71.73%
GCN Gradients Present: False

Epoch 24: val_f1_score did not improve from 0.15227
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 3s/step - accuracy: 0.0395 - f1_score: 0.1523 - loss: 0.0526 - precision: 0.0826 - recall: 0.9696 - val_accuracy: 0.0082 - val_f1_score: 0.1478 - val_loss: 0.0525 - val_precision: 0.0800 - val_recall: 0.9668
Epoch 25/100
[1m1

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0247 - f1_score: 0.1543 - loss: 0.0528 - precision: 0.0839 - recall: 0.9628
Epoch 35 - Prediction Stats:
Mean prediction: 0.3960
Min prediction: 0.0238
Max prediction: 0.5583
Percentage of predictions > 0.5: 35.46%
Percentage of predictions > 0.3: 77.45%
GCN Gradients Present: False

Epoch 35: val_f1_score did not improve from 0.15317
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 3s/step - accuracy: 0.0247 - f1_score: 0.1544 - loss: 0.0528 - precision: 0.0839 - recall: 0.9630 - val_accuracy: 0.0261 - val_f1_score: 0.1399 - val_loss: 0.0529 - val_precision: 0.0753 - val_recall: 0.9843
Epoch 36/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0384 - f1_score: 0.1504 - loss: 0.0531 - precision: 0.0815 - recall: 0.9746
Epoch 36 - Prediction Stats:
Mean prediction: 0.3783
Min prediction: 0.0168
Max prediction: 0.5493
Percentage of predictions > 0.

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 3s/step - accuracy: 0.0605 - f1_score: 0.1549 - loss: 0.0526 - precision: 0.0842 - recall: 0.9710 - val_accuracy: 0.0041 - val_f1_score: 0.1498 - val_loss: 0.0522 - val_precision: 0.0812 - val_recall: 0.9661
Epoch 47/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0394 - f1_score: 0.1532 - loss: 0.0521 - precision: 0.0832 - recall: 0.9665
Epoch 47 - Prediction Stats:
Mean prediction: 0.3673
Min prediction: 0.0129
Max prediction: 0.5551
Percentage of predictions > 0.5: 35.69%
Percentage of predictions > 0.3: 66.75%
GCN Gradients Present: False

Epoch 47: val_f1_score improved from 0.15474 to 0.15518, saving model to best_model.keras
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 3s/step - accuracy: 0.0399 - f1_score: 0.1533 - loss: 0.0522 - precision: 0.0832 - recall: 0.9668 - val_accuracy: 0.0082 - val_f1_score: 0.1552 - val_loss: 0.0522 - val_precision: 0.0844 -

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0797 - f1_score: 0.1550 - loss: 0.0517 - precision: 0.0842 - recall: 0.9741
Epoch 58 - Prediction Stats:
Mean prediction: 0.3712
Min prediction: 0.0216
Max prediction: 0.5604
Percentage of predictions > 0.5: 30.29%
Percentage of predictions > 0.3: 69.55%
GCN Gradients Present: False

Epoch 58: val_f1_score did not improve from 0.15518
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 3s/step - accuracy: 0.0798 - f1_score: 0.1550 - loss: 0.0517 - precision: 0.0842 - recall: 0.9739 - val_accuracy: 0.0694 - val_f1_score: 0.1506 - val_loss: 0.0518 - val_precision: 0.0817 - val_recall: 0.9664
Epoch 59/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0950 - f1_score: 0.1576 - loss: 0.0522 - precision: 0.0858 - recall: 0.9694
Epoch 59 - Prediction Stats:
Mean prediction: 0.3776
Min prediction: 0.0194
Max prediction: 0.5633
Percentage of predictions > 0.

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 3s/step - accuracy: 0.0837 - f1_score: 0.1661 - loss: 0.0508 - precision: 0.0909 - recall: 0.9651 - val_accuracy: 0.1056 - val_f1_score: 0.1627 - val_loss: 0.0500 - val_precision: 0.0889 - val_recall: 0.9599
Epoch 70/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.0894 - f1_score: 0.1656 - loss: 0.0505 - precision: 0.0905 - recall: 0.9679
Epoch 70 - Prediction Stats:
Mean prediction: 0.3608
Min prediction: 0.0147
Max prediction: 0.6636
Percentage of predictions > 0.5: 27.36%
Percentage of predictions > 0.3: 66.39%
GCN Gradients Present: False

Epoch 70: val_f1_score improved from 0.16268 to 0.16441, saving model to best_model.keras
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 3s/step - accuracy: 0.0892 - f1_score: 0.1657 - loss: 0.0504 - precision: 0.0906 - recall: 0.9680 - val_accuracy: 0.1595 - val_f1_score: 0.1644 - val_loss: 0.0499 - val_precision: 0.0900 -

Epoch 81/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1610 - f1_score: 0.1742 - loss: 0.0477 - precision: 0.0957 - recall: 0.9672
Epoch 81 - Prediction Stats:
Mean prediction: 0.3546
Min prediction: 0.0227
Max prediction: 0.7062
Percentage of predictions > 0.5: 25.79%
Percentage of predictions > 0.3: 64.68%
GCN Gradients Present: False

Epoch 81: val_f1_score did not improve from 0.17408
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 3s/step - accuracy: 0.1604 - f1_score: 0.1743 - loss: 0.0477 - precision: 0.0958 - recall: 0.9672 - val_accuracy: 0.3738 - val_f1_score: 0.1735 - val_loss: 0.0478 - val_precision: 0.0954 - val_recall: 0.9542
Epoch 82/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1723 - f1_score: 0.1776 - loss: 0.0484 - precision: 0.0978 - recall: 0.9679
Epoch 82 - Prediction Stats:
Mean prediction: 0.3394
Min prediction: 0.0104
Max prediction: 0.7105
Percentage of pre

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 3s/step - accuracy: 0.1262 - f1_score: 0.1859 - loss: 0.0464 - precision: 0.1029 - recall: 0.9629 - val_accuracy: 0.3579 - val_f1_score: 0.1863 - val_loss: 0.0462 - val_precision: 0.1034 - val_recall: 0.9437
Epoch 93/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1478 - f1_score: 0.1857 - loss: 0.0455 - precision: 0.1028 - recall: 0.9647
Epoch 93 - Prediction Stats:
Mean prediction: 0.3520
Min prediction: 0.0178
Max prediction: 0.7157
Percentage of predictions > 0.5: 22.19%
Percentage of predictions > 0.3: 63.62%
GCN Gradients Present: False

Epoch 93: val_f1_score did not improve from 0.18927
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 3s/step - accuracy: 0.1480 - f1_score: 0.1857 - loss: 0.0455 - precision: 0.1027 - recall: 0.9648 - val_accuracy: 0.3899 - val_f1_score: 0.1770 - val_loss: 0.0457 - val_precision: 0.0975 - val_recall: 0.9614
Epoch 94/100
[1m1

ValueError: Exception encountered when calling CRGCN.call().

[1mas_list() is not defined on an unknown TensorShape.[0m

Arguments received by CRGCN.call():
  • inputs=tf.Tensor(shape=<unknown>, dtype=float32)
  • training=False