### EDA Playground

- GUI 환경에서 dataset classes 선택시 해당하는 class별로 볼 수 있음
- anormaly - normal pair data를 찾아서 묶어준다음 idx별로 볼 수 있게 만듬.
- 스펙트로그램을 볼떄 n_fft 사이즈를 조절가능함. html로 따지면 range와 number input으로 조작가능
- 소리를 들을 수 있도록 재생 버튼도 있어야함


In [8]:
import os
import pandas as pd
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display, HTML
import warnings
import ipywidgets as widgets
from ipywidgets import interact, fixed
from skimage.metrics import structural_similarity as ssim
import base64
from io import BytesIO
import soundfile as sf
import uuid
from scipy.spatial.distance import cosine

# Suppress warnings from librosa (optional)
warnings.filterwarnings('ignore')

# -----------------------------
# Configuration Parameters
# -----------------------------
DATASETS_DIR = "../../datasets/dev"  # Path to the datasets directory
MATCHING_PAIRS_FILE = "all_class_matching_pairs.csv"  # Path to the matching pairs CSV
N_FFT = 160

# -----------------------------
# Step 1: Load Matching Pairs
# -----------------------------
def load_matching_pairs(matching_pairs_file):
    """
    Load the matching pairs from all_class_matching_pairs.csv.

    Args:
        matching_pairs_file (str): Path to the matching pairs CSV file.

    Returns:
        pd.DataFrame: DataFrame containing class, anomaly, normal, similarity, and method columns.
    """
    if not os.path.isfile(matching_pairs_file):
        raise FileNotFoundError(f"Matching pairs file not found: {matching_pairs_file}")

    df = pd.read_csv(matching_pairs_file)
    return df

# -----------------------------
# Step 2: Compute Spectrogram
# -----------------------------
def compute_spectrogram(y, n_fft, hop_length):
    """
    Compute the magnitude spectrogram of an audio signal.

    Args:
        y (np.ndarray): Audio time series.
        n_fft (int): FFT window size.
        hop_length (int): Number of samples between successive frames.

    Returns:
        np.ndarray: Magnitude spectrogram.
    """
    S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    S_mag = np.abs(S)
    return S_mag

# -----------------------------
# Step 3: Find Most Similar Normal Using Cosine Similarity
# -----------------------------
def find_most_similar_normal_cosine(anomaly_spectrogram, normal_spectrograms):
    """
    Find the most similar normal spectrogram to the anomaly spectrogram using cosine similarity.

    Args:
        anomaly_spectrogram (np.ndarray): Spectrogram of the anomaly file.
        normal_spectrograms (list of tuples): List of tuples containing normal spectrograms and their paths.

    Returns:
        tuple: (most_similar_normal_path, max_similarity)
    """
    max_similarity = -1
    most_similar_normal = None

    for normal_spectrogram, path in normal_spectrograms:
        # Compute cosine similarity
        similarity = 1 - cosine(anomaly_spectrogram.flatten(), normal_spectrogram.flatten())

        if similarity > max_similarity:
            max_similarity = similarity
            most_similar_normal = path

    return most_similar_normal, max_similarity

