In [1]:
# imports and settings

import os
import time
import pickle
import warnings
import pandas as pd
import networkx as nx
from networkx.algorithms.coloring import greedy_color
import matplotlib.pyplot as plt
from copy import deepcopy

import numpy as np
from numpy import linalg as LA
from numpy import histogram2d

from scipy import signal
from scipy.fft import fft, fftfreq, fftshift
from scipy.signal import find_peaks, butter, filtfilt, welch, get_window
from scipy.ndimage import gaussian_filter
from scipy.io import wavfile
from scipy.stats import wasserstein_distance_nd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics.pairwise import rbf_kernel, polynomial_kernel, linear_kernel

import utils as ut
%load_ext autoreload
%autoreload 2

# do not show warnings
warnings.filterwarnings("ignore")

print("Imports complete.")

Imports complete.
Settings: height=800, width=1400, font_size=16
Imports complete.


In [2]:
def add_bg(audio, fs, snr, noise_path):
    noise_fs, noise = wavfile.read(noise_path)
    # assert fs == noise_fs
    return audio + snr * noise

def get_detection_stats(feature, labels, step_size=0.001):
    detections_scores = []
    fa_scores = []
    for threshold in np.arange(0, 1, step_size):
        detections = np.where(feature >= threshold, 1, 0)
        if detections[labels == 1].sum() > 0:
            detections_scores.append(1)
        else:
            detections_scores.append(0)
        if detections[labels == 0].sum() > 0:
            fa_scores.append(1)
        else:
            fa_scores.append(0)
    return detections_scores, fa_scores

def eval_audio(audio, fs, labels, feature_func, snr=None, args={}):
    if snr is not None:
        audio = add_bg(audio, fs, snr, noise_path=args.get("noise_path"))

    feature, freqs = feature_func(audio, fs, **args, return_freqs=True)

    #calculat label indexs
    label_vec = np.zeros(len(freqs))
    for f in labels:
        idx = (np.abs(freqs - f)).argmin()
        print(f"Frequency {f} Hz matched to index {idx}, frequency {freqs[idx]} Hz")
        label_vec[idx] = 1
    label_vec = np.array(label_vec)

    detections_scores, fa_scores = get_detection_stats(feature, label_vec)

    return detections_scores, fa_scores

def process_audio_file(path, slice_duration=10):
    # upload audio
    fs, audio = wavfile.read(path)

    # slice 10s from the middle
    if len(audio.shape) > 1:
        audio = audio[:, 0]  # take only first channel
    
    if slice_duration is not None:
        half_slice = slice_duration // 2
        mid = len(audio) // 2
        start = max(0, mid - fs * half_slice)
        end = min(len(audio), mid + fs * half_slice)
        audio = audio[start:end]

    return fs, audio

def eval_folder(folder_path, label_dict, feature_func, feature_args, slice_duration=10, snr=None):
    all_detection_scores = []
    all_fa_scores = []
    for file_name in os.listdir(folder_path):
        if file_name.endswith(".wav"):
            print(f"Processing file: {file_name}")
            file_path = os.path.join(folder_path, file_name)
            fs, audio = process_audio_file(file_path, slice_duration=slice_duration)

            labels = label_dict[file_name]
            detections_scores, fa_scores = eval_audio(audio, fs, labels, feature_func, snr=snr, args=feature_args)
            all_detection_scores.append(detections_scores)
            all_fa_scores.append(fa_scores)
    avg_detection_scores = np.average(np.array(all_detection_scores), axis=0)
    avg_fa_scores = np.average(np.array(all_fa_scores), axis=0)
    return avg_detection_scores, avg_fa_scores

def welch_feature(audio, fs, nperseg, overlap, window, dc, crop_freq, norm_size, return_freqs=False):
    F, T, Sxx, phasogram = ut.calc_spectrogram(audio, fs, nperseg=nperseg, percent_overlap=overlap, window=window, remove_dc=dc, crop_freq=crop_freq)
    pxx = ut.calc_welch_from_spectrogram(Sxx, normalization_window_size=norm_size)
    if return_freqs:
        return pxx, F
    return pxx

def s2g_detections(audio, threshold, fs, nperseg, overlap, window, dc, crop_freq, quantization_levels):
    F, T, Sxx, phasogram = ut.calc_spectrogram(audio, fs, nperseg=nperseg, percent_overlap=overlap, window=window, remove_dc=dc, crop_freq=crop_freq)
    K = ut.get_all_Ks(phasogram, F, n_levels=quantization_levels)
    th = np.mean(K) + threshold * np.std(K)
    detections = np.where(K >= th, 1, 0)
    return detections, F


§
print("Functions defined.")

Functions defined.


In [44]:
args = {
    "nperseg": 16000,
    "overlap": 0.,
    "window": "hanning",
    "dc": 20,
    "crop_freq": 4000,
    "norm_size": 5,
    "slice_duration": 10,
    "quantization_levels": 30,
    "prominence": 0,
    "wlen": 16,
    # "distance": None,
    "snr": 0.
}


folder_path = '../data/ds2/pos'
noise_path = '../data/ds2/neg/bg_noise_1m.wav'
noise = wavfile.read(noise_path)[1]
all_files = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path) if file_name.endswith(".wav")]
all_files.sort()
# all_files = all_files[2:4]

fig = make_subplots(rows=len(all_files), cols=3, column_widths=[0.8, 0.1, 0.1], shared_yaxes=True, row_titles=all_files, horizontal_spacing=0.01, vertical_spacing=0.01, shared_xaxes=False)
for i, file in enumerate(all_files, start=1):
    fs, audio = process_audio_file(file, slice_duration=args['slice_duration'])
    audio = audio + args['snr'] * noise[:len(audio)]
    F, T, Sxx, phasogram = ut.calc_spectrogram(audio, fs, nperseg=args['nperseg'], percent_overlap=args['overlap'], window=args['window'], remove_dc=args['dc'], crop_freq=args['crop_freq'])
    pxx = ut.calc_welch_from_spectrogram(Sxx, normalization_window_size=args['norm_size'])
    pxx_peaks = find_peaks(pxx, height=np.mean(pxx) + 3 * np.std(pxx))[0]
    K = ut.get_all_Ks(phasogram, F, n_levels=args['quantization_levels'])
    K_peaks = find_peaks(K, height=np.mean(K) + 3 * np.std(K))[0]

    fig.add_trace(go.Heatmap(z=Sxx, x=T, y=F, colorscale='Viridis', showlegend=False, showscale=False), row=i, col=1)
    fig.add_trace(go.Scatter(x=pxx, y=F, mode='lines', line=dict(color='blue'), showlegend=False), row=i, col=2)
    fig.add_trace(go.Scatter(x=pxx[pxx_peaks], y=F[pxx_peaks], mode='markers', marker=dict(color='red', size=6), showlegend=False), row=i, col=2)
    fig.add_trace(go.Scatter(x=K, y=F, mode='lines', line=dict(color='purple'), showlegend=False), row=i, col=3)
    fig.add_trace(go.Scatter(x=K[K_peaks], y=F[K_peaks], mode='markers', marker=dict(color='red', size=6), showlegend=False), row=i, col=3)

fig.update_layout(height=600*len(all_files), width=1400, title_text="Spectrograms and Welch PSDs of Ship Audio Files")
fig.show()