In [1]:
# Add soundsig to path to get coherence2.
# At some point this can be packaged and pip installed
import sys
sys.path.append("../soundsig")

In [2]:
## Dependencies
from scipy import stats
from coherence import multitapered_coherence
from scipy import signal
import numpy as np
import matplotlib.pyplot as plt

# Display and GUI
import ipywidgets as widgets



### Define Input Widgets

In [3]:
samprate_picker = widgets.FloatSlider(
    value=25000,
    min=2,
    max=44100,
    step=1,
    description='Sampling Rate:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

duration_picker = widgets.FloatSlider(
    value=10.0,
    min=1,
    max=100.0,
    step=0.1,
    description='Duration (s):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

cf_picker = widgets.FloatSlider(
    value=250.0,
    min=1,
    max=1000,
    step=1,
    description='Cutoff F:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.0f',
)

SNR_picker = widgets.FloatSlider(
    value=1.0,
    min=0.01,
    max=10,
    step=0.01,
    description='SNR:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

delay_picker = widgets.FloatSlider(
    value=0.0,
    min=-50,
    max=50,
    step=1,
    description='Delay (ms):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.0f',
)

chunk_picker = widgets.IntSlider(
    value=10,
    min=4,
    max=20,
    step=1,
    description='Chunk size (log2):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
)

NW_picker = widgets.IntSlider(
    value=3,
    min=1,
    max=10,
    step=1,
    description='NW:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
)

genSig_button = widgets.Button(
    description='Generate Signal',
    disabled=False,
    button_style='success', 
    tooltip='Click me to generate signals',
    icon='check' 
)

calcCoh_button = widgets.Button(
    description='Calc Coh',
    disabled=False,
    button_style='success', 
    tooltip='Click me to calculate coherence',
    icon='check' 
)

In [4]:
def plot_expectation(cutoffFreq, fs, SNR):
    # Globals
    global b, a, w, h

    # Design the filter
    b, a = signal.butter(8, cutoffFreq, btype='low', analog=False, output='ba', fs=fs)

    # Plot the frequency gain, SNR and Expected Coherence
    w, h = signal.freqz(b, a)
    plt.figure()
    plt.subplot(1,3,1)
    plt.plot(w*fs/(2*np.pi), abs(h))
    plt.ylabel('Filter Gain')
    plt.xlim([0, 1000])

    plt.subplot(1,3,2)
    plt.plot(w*fs/(2*np.pi), (abs(h)**2)*SNR)
    plt.ylabel('SNR')
    plt.xlim([0, 1000])

    plt.subplot(1,3,3)
    plt.plot(w*fs/(2*np.pi), (abs(h)**2)*SNR/(1+(abs(h)**2)*SNR))
    plt.ylabel('Coherence')
    plt.xlim([0, 1000])
    plt.ylim([-0.05,1])
    plt.xlabel('Frequency (Hz)')
    plt.show()


    

In [5]:
expectation_plot = widgets.interactive_output(
    plot_expectation, 
    {
        "cutoffFreq" : cf_picker ,
        "fs" : samprate_picker , 
        "SNR" : SNR_picker       
    }
)

signal_plot = widgets.Output()
coherence_plot = widgets.Output()

In [6]:
def plot_signals(but):
    global b, a
    global x, y

    fs = samprate_picker.value
    T = duration_picker.value
    Delay = delay_picker.value
    SNR = SNR_picker.value
    chunkSize = chunk_picker.value

    N = int(fs*T)                # Number of points in signals
    time = np.arange(N) / fs     # Time array for plots
    delay = int(Delay * fs/1000.0)      # Delay in number of points



    ## Generate signals
    sigInput = np.random.normal(scale=1, size=N + abs(delay) )  # Input signal    
    sigFilt = signal.filtfilt(b, a, sigInput)  # Output signal is delayed and low-pass filtered...
    
    ## The x and y signal
    if delay >= 0:
        x = sigInput[delay:N+delay]
        y = sigFilt[0:N]
    else:
        x = sigInput[0:N]
        y = sigFilt[abs(delay):abs(delay)+N]
    
    ## Adding noise
    y += np.random.normal(scale=np.sqrt(1/SNR), size=N) # and noise is added.


    f, Pxx = signal.welch(x, fs=fs, window='hann', nperseg=2**chunkSize, noverlap=2**(chunkSize-1))
    f, Pyy = signal.welch(y, fs=fs, window='hann', nperseg=2**chunkSize, noverlap=2**(chunkSize-1))

    ## Plot Signals
    with signal_plot:
        signal_plot.clear_output(True)
        plt.subplot(2,1,1)
        plt.plot(x[0:500],'k', label='Input')
        plt.plot(y[0:500],'r', label='Output')
        plt.xlabel('Time (pts)')
        plt.legend()
        plt.subplot(2,1,2)
        plt.plot(f, Pxx,'k', label='Input')
        plt.plot(f, Pyy,'r', label='Output')
        plt.xlim(0,1000)
        plt.xlabel('Frequency (Hz)')
        plt.legend()
        plt.show()
    

In [7]:
genSig_button.on_click(plot_signals)

In [8]:
def plot_coherence(but):
    global x, y, w, h
    global result

    fs = samprate_picker.value
    chunkSize = chunk_picker.value
    SNR = SNR_picker.value
    Delay = delay_picker.value
    NW = NW_picker.value
    
    # Calculate the coherence

    # Multi-tapered + JN
    result = multitapered_coherence([
        np.array([x, y]),
    ], sampling_rate=fs, chunk_size=2**chunkSize, overlap=0.5, NW=NW)

    # Welch
    f, Cxy = signal.coherence(x, y, fs=fs, window='hann', nperseg=2**chunkSize, noverlap=2**(chunkSize-1))
    
    with coherence_plot:
    ## Coherency in Spectral Domain. Some point make this a radial plot with colored segments
        coherence_plot.clear_output(True)
        plt.subplot(3,1,1)
        plt.plot(np.fft.fftshift(result["freqs"]), np.fft.fftshift(np.real(result["coherency"][0, 1])), label='Real')
        plt.plot(np.fft.fftshift(result["freqs"]), np.fft.fftshift(np.imag(result["coherency"][0, 1])), label='Imag')
        plt.xlim([0, 1000.0])
        plt.xlabel('Frequency (Hz)')
        plt.legend()
    
        ## Coherence Estimates
        plt.subplot(3,1,2)
        plt.plot(np.fft.fftshift(result["freqs"]), np.fft.fftshift(np.abs(result["coherence"][0, 1])), label='MT-JN Estimate')
        plt.plot(w*fs/(2*np.pi), (abs(h)**2)*SNR/(1+(abs(h)**2)*SNR), 'k', label='Ground Truth')
        plt.plot(f, Cxy, 'k--', label='Welch Estimate')
        plt.fill_between(
            np.fft.fftshift(result["freqs"]),
            np.fft.fftshift(np.abs(result["coherence_bounds"][0][0, 1])),
            np.fft.fftshift(np.abs(result["coherence_bounds"][1][0, 1])),
            alpha=0.5,
        )

        plt.ylim([0, 1])
        plt.xlim([0, 1000.0])
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Coherence')
        plt.legend()
    
        # The coherency in the time domain.
        cxy = np.fft.fftshift(np.real(result["coherency_t"][0, 1]))
        npts = cxy.size
        t = np.arange(-npts/2 +1, npts/2+1)*1000.0/fs
        plt.subplot(3,1,3)
        plt.plot(t, cxy)
        plt.plot([Delay, Delay], [cxy.min(), cxy.max()], 'k')
        # plt.xlim(-50, 50)
        plt.xlabel('Time (ms)')
        plt.ylabel('Coherency')
    
        plt.show()
    return 

In [9]:
calcCoh_button.on_click(plot_coherence)

In [10]:
widgets.VBox([
    widgets.HBox([widgets.VBox([
        samprate_picker,
        cf_picker,
        SNR_picker,
        delay_picker,
        duration_picker,
        chunk_picker,
        NW_picker
    ]), expectation_plot ]),
    widgets.HBox([genSig_button, signal_plot ]),
    widgets.HBox([calcCoh_button, coherence_plot ])


])

VBox(children=(HBox(children=(VBox(children=(FloatSlider(value=25000.0, continuous_update=False, description='…