In [None]:
# imports and settings

import os
import time
import pickle
import warnings
import pandas as pd
import networkx as nx
from networkx.algorithms.coloring import greedy_color
import matplotlib.pyplot as plt
from copy import deepcopy

import numpy as np
from numpy import linalg as LA
from numpy import histogram2d

from scipy import signal
from scipy.fft import fft, fftfreq, fftshift
from scipy.signal import find_peaks, butter, filtfilt, welch, get_window
from scipy.ndimage import gaussian_filter
from scipy.io import wavfile
from scipy.stats import wasserstein_distance_nd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics.pairwise import rbf_kernel, polynomial_kernel, linear_kernel

import utils as ut
%load_ext autoreload
%autoreload 2

# do not show warnings
warnings.filterwarnings("ignore")

print("Imports complete.")

In [None]:
# settings
fs = 8000  # sampling frequency
f0 = 2000  # signal frequency
duration = 60  # seconds
snr_range = (0, -10)  # from 0 dB to -10 dB
snr_vec = np.arange(snr_range[0], snr_range[1] - 1, -1)
num_thresholds = 10
num_iterations = 10

# welch args
nperseg = 1024
noverlap = 0.5
window = 'hanning'
dc = 20
crop_freq = None
norm_size = 5

# s2g args
quantization_bins = 10


In [None]:
raw_signal = ut.simulate_raw_signal(f0, fs, duration)
F, T, raw_Sxx, raw_phasogram = ut.calc_spectrogram(raw_signal, fs, nperseg=nperseg, percent_overlap=noverlap/nperseg, window=window, remove_dc=dc, crop_freq=crop_freq)

SNR = snr_vec
welch_threshold = 0.04
s2g_threshold_laplacian = 0.2

welch_detector = ut.WelchDetector(fs, nperseg, noverlap, window, dc, crop_freq, norm_size)
s2g_detector_laplacian = ut.S2GDetector(fs, nperseg, noverlap, window, dc, crop_freq, quantization_bins, mode="laplacian")


_noise_greneration_time = []
_spectrogram_time = []
_welch_time = []
_welch_detection_time = []
_s2g_detection_time_lap = []

fig = make_subplots(rows=4, cols=len(SNR), shared_xaxes=False, column_titles=[f"{snr} dB" for snr in SNR])
for i, snr in enumerate(SNR, start=1):
    _t = time.time()
    rx = ut.add_noise_to_signal(raw_signal, snr_db=snr, fs=fs, signal_bw=1, noise_type='white')
    _noise_greneration_time.append(time.time() - _t)
    fig.add_trace(go.Heatmap(z=raw_Sxx, x=T, y=F, colorscale='Viridis', showlegend=False, showscale=False), row=1, col=i)

    _t = time.time()
    _F, _T, rx_Sxx, _phasogram = ut.calc_spectrogram(rx, fs, nperseg=nperseg, percent_overlap=noverlap, window=window, remove_dc=dc, crop_freq=crop_freq)
    _spectrogram_time.append(time.time() - _t)
    fig.add_trace(go.Heatmap(z=rx_Sxx, x=_T, y=_F, colorscale='Viridis', showlegend=False, showscale=False), row=2, col=i)

    _t = time.time()
    rx_pxx = ut.calc_welch_from_spectrogram(rx_Sxx, normalization_window_size=norm_size)
    _welch_time.append(time.time() - _t)
    fig.add_trace(go.Scatter(y=rx_pxx, x=_F, showlegend=False, line=dict(color='blue')), row=3, col=i)
    
    _t = time.time()
    _, welch_detections, _ = welch_detector.detect(rx, threshold=welch_threshold)
    _welch_detection_time.append(time.time() - _t)
    if len(welch_detections) > 0:
        fig.add_trace(go.Scatter(y=rx_pxx[welch_detections], x=_F[welch_detections], mode='markers', marker=dict(color='red', size=10), showlegend=False), row=3, col=i)

    _t = time.time()
    _, s2g_detections, (K, K_f) = s2g_detector_laplacian.detect(rx, threshold=s2g_threshold_laplacian)
    _s2g_detection_time_lap.append(time.time() - _t)
    fig.add_trace(go.Scatter(y=K, x=K_f, showlegend=False, line=dict(color='purple')), row=4, col=i)
    if len(s2g_detections) > 0:
        fig.add_trace(go.Scatter(y=K[s2g_detections], x=K_f[s2g_detections], mode='markers', marker=dict(color='red', size=10), showlegend=False), row=4, col=i)


print("Noise generation times:", np.average(_noise_greneration_time))
print("Spectrogram calculation times:", np.average(_spectrogram_time))
print("Welch calculation times:", np.average(_welch_time))
print("Welch detection times:", np.average(_welch_detection_time))
print("S2G detection times (Laplacian):", np.average(_s2g_detection_time_lap))

