# Data Exploration

Explore the preprocessed data (see [data_preprocessing](./data_preprocessing.ipynb)) in its raw form and for some of its feature transforms.

In [None]:
# Autoreloading makes development easier
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import librosa as lb
import matplotlib.pyplot as plt
from tools.audio_tools import read_audio, play_audio, write_audio
from tools.feature_tools import compute_stft, compute_istft, compute_mels, compute_imels, compute_mfcc, compute_imfcc
from tools.constants import cut_classical_path, cut_jazz_path, default_sample_rate
from tools.plot_tools import plot_spectral_feature, plot_audio, make_figax

In [None]:
results_dir = "./results/DataExploration"
os.makedirs(results_dir, exist_ok=True)
epsilon = 1e-6

In [None]:
# Some constants
transform_labels = ["STFT", "MELS", "MFCC"]
transform_transforms = [compute_stft, compute_mels, compute_mfcc]
transform_inverse_transforms = [compute_istft, compute_imels, compute_imfcc]

genre_labels = ["CLASSICAL", "JAZZ"]
genre_paths = [cut_classical_path, cut_jazz_path]

fraction = 1/4

In [None]:
# Explore transform spaces for all genres
for genre_label, genre_path in zip(genre_labels, genre_paths):
    print(genre_label)

    # Pick a random audio track
    audio_file = np.random.choice(os.listdir(genre_path))
    audio_path = os.path.join(genre_path, audio_file)
    audio = read_audio(audio_path)
    time = np.linspace(0, len(audio)/default_sample_rate, len(audio))
    print(f"{audio_file = }")
    print(f"{len(audio) = }")

    # Time domain plot
    fig, ax = make_figax()
    ax.plot(time, audio)
    ax.set_xlabel("Time [s]")
    ax.set_ylabel("Audio Track")
    ax.set_title(f"{genre_label} AUDIO TRACK")
    ax.grid()
    fig.tight_layout()
    fig.savefig(os.path.join(results_dir, f"{genre_label.lower()}_time.png"), dpi=300, facecolor="white")
    plt.show()

    # Play audio
    print(f"{genre_label} PLAYER")
    player = play_audio(audio)
    
    # Write audio
    write_audio(audio, os.path.join(results_dir, f"{genre_label.lower()}_sample.wav"))

    # Transform to feature spaces and make plots
    for transform_label, transform_transform, transform_inverse_transform in zip(transform_labels, transform_transforms, transform_inverse_transforms):
        print(transform_label)

        # Compute transform
        transform = transform_transform(audio)
        k = int(transform.size * fraction)
        print(f"{transform.shape = }")
        print(f"{transform.size = }")
        print(f"transform element = {np.random.choice(transform.reshape(-1))}")

        # Plot Transform in 2D
        fn = lambda x: 20*np.log10(np.abs(x) + np.sqrt(epsilon))
        if transform_label == "MELS":
            fn = lambda x : 10 * np.log10(np.abs(x) + epsilon)
        fig, ax = plot_spectral_feature(transform, fn=fn)
        ax.set_title(f"{genre_label} {transform_label}")
        fig.tight_layout()
        fig.savefig(os.path.join(results_dir, f"{genre_label.lower()}_{transform_label.lower()}.png"), dpi=300, facecolor="white")
        plt.show()

        # Plot transform coefficient roloff
        transform_fn = fn(transform)
        transform_magnitudes = np.sort(transform_fn.reshape(-1))[::-1]
        fig, ax = make_figax()
        ax.plot(transform_magnitudes)
        ax.set_xlabel("Transform coefficient")
        ax.set_ylabel("Transform magnitude [dB]")
        ax.set_title(f"{genre_label} {transform_label} COEFFICIENTS")
        ax.grid()
        fig.tight_layout()
        fig.savefig(os.path.join(results_dir, f"{genre_label.lower()}_{transform_label.lower()}_coefficients.png"), dpi=300, facecolor="white")
        plt.show()
        
        # Plot transform singular values
        print(f"{len(transform.shape) = }") 
        singular_values = np.linalg.svd(transform, full_matrices=True, compute_uv=False, hermitian=False)
        fig, ax = make_figax()
        ax.plot(fn(np.sort(singular_values))[::-1])
        ax.set_xlabel(r"i")
        ax.set_ylabel("Singular Value $\sigma_i$ [dB]")
        ax.set_title(f"{genre_label} {transform_label} SINGULAR VALUES")
        ax.grid()
        fig.tight_layout()
        fig.savefig(os.path.join(results_dir, f"{genre_label.lower()}_{transform_label.lower()}_singular_values.png"), dpi=300, facecolor="white")
        plt.show()

        # Transform reconstruct & play
        print(f"INVERSE {transform_label} - {genre_label}")
        inverse = transform_inverse_transform(transform)
        play_audio(inverse)
        write_audio(audio, os.path.join(results_dir, f"{genre_label.lower()}_{transform_label.lower()}_inverse.wav"))

        # Transform reconstruct from top k & play
        print(f"INVERSE {transform_label} - {genre_label} - TOP K")
        print(f"Reconstructing from 1/{1/fraction:.2f} of the samples: {k = }")
        transform[transform_fn <= transform_magnitudes[k]] = 0.
        inverse = transform_inverse_transform(transform)
        play_audio(inverse)
        write_audio(audio, os.path.join(results_dir, f"{genre_label.lower()}_{transform_label.lower()}_compressed_inverse.wav"))