In [31]:
import sys
import os
sys.path.append(os.getcwd())

from ddc.learn.chart import Chart
import tensorflow as tf
import numpy as np
from ddc.learn.util import open_dataset_fps, flatten_dataset_to_charts, select_channels

In [32]:
class DiagnosticNet:
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def prepare_train_batch(self, charts, **kwargs):
        """Wrapper around the original prepare_train_batch to print shapes"""
        # Create a random batch
        feats_audio, feats_other, targets, target_weights = charts[0].prepare_train_batch(
            [charts[0]], 
            **kwargs
        )
        
        print("\nDiagnostic Batch Information:")
        print(f"Audio Features Shape: {feats_audio.shape}")
        print(f"Other Features Shape: {feats_other.shape}")
        print(f"Targets Shape: {targets.shape}")
        print(f"Target Weights Shape: {target_weights.shape}")
        
        print("\nData Statistics:")
        print(f"Audio Features - Mean: {np.mean(feats_audio):.4f}, Std: {np.std(feats_audio):.4f}")
        print(f"Other Features - Mean: {np.mean(feats_other):.4f}, Std: {np.std(feats_other):.4f}")
        print(f"Targets - Mean: {np.mean(targets):.4f} (Class balance)")
        print(f"Target Weights - Mean: {np.mean(target_weights):.4f}")
        
        # Look at a single example in detail
        print("\nSingle Example Details:")
        print(f"Audio time steps: {feats_audio.shape[2]}")
        print(f"Audio frequency bands: {feats_audio.shape[3]}")
        print(f"Audio channels: {feats_audio.shape[4]}")
        
        return feats_audio, feats_other, targets, target_weights

def run_diagnostic(batch_size=256, audio_context_radius=7):
    """Run a complete diagnostic on the dataset"""
    print("Loading data...")
    train_txt_fp = "ddc/data/chart_onset/speirmix/mel80hop441/speirmix_train.txt"
    valid_txt_fp = "ddc/data/chart_onset/speirmix/mel80hop441/speirmix_valid.txt"
    test_txt_fp = "ddc/data/chart_onset/speirmix/mel80hop441/speirmix_test.txt"
    
    train_data, valid_data, test_data = open_dataset_fps(train_txt_fp, valid_txt_fp, test_txt_fp)
    
    print("\nDataset Overview:")
    print(f"Number of songs in dataset: {len(train_data)}")
    
    # Look at first song's metadata
    first_song = train_data[0]
    print("\nExample Song Metadata:")
    print(f"Metadata: {first_song[0]}")
    print(f"Features shape: {first_song[1].shape}")
    print(f"Number of charts: {len(first_song[2])}")
    
    # Flatten into charts
    charts = flatten_dataset_to_charts(train_data)
    print(f"\nTotal number of charts: {len(charts)}")
    
    # Create diagnostic network
    net = DiagnosticNet(batch_size)
    
    # Create a sample batch with default parameters
    print("\nCreating sample batch...")
    feats_config = {
        'time_context_radius': audio_context_radius,
    }
    
    train_config = {
        'randomize_charts': False,
        'exclude_onset_neighbors': 0,
        'exclude_pre_onsets': False,
        'exclude_post_onsets': False,
        'include_onsets': True
    }
    
    batch_config = feats_config.copy()
    batch_config.update(train_config)
    
    # Get a batch and print diagnostics
    net.prepare_train_batch(charts, **batch_config)

if __name__ == "__main__":
    run_diagnostic()

Loading data...


ModuleNotFoundError: No module named 'chart'