fig.update_layout(height=800, width=300*len(SNR), title_text=f"Raw vs Noisy Signal")
fig.show()

In [None]:
# settings
fs = 8000  # sampling frequency
f0 = 2000  # signal frequency
duration = 60  # seconds

# welch args
nperseg = 1024
noverlap = 0.5
window = 'hanning'
dc = 20
crop_freq = None
norm_size = 5

# s2g args
quantization_bins = 10

# detectors
welch_detector = ut.WelchDetector(fs, nperseg, noverlap, window, dc, crop_freq, norm_size)
s2g_detector_laplacian = ut.S2GDetector(fs, nperseg, noverlap, window, dc, crop_freq, quantization_bins, mode="laplacian")

In [None]:
# sanity check - run one simulation and plot results
snr = -6
tolerance = 5  # Hz, for matching detections to true signal frequency
num_thresholds = 10
welch_thresholds = np.linspace(0.01, 0.1, num_thresholds)
s2g_thresholds = np.linspace(0.1, 0.99, num_thresholds)
raw_signal = ut.simulate_raw_signal(f0, fs, duration)
rx = ut.add_noise_to_signal(raw_signal, snr_db=snr, fs=fs, signal_bw=1, noise_type='white')
T, F, rx_Sxx, rx_phasogram = ut.calc_spectrogram(rx, fs, nperseg=nperseg, percent_overlap=noverlap, window=window, remove_dc=dc, crop_freq=crop_freq)

welch_detections_all_thresholds = []
welch_fa_all_thresholds = []
s2g_detections_all_thresholds = []
s2g_fa_all_thresholds = []

for welch_threshold in welch_thresholds:
    _, welch_detections, (pxx, F_pxx) = welch_detector.detect(rx, threshold=welch_threshold)
    welch_is_detected = any(abs(F_pxx[welch_detections] - f0) <= tolerance)
    welch_false_alarm = any(abs(F_pxx[welch_detections] - f0) > tolerance)
    welch_detections_all_thresholds.append(1 if welch_is_detected else 0)
    welch_fa_all_thresholds.append(1 if welch_false_alarm else 0)
for s2g_threshold in s2g_thresholds:
    _, s2g_detections, (K, F_s2g) = s2g_detector_laplacian.detect(rx, threshold=s2g_threshold)
    s2g_is_detected = any(abs(F_s2g[s2g_detections] - f0) <= tolerance)
    s2g_false_alarm = any(abs(F_s2g[s2g_detections] - f0) > tolerance)
    s2g_detections_all_thresholds.append(1 if s2g_is_detected else 0)
    s2g_fa_all_thresholds.append(1 if s2g_false_alarm else 0)

print(f"welch detected: {welch_detections_all_thresholds} with fs: {welch_fa_all_thresholds}")
print(f"s2g detected: {s2g_detections_all_thresholds} with fs: {s2g_fa_all_thresholds}")

In [None]:
# ## MAIN SIMULATION CODE
# num_thresholds = 20
# welch_thresholds = np.linspace(0.01, 0.1, num_thresholds)
# s2g_thresholds = np.linspace(0.1, 0.99, num_thresholds)
# num_iterations = 100
# snr_values = [0, -2, -4, -6, -8, -10]  # from 0 dB to -10 dB
# raw_signal = ut.simulate_raw_signal(f0, fs, duration)
# tolerance = 5  # Hz, for matching detections to true signal frequency
# T, F, raw_Sxx, raw_phasogram = ut.calc_spectrogram(raw_signal, fs, nperseg=nperseg, percent_overlap=noverlap, window=window, remove_dc=dc, crop_freq=crop_freq)

# print("Starting simulation...")
# for snr in snr_values:
#     print(f"Simulating for SNR = {snr} dB...")
#     welch_all_detections = np.zeros((num_iterations, num_thresholds))
#     welch_all_fa = np.zeros((num_iterations, num_thresholds))
#     s2g_all_detections = np.zeros((num_iterations, num_thresholds))
#     s2g_all_fa = np.zeros((num_iterations, num_thresholds))
#     for i in range(num_iterations):
#         print(f"Iteration {i+1}/{num_iterations} for SNR = {snr} dB...")
#         rx = ut.add_noise_to_signal(raw_signal, snr_db=snr, fs=fs, signal_bw=1, noise_type='white')
#         welch_detections_all_thresholds = []
#         welch_fa_all_thresholds = []
#         s2g_detections_all_thresholds = []
#         s2g_fa_all_thresholds = []

