In [1]:
import numpy as np
import plotly.graph_objects as go
from scipy import signal
from scipy.io import wavfile

In [2]:
data_file = "../data/scooter_example_1.wav"
fs, data = wavfile.read(data_file)

# crop data
start_time = 80 # seconds
end_time = 150 # seconds
data = data[int(start_time*fs):int(end_time*fs)]

  fs, data = wavfile.read(data_file)


In [3]:
nperseg=65536
noverlap=None
window='hann'
title='Spectrogram'
colorscale='Viridis'
crop_freq=2000

In [4]:
if noverlap is None:
    noverlap = nperseg // 8

# Compute spectrogram
frequencies, times, Sxx = signal.spectrogram(
    data, 
    fs=fs,
    window=window,
    nperseg=nperseg,
    noverlap=noverlap
)

# Crop frequencies if specified
if crop_freq is not None:
    freq_mask = frequencies <= crop_freq
    frequencies = frequencies[freq_mask]
    Sxx = Sxx[freq_mask, :]

# Convert to dB scale
Sxx_db = 10 * np.log10(Sxx + 1e-10)  # Add small value to avoid log(0)

# Create the heatmap
fig = go.Figure(data=go.Heatmap(
    z=Sxx_db,
    x=times,
    y=frequencies,
    colorscale=colorscale,
    colorbar=dict(title='Power (dB)')
))

# Update layout
fig.update_layout(
    title=title,
    xaxis_title='Time (s)',
    yaxis_title='Frequency (Hz)',
    width=800,
    height=600
)

fig.show()

In [5]:
if noverlap is None:
    noverlap = nperseg // 8

# Compute spectrogram
frequencies, times, Sxx = signal.spectrogram(
    data, 
    fs=fs,
    window=window,
    nperseg=nperseg,
    noverlap=noverlap,
    mode='phase'
)

# Crop frequencies if specified
if crop_freq is not None:
    freq_mask = frequencies <= crop_freq
    frequencies = frequencies[freq_mask]
    Sxx = Sxx[freq_mask, :]

# Convert to dB scale
Sxx_phase = 10 * np.log10(Sxx + 1e-10)  # Add small value to avoid log(0)

# Create the heatmap
fig = go.Figure(data=go.Heatmap(
    z=Sxx_phase,
    x=times,
    y=frequencies,
    colorscale=colorscale,
    colorbar=dict(title='Power (dB)')
))

# Update layout
fig.update_layout(
    title=title,
    xaxis_title='Time (s)',
    yaxis_title='Frequency (Hz)',
    width=800,
    height=600
)

fig.show()


invalid value encountered in log10



In [6]:
# Combined spectrogram plots - Power and Phase side by side
from plotly.subplots import make_subplots

if noverlap is None:
    noverlap = nperseg // 8

# Compute power spectrogram
frequencies, times, Sxx_power = signal.spectrogram(
    data, 
    fs=fs,
    window=window,
    nperseg=nperseg,
    noverlap=noverlap
)

# Compute phase spectrogram  
frequencies, times, Sxx_phase = signal.spectrogram(
    data, 
    fs=fs,
    window=window,
    nperseg=nperseg,
    noverlap=noverlap,
    mode='phase'
)

# Crop frequencies if specified
if crop_freq is not None:
    freq_mask = frequencies <= crop_freq
    frequencies = frequencies[freq_mask]
    Sxx_power = Sxx_power[freq_mask, :]
    Sxx_phase = Sxx_phase[freq_mask, :]

# Convert to dB scale
Sxx_power_db = 10 * np.log10(Sxx_power + 1e-10)
Sxx_phase_db = 10 * np.log10(np.abs(Sxx_phase) + 1e-10)  # Use abs for phase to avoid invalid values

# Create subplots
fig = make_subplots(
    rows=2, cols=1,
    subplot_titles=('Power Spectrogram', 'Phase Spectrogram')
)

# Add power spectrogram
fig.add_trace(
    go.Heatmap(
        z=Sxx_power_db,
        x=times,
        y=frequencies,
        colorscale=colorscale
    ),
    row=1, col=1
)

# Add phase spectrogram
fig.add_trace(
    go.Heatmap(
        z=Sxx_phase_db,
        x=times,
        y=frequencies,
        colorscale=colorscale
    ),
    row=2, col=1
)

# Update layout
fig.update_layout(
    title='Combined Spectrogram Analysis',
    width=1000,
    height=800
)

# Update x and y axis labels
fig.update_xaxes(title_text='Time (s)')
fig.update_yaxes(title_text='Frequency (Hz)')

fig.show()

In [46]:
# perform FFT on each segment to get phase information

Sxx_phase_fft = np.fft.fft(Sxx_phase, axis=0)
phase_frequencies = np.fft.fftfreq(Sxx_phase.shape[1], d=1/fs)

# Only keep the positive frequencies
positive_freqs = phase_frequencies > 0
Sxx_phase_fft = Sxx_phase_fft[:, positive_freqs]
phase_frequencies = phase_frequencies[positive_freqs]

Sxx_phase_fft_db = 10 * np.log10(np.abs(Sxx_phase_fft) + 1e-10)

# Create the heatmap
fig = go.Figure(data=go.Heatmap(
    z=Sxx_phase_fft_db,
    x=phase_frequencies,
    y=frequencies,
    colorscale=colorscale,
    colorbar=dict(title='Power (dB)')
))

# Update layout
fig.update_layout(
    title=title,
    xaxis_title='Frequency (Hz)',
    yaxis_title='Frequency - original (Hz)',
    width=800,
    height=600
)

fig.show()