In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import rfft, irfft
import math
from scipy.stats import norm as normal_dist

# NOTES=[((' ABb B CDb DEb E FGb GAb'[2*i%24:2*i%24+2]+str(i//12)).strip(),27.5*2**(i/12)) for i in range(120)]

import time

In [None]:
def parse_wav(b):
    assert b[0:4] == b'RIFF', "Chunk ID did not match 'RIFF'"
    chunk_size = int.from_bytes(b[4:8], byteorder="little")
    assert b[8:12] == b'WAVE', "Format did not match 'WAVE'"
    assert b[12:16] == b'fmt ', "Format subchunk ID did not match 'fmt '"
    chunk1_size = int.from_bytes(b[16:20], byteorder="little")
    audio_format = int.from_bytes(b[20:22], byteorder="little")
    assert audio_format == 1, "Audio Format is not PCM (i.e. data is compressed)"
    channels = int.from_bytes(b[22:24], byteorder="little")
    sample_rate = int.from_bytes(b[24:28], byteorder="little")
    byte_rate = int.from_bytes(b[28:32], byteorder="little")
    block_align = int.from_bytes(b[32:34], byteorder="little")
    bits_per_sample = int.from_bytes(b[34:36], "little")
    assert (bits_per_sample % 8) == 0 , "bits per sample is not a multiple of 8"
    assert (sample_rate * (bits_per_sample/8) * channels == byte_rate), "byte_rate != bytes_per_sample * sample_rate * num_channels"
    # from 36 onwards, we might have additional metadata
    data_idx = 36
    while True:
        temp_chunk_id = b[data_idx: data_idx + 4]
        temp_chunk_size = int.from_bytes(b[data_idx + 4: data_idx + 8], byteorder="little")
        if temp_chunk_id != b'data':
            data_idx += temp_chunk_size + 8 # add 8 to include the chunk_id and chunk_size fields
        else:
            break
            
    assert b[data_idx: data_idx + 4] == b'data', "Format subchunk ID did not match 'data'"
    sample_size = int.from_bytes(b[data_idx + 4: data_idx + 8], byteorder="little")
    sample_begin = data_idx + 8
    assert sample_size % block_align == 0, "size of sample in bytes is not divisible by number of bytes per sample"
    
    return {
        "File Size": chunk_size + 8,
        "Audio Format": "PCM",
        "Channels": channels,
        "Sample Rate": sample_rate,
        "Byte Rate": byte_rate,
        "Block Align (bytes per sample)": block_align,
        "Bits Per Sample Per Channel": bits_per_sample,
        "Samples Size": sample_size,
        "Num Samples": int(sample_size/block_align),
        "File Length (seconds)": round(sample_size/(byte_rate), 3),
        "Sample Begin": sample_begin,
    }

In [None]:
def note_similarity_matrix(window, rate):
    #TODO: make this cleaner
    len_fft = window//2 + 1
    logs = (math.log(27.5,2) + np.arange(0,10,(1/12)))
    freqs = np.log2(np.arange(1, len_fft) * (rate/window))
    mat = np.zeros((len_fft, len(logs)))
    sigma = 1/24
    
    for i in range(1, len_fft): #do not count frequency 0
        for j in range(len(logs)):
            mat[i][j] = freqs[i-1] - logs[j]
            
    mat = 2 * (1 - normal_dist().cdf(np.abs(mat/sigma)))

    for i in range(1, len_fft):
        s = sum(mat[i]) 
        if s > 0.0:
            mat[i] /= s
    mat[0][:] = 0 #clear frequency 0
    
    return mat

In [None]:
start = time.time()

f = open("whistle.wav", "rb")
bb = f.read()
f.close()

file_info = parse_wav(bb)

num_samples = file_info['Num Samples']
rate = file_info['Sample Rate']
#TODO: determine window size as a function of sample rate
window = 2048 + 1024 

# TODO: deal with multiple channels, different bits per sample, 
samples = np.zeros(num_samples + (window - (num_samples%window)), dtype=np.int16)
samples[:num_samples] = np.frombuffer(bb[file_info['Sample Begin']:], dtype=np.int16)

num_windows = len(samples)//window
num_fft_coefs = window//2 + 1

ffts = np.zeros((num_windows, num_fft_coefs)) 
    
for winnum in range(num_windows):
    np.copyto(ffts[winnum], np.abs(rfft(samples[window*winnum:window*(winnum + 1)])))

song = np.matmul(ffts, note_similarity_matrix(window, rate))

end = time.time()

print("time:", end-start)

In [None]:
fig, ax = plt.subplots(figsize=(10,100))
im = ax.imshow(song, cmap='BuPu')