# -----------------------------
# Step 4: Plot and Play Normal and Matching Anomaly Spectrograms
# -----------------------------
def plot_and_play_normal_anomaly_pair(anom_path, norm_path, n_fft, hop_length):
    """
    Plot the spectrograms of the anomaly and normal files and provide audio playback controls.

    Args:
        anom_path (str): Path to the anomaly file.
        norm_path (str): Path to the matching normal file.
        n_fft (int): FFT window size.
        hop_length (int): Number of samples between successive frames.
    """
    y_anomaly, sr_anomaly = librosa.load(anom_path, sr=None)
    y_normal, sr_normal = librosa.load(norm_path, sr=None)

    duration_anomaly = len(y_anomaly) / sr_anomaly
    duration_normal = len(y_normal) / sr_normal

    S_anomaly = compute_spectrogram(y_anomaly, n_fft, hop_length)
    S_normal = compute_spectrogram(y_normal, n_fft, hop_length)

    # Convert spectrograms to decibel scale
    S_db_anomaly = librosa.amplitude_to_db(S_anomaly, ref=np.max)
    S_db_normal = librosa.amplitude_to_db(S_normal, ref=np.max)

    # Generate data URIs for spectrogram images
    def spectrogram_to_data_uri(S_db, hop_length, sr, duration, title):
        fig, ax = plt.subplots(figsize=(10, 3))

        # Plot spectrogram
        img = librosa.display.specshow(
            S_db,
            sr=sr,
            hop_length=hop_length,
            x_axis='time',
            y_axis='linear',
            ax=ax
        )

        # Set x-axis limits to match audio duration
        ax.set_xlim(0, duration)

        # Remove margins and ticks
        ax.set_title(title)
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.tick_params(axis='both', which='both', length=0)
        ax.set_yticklabels([])
        ax.set_xticklabels([])

        # Remove whitespace around the plot
        plt.tight_layout(pad=0)
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

        # Save the figure to a BytesIO object
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=100, bbox_inches='tight', pad_inches=0)
        buf.seek(0)
        # Close the figure
        plt.close(fig)
        # Encode to base64 string
        img_base64 = base64.b64encode(buf.read()).decode('utf-8')
        return 'data:image/png;base64,' + img_base64

    anomaly_img_uri = spectrogram_to_data_uri(S_db_anomaly, hop_length, sr_anomaly, duration_anomaly, 'Anomaly Spectrogram')
    normal_img_uri = spectrogram_to_data_uri(S_db_normal, hop_length, sr_normal, duration_normal, 'Normal Spectrogram')

    # Generate data URIs for audio
    def audio_to_data_uri(y, sr):
        buf = BytesIO()
        sf.write(buf, y, sr, format='WAV')
        buf.seek(0)
        # Encode to base64
        audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
        return 'data:audio/wav;base64,' + audio_base64

    anomaly_audio_uri = audio_to_data_uri(y_anomaly, sr_anomaly)
    normal_audio_uri = audio_to_data_uri(y_normal, sr_normal)

    # Generate unique IDs
    anomaly_id = 'anomaly_' + str(uuid.uuid4()).replace('-', '')
    normal_id = 'normal_' + str(uuid.uuid4()).replace('-', '')

    # Generate HTML
    html_template = """
    <div style="display: flex; flex-direction: column; align-items: center;">
        <label for="global_speed_select">재생 속도 (Playback Speed):</label>
        <select id="global_speed_select" style="margin-bottom: 10px;">
            <option value="0.5">0.5x</option>
            <option value="0.75">0.75x</option>
            <option value="1.0" selected>1.0x</option>
            <option value="1.25">1.25x</option>
            <option value="1.5">1.5x</option>
            <option value="2.0">2.0x</option>
        </select>

        <div style="display: flex; align-items: center; margin-bottom: 20px;">
            <div style="position: relative; display: inline-block;">
                <img src="{anomaly_img_uri}" id="img_{anomaly_id}" style="display: block; margin: 0; padding: 0;" />
                <canvas id="cursor_canvas_{anomaly_id}" style="position: absolute; top: 0; left: 0;"></canvas>
            </div>
            <audio id="audio_{anomaly_id}" controls style="margin-left: 10px;">
                <source src="{anomaly_audio_uri}" type="audio/wav">
                Your browser does not support the audio element.
            </audio>
        </div>

        <div style="display: flex; align-items: center;">
            <div style="position: relative; display: inline-block;">
                <img src="{normal_img_uri}" id="img_{normal_id}" style="display: block; margin: 0; padding: 0;" />
                <canvas id="cursor_canvas_{normal_id}" style="position: absolute; top: 0; left: 0;"></canvas>
            </div>
            <audio id="audio_{normal_id}" controls style="margin-left: 10px;">
                <source src="{normal_audio_uri}" type="audio/wav">
                Your browser does not support the audio element.
            </audio>
        </div>
    </div>
    <script>
    (function() {{
        var audioAnomaly = document.getElementById('audio_{anomaly_id}');
        var audioNormal = document.getElementById('audio_{normal_id}');
        var canvasAnomaly = document.getElementById('cursor_canvas_{anomaly_id}');
        var canvasNormal = document.getElementById('cursor_canvas_{normal_id}');
        var imgAnomaly = document.getElementById('img_{anomaly_id}');
        var imgNormal = document.getElementById('img_{normal_id}');
        var durationAnomaly = {duration_anomaly};
        var durationNormal = {duration_normal};

        var speedSelect = document.getElementById('global_speed_select');
        speedSelect.addEventListener('change', function() {{
            var playbackRate = parseFloat(this.value);
            audioAnomaly.playbackRate = playbackRate;
            audioNormal.playbackRate = playbackRate;
        }});

        function resizeCanvas(canvas, img) {{
            canvas.width = img.naturalWidth;
            canvas.height = img.naturalHeight;
            canvas.style.width = img.width + 'px';
            canvas.style.height = img.height + 'px';
        }}

        imgAnomaly.onload = function() {{
            resizeCanvas(canvasAnomaly, imgAnomaly);
        }};
        imgNormal.onload = function() {{
            resizeCanvas(canvasNormal, imgNormal);
        }};
        window.addEventListener('resize', function() {{
            resizeCanvas(canvasAnomaly, imgAnomaly);
            resizeCanvas(canvasNormal, imgNormal);
        }});

        function drawCursor(audio, canvas, duration) {{
            var ctx = canvas.getContext('2d');
            var currentTime = audio.currentTime;
            var x = (currentTime / duration) * canvas.width;
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            ctx.beginPath();
            ctx.moveTo(x, 0);
            ctx.lineTo(x, canvas.height);
            ctx.strokeStyle = 'red';
            ctx.lineWidth = 2;
            ctx.stroke();

            if (!audio.paused) {{
                requestAnimationFrame(function() {{
                    drawCursor(audio, canvas, duration);
                }});
            }}
        }}

        audioAnomaly.addEventListener('play', function() {{
            requestAnimationFrame(function() {{
                drawCursor(audioAnomaly, canvasAnomaly, durationAnomaly);
            }});
        }});
        audioNormal.addEventListener('play', function() {{
            requestAnimationFrame(function() {{
                drawCursor(audioNormal, canvasNormal, durationNormal);
            }});
        }});

        audioAnomaly.addEventListener('pause', function() {{
            canvasAnomaly.getContext('2d').clearRect(0, 0, canvasAnomaly.width, canvasAnomaly.height);
        }});
        audioNormal.addEventListener('pause', function() {{
            canvasNormal.getContext('2d').clearRect(0, 0, canvasNormal.width, canvasNormal.height);
        }});

        audioAnomaly.addEventListener('seeked', function() {{
            drawCursor(audioAnomaly, canvasAnomaly, durationAnomaly);
        }});
        audioNormal.addEventListener('seeked', function() {{
            drawCursor(audioNormal, canvasNormal, durationNormal);
        }});
    }})();
    </script>
    """

    # Generate combined HTML
    combined_html = html_template.format(
        anomaly_img_uri=anomaly_img_uri,
        normal_img_uri=normal_img_uri,
        anomaly_audio_uri=anomaly_audio_uri,
        normal_audio_uri=normal_audio_uri,
        anomaly_id=anomaly_id,
        normal_id=normal_id,
        duration_anomaly=duration_anomaly,
        duration_normal=duration_normal
    )

    # Display the HTML
    display(HTML(combined_html))