#         for welch_threshold in welch_thresholds:
#             _, welch_detections, (pxx, F_pxx) = welch_detector.detect(rx, threshold=welch_threshold)
#             welch_is_detected = any(abs(F_pxx[welch_detections] - f0) <= tolerance)
#             welch_false_alarm = any(abs(F_pxx[welch_detections] - f0) > tolerance)
#             welch_detections_all_thresholds.append(1 if welch_is_detected else 0)
#             welch_fa_all_thresholds.append(1 if welch_false_alarm else 0)
#         for s2g_threshold in s2g_thresholds:
#             _, s2g_detections, (K, F_s2g) = s2g_detector_laplacian.detect(rx, threshold=s2g_threshold)
#             s2g_is_detected = any(abs(F_s2g[s2g_detections] - f0) <= tolerance)
#             s2g_false_alarm = any(abs(F_s2g[s2g_detections] - f0) > tolerance)
#             s2g_detections_all_thresholds.append(1 if s2g_is_detected else 0)
#             s2g_fa_all_thresholds.append(1 if s2g_false_alarm else 0)

#         welch_all_detections[i, :] = welch_detections_all_thresholds
#         welch_all_fa[i, :] = welch_fa_all_thresholds
#         s2g_all_detections[i, :] = s2g_detections_all_thresholds
#         s2g_all_fa[i, :] = s2g_fa_all_thresholds

#     welch_all_detections = np.average(welch_all_detections, axis=0)
#     welch_all_fa = np.average(welch_all_fa, axis=0)
#     s2g_all_detections = np.average(s2g_all_detections, axis=0)
#     s2g_all_fa = np.average(s2g_all_fa, axis=0)

#     # save results as pickle
#     with open(f"../results/welch_eval_snr_{snr}.pkl", 'wb') as file:
#         pickle.dump((welch_all_detections, welch_all_fa), file)
#         print(f"Saved Welch results for SNR = {snr} dB to ../results/welch_eval_snr_{snr}.pkl")
#     with open(f"../results/s2g_eval_snr_{snr}.pkl", 'wb') as file:
#         pickle.dump((s2g_all_detections, s2g_all_fa), file)
#         print(f"Saved S2G results for SNR = {snr} dB to ../results/s2g_eval_snr_{snr}.pkl")


In [None]:
welch_pd = []
welch_fa = []
s2g_pd = []
s2g_fa = []
snr_values = [0, -2, -4, -6, -8, -10]
num_thresholds = 20
welch_thresholds = np.linspace(0.01, 0.1, num_thresholds)
s2g_thresholds = np.linspace(0.1, 0.99, num_thresholds)
for snr in snr_values:
    with open(f"../results/welch_eval_snr_{snr}.pkl", 'rb') as file:
        detections, false_alarms = pickle.load(file)
        welch_pd.append(detections)
        welch_fa.append(false_alarms)
    with open(f"../results/s2g_eval_snr_{snr}.pkl", 'rb') as file:
        detections, false_alarms = pickle.load(file)
        s2g_pd.append(detections)
        s2g_fa.append(false_alarms)

fig = make_subplots(rows=2, cols=2, subplot_titles=["Welch Probability of Detection", "S2G (Laplacian) Probability of Detection", "Welch False Alarm Rate", "S2G (Laplacian) False Alarm Rate"])
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown']
for i in range(len(snr_values)):
    fig.add_trace(go.Scatter(x=welch_thresholds, y=welch_pd[i], mode='lines+markers', name=f'Welch SNR={snr_values[i]}', line=dict(color=colors[i])), row=1, col=1)
    fig.add_trace(go.Scatter(x=welch_thresholds, y=welch_fa[i], mode='lines+markers', name=f'Welch FA SNR={snr_values[i]}', line=dict(color=colors[i])), row=2, col=1)
    fig.add_trace(go.Scatter(x=s2g_thresholds, y=s2g_pd[i], mode='lines+markers', name=f'S2G (Laplacian) SNR={snr_values[i]}', line=dict(color=colors[i])), row=1, col=2)
    fig.add_trace(go.Scatter(x=s2g_thresholds, y=s2g_fa[i], mode='lines+markers', name=f'S2G FA SNR={snr_values[i]}', line=dict(color=colors[i])), row=2, col=2)
fig.update_layout(title='Pd / FAR vs Threshold')
fig.show()

In [None]:
# plot Pd vs FAR for all SNRs on the same plot
fig = make_subplots(rows=1, cols=2, subplot_titles=["Welch Pd vs FAR", "S2G (Laplacian) Pd vs FAR"])
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown']
for i in range(len(snr_values)):
    snr_index = snr_values.index(snr_values[i])
    fig.add_trace(go.Scatter(x=welch_fa[snr_index], y=welch_pd[snr_index], hovertemplate='Threshold: %{text:.3f}', text=welch_thresholds, mode='lines+markers', name=f'Welch SNR={snr_values[i]}', line=dict(color=colors[i])), row=1, col=1)
    fig.add_trace(go.Scatter(x=s2g_fa[snr_index], y=s2g_pd[snr_index], hovertemplate='Threshold: %{text:.2f}', text=s2g_thresholds, mode='lines+markers', name=f'S2G (Laplacian) SNR={snr_values[i]}', line=dict(color=colors[i])), row=1, col=2)
fig.update_layout(title='Pd vs FAR')
fig.show()