In [129]:
from scipy import fft
from scipy.io import wavfile
from scipy.signal import butter, lfilter, windows
import sounddevice as sd
import numpy as np
import librosa

In [130]:
max16 = 32767

# Bands
mid = 100
hi = 400

In [131]:
rate, wav = wavfile.read(filename="hw2assets/collectathon.wav")

# To mono if need be
wav = (wav / 32768).astype(np.float32).transpose()
if wav.ndim == 2:
    wav = (wav[0] + wav[1]) / 2
else:
    assert wav.ndim == 1

In [132]:
# Only use the first 10 seconds to mess with to make this all a bit faster
duration_seconds = 10
num_samples = rate *duration_seconds
wav = wav[:num_samples]

In [133]:
sd.play(data=wav, samplerate=rate)
sd.sleep(2000) # sleep for 2 seconds to just test play
sd.stop() # stop play once sleep is done.

In [134]:
blockSizes = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

In [135]:
#FFT - blockSize passed in since we're doing a lot of them.
def do_fft(wav: any, blockSize: int) -> any:
    r = len(wav) % blockSize
    if r > 0:
        wav = np.append(wav, np.zeros(blockSize - r))
    blocks = np.array_split(wav, len(wav) // blockSize)
    
    # FFT with Blackman-Harris window
    window = windows.blackmanharris(blockSize, sym=False)
    freqs = np.array([fft.rfft(window * b) for b in blocks])
    
    # Frequency bands
    bin_width = rate / blockSize / 2
    bin_mid = max(1, round(mid / bin_width))  # Avoid zero index
    bin_hi = max(bin_mid + 1, round(hi / bin_width))  # Ensure bin_hi > bin_mid
    
    bands = [
        [
            np.sum(np.abs(fs[:bin_mid])),
            np.sum(np.abs(fs[bin_mid:bin_hi])),
            np.sum(np.abs(fs[bin_hi:]))
        ]
        for fs in freqs
    ]
    return bands, freqs

In [136]:
#Filter
def do_filters(wav, rate):
    # Define filters
    low_b, low_a = butter(N=4, Wn=mid / (rate / 2), btype='low')
    mid_b, mid_a = butter(N=4, Wn=[mid / (rate / 2), hi / (rate / 2)], btype='band')
    high_b, high_a = butter(N=4, Wn=hi / (rate / 2), btype='high')
    
    # Apply filters
    low_filtered = lfilter(low_b, low_a, wav)
    mid_filtered = lfilter(mid_b, mid_a, wav)
    high_filtered = lfilter(high_b, high_a, wav)
    
    return low_filtered, mid_filtered, high_filtered

In [None]:
for blockSize in blockSizes:
    
    wav_fft_bands, wav_fft_freqs = do_fft(wav, blockSize=blockSize)
    
    # avg the energies from each band back from fft
    avg_low_energy = np.mean([band[0] for band in wav_fft_bands])
    avg_mid_energy = np.mean([band[1] for band in wav_fft_bands])
    avg_high_energy = np.mean([band[2] for band in wav_fft_bands])
    
    # Apply a butter and lfilter using bands previously set at start
    low_filtered, mid_filtered, high_filtered = do_filters(wav, rate)
    
    # identify the rms for each filter to be used to equalize the energy
    low_rms = np.sqrt(np.mean(low_filtered**2))
    mid_rms = np.sqrt(np.mean(mid_filtered**2))
    high_rms = np.sqrt(np.mean(high_filtered**2))
    
    # with each rms, adjust the filtered data to equalize the energy
    if low_rms != 0:
        low_filtered *= (avg_low_energy / low_rms)
    else:
        low_filtered = 0
    if mid_rms != 0:
        mid_filtered *= (avg_mid_energy / mid_rms)
    else:
        mid_filtered = 0
    if high_rms != 0:
        high_filtered *= (avg_high_energy / high_rms)
    else:
        high_filtered = 0
    
    # Combine the 3 filtered freqs.
    combined_wav = low_filtered + mid_filtered + high_filtered
    
    # Convert from float to int16 for .wav file
    if combined_wav.dtype == np.float64 or combined_wav.dtype == np.float32:
        combined_wav = combined_wav / np.max(np.abs(combined_wav))
        combined_wav = (combined_wav * max16).astype(np.int16)

    # If it managed to get back into more than 1 dim, fix it. This happened
    # infrequently, but noticed the bug
    if combined_wav.ndim == 2:
        combined_wav = (combined_wav[0] + combined_wav[1]) / 2
    
    # print(f"Playing audio with block size: {blockSize}")
    # sd.play(data=combined_wav, samplerate=rate)
    # sd.sleep(3000)
    # sd.stop()
    print(f"Writing .wav audio with block size: {blockSize}")
    wavfile.write(filename=f"./hw2assets/collectathon-{blockSize}-{mid}_{hi}-10sec.wav", rate=rate, data=combined_wav)
    