# -----------------------------
# Interactive Widgets and Main Execution Flow
# -----------------------------
def interactive_visualization(matching_pairs_df):
    # Get unique class names from the matching pairs DataFrame
    classes = matching_pairs_df['class'].unique()

    class_dropdown = widgets.Dropdown(
        options=classes,
        value=classes[0],
        description='Class:',
        disabled=False,
    )

    # Determine the maximum pair index per class
    def get_max_pair_index(class_name):
        class_df = matching_pairs_df[matching_pairs_df['class'] == class_name]
        return len(class_df) - 1 if len(class_df) > 0 else 0

    n_fft_slider = widgets.IntSlider(
        value=N_FFT,
        min=32,
        max=512,
        step=2,
        description='n_fft:',
        continuous_update=False
    )

    hop_length_ratio_dropdown = widgets.Dropdown(
        options=[1/2, 1/3, 1/4],
        value=1/2,
        description='Hop Ratio:',
        disabled=False
    )

    index_input = widgets.IntSlider(
        value=0,
        min=0,
        max=get_max_pair_index(class_dropdown.value),
        step=1,
        description='Pair Index:',
        continuous_update=False
    )

    # Update the maximum value of the pair index slider when class changes
    def update_index_slider(change):
        new_class = change['new']
        new_max = get_max_pair_index(new_class)
        index_input.max = new_max
        index_input.value = 0  # Reset to first pair

    class_dropdown.observe(update_index_slider, names='value')

    def visualize(class_name, n_fft, hop_ratio, pair_index):
        hop_length = int(n_fft * hop_ratio)
        class_df = matching_pairs_df[matching_pairs_df['class'] == class_name].reset_index(drop=True)

        if pair_index >= len(class_df):
            print("Pair index out of range.")
            return

        # Fetch the anomaly and normal paths from the matching pairs DataFrame
        anomaly_path = class_df.loc[pair_index, 'anomaly']
        normal_path = class_df.loc[pair_index, 'normal']
        similarity = class_df.loc[pair_index, 'similarity']
        method = class_df.loc[pair_index, 'method']

        print(f"Pair {pair_index}:")
        print(f"  Anomaly: {anomaly_path}")
        print(f"  Normal: {normal_path}")
        print(f"  Similarity: {similarity:.4f} (Method: {method})")

        plot_and_play_normal_anomaly_pair(anomaly_path, normal_path, n_fft, hop_length)

    interact(
        visualize,
        class_name=class_dropdown,
        n_fft=n_fft_slider,
        hop_ratio=hop_length_ratio_dropdown,
        pair_index=index_input
    )

def main():
    # Load matching pairs
    print("Loading matching pairs from CSV...")
    matching_pairs_df = load_matching_pairs(MATCHING_PAIRS_FILE)
    print(f"Total matching pairs loaded: {len(matching_pairs_df)}")

    # Interactive visualization
    interactive_visualization(matching_pairs_df)


main()


Loading matching pairs from CSV...
Total matching pairs loaded: 700


interactive(children=(Dropdown(description='Class:', options=('ToyCar', 'gearbox', 'valve', 'bearing', 'slider…