In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.io.wavfile
import IPython

def setup_graph(title='', x_label='', y_label='', fig_size=None):
    fig = plt.figure()
    if fig_size != None:
        fig.set_size_inches(fig_size[0], fig_size[1])
    ax = fig.add_subplot(111)
    ax.set_title(title)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

# Audio Filtering

Procedure:
    
* Read in audio
* Apply STFT (Window and FFT)
* Remove a range of frequencies we want out
* Apply Inverse STFT (to resynthesize)
* Write new audio


# Functions

In [None]:
def stft(input_data, window_size, hop_size):
    window = scipy.hamming(window_size)
    output = scipy.array([scipy.fft(window * input_data[i: i + window_size]) 
                         for i in range(0, len(input_data) - window_size, hop_size)])
    return output

def istft(input_data, window_size, hop_size, len_signal):
    window = scipy.hamming(window_size)
    output = np.zeros((len(input_data) - 1) * hop_size + window_size)
    for i, in_dat in enumerate(input_data):
        ini = i * hop_size
        output[ini: ini + window_size] += scipy.real(scipy.ifft(in_dat))
    return output

def low_pass_filter(max_freq, window_size, sample_rate):
    fft_bin_width = sample_rate / (window_size / 2)
    max_freq_bin = int(max_freq / fft_bin_width)
    filter_block = np.ones(window_size)
    filter_block[max_freq_bin: (window_size - max_freq_bin)] = 0
    return filter_block

def high_pass_filter(min_freq, window_size, sample_rate):
    return np.ones(window_size) - low_pass_filter(min_freq, window_size, sample_rate)



In [None]:
def filter_audio(input_signal, filter_window, window_size=256):
    # Setting parameters
    hop_size = window_size // 2
    
    # Do actual filtering
    stft_output = stft(input_signal, window_size, hop_size)
    filtered_result = [stft_window * filter_window for stft_window in stft_output]
    resynth = istft(filtered_result, window_size, hop_size, len(input_signal))
    
    return resynth


# Perform filtering

In [None]:
infile = "audio_files/ohm_scale.wav"
outfile = "audio_files/high_pass_out.wav"
window_size = 256

# Input
(sample_rate, input_signal) = scipy.io.wavfile.read(infile)

# Create filter window
filter_window = high_pass_filter(2500, window_size, sample_rate)

# Run filter
resynth = filter_audio(input_signal, filter_window, window_size)

# Output
scipy.io.wavfile.write(outfile, sample_rate, resynth.astype('int16'))


# Results

In [None]:
IPython.display.Audio("audio_files/ohm_scale.wav")

In [None]:
IPython.display.Audio("audio_files/high_pass_out.wav")

## Spectrogram Before

In [None]:
setup_graph(title='Spectrogram (Before)', x_label='time (in seconds)', y_label='frequency', fig_size=(14,7))
_ = plt.specgram(input_signal, Fs=sample_rate)

## Spectrogram After

In [None]:
setup_graph(title='Spectrogram (After)', x_label='time (in seconds)', y_label='frequency', fig_size=(14,7))
_ = plt.specgram(resynth, Fs=sample_rate)

## Sound Wave Before

In [None]:
setup_graph(title='Sound wave (Before)', x_label='time (in seconds)', y_label='amplitude', fig_size=(14,7))
_ = plt.plot(input_signal)

## Sound Wave After

In [None]:
setup_graph(title='Sound wave (After)', x_label='time (in seconds)', y_label='amplitude', fig_size=(14,7))
_ = plt.plot(resynth)

# A low-pass filter example

In [None]:
infile = "audio_files/doremi_xylo.wav"
outfile = "audio_files/low_pass_out.wav"
window_size = 256

# Input
(sample_rate, input_signal) = scipy.io.wavfile.read(infile)

# Create filter window
filter_window = low_pass_filter(1700, window_size, sample_rate)

# Run filter
resynth = filter_audio(input_signal, filter_window, window_size)

# Output
scipy.io.wavfile.write(outfile, sample_rate, resynth.astype('int16'))


In [None]:
IPython.display.Audio("audio_files/doremi_xylo.wav")

In [None]:
IPython.display.Audio("audio_files/low_pass_out.wav")

### Notice that in the after example, you can hear the xylophone mallet, but not the keys

In [None]:
setup_graph(title='Sound wave (Before)', x_label='time (in seconds)', y_label='amplitude', fig_size=(14,7))
_ = plt.plot(input_signal)

In [None]:
setup_graph(title='Sound wave (After)', x_label='time (in seconds)', y_label='amplitude', fig_size=(14,7))
_ = plt.plot(resynth)