In [2]:
import matplotlib.pyplot as plt
import numpy as np
import time
from scipy.fft import rfft, irfft
import math

In [5]:
def parse_wav(b):
    assert b[0:4] == b'RIFF', "Chunk ID did not match 'RIFF'"
    chunk_size = int.from_bytes(b[4:8], byteorder="little")
    assert b[8:12] == b'WAVE', "Format did not match 'WAVE'"
    assert b[12:16] == b'fmt ', "Format subchunk ID did not match 'fmt '"
    chunk1_size = int.from_bytes(b[16:20], byteorder="little")
    audio_format = int.from_bytes(b[20:22], byteorder="little")
    assert audio_format == 1, "Audio Format is not PCM (i.e. data is compressed)"
    channels = int.from_bytes(b[22:24], byteorder="little")
    sample_rate = int.from_bytes(b[24:28], byteorder="little")
    byte_rate = int.from_bytes(b[28:32], byteorder="little")
    block_align = int.from_bytes(b[32:34], byteorder="little")
    bits_per_sample = int.from_bytes(b[34:36], "little")
    assert (bits_per_sample % 8) == 0 , "bits per sample is not a multiple of 8"
    assert (sample_rate * (bits_per_sample/8) * channels == byte_rate), "byte_rate != bytes_per_sample * sample_rate * num_channels"
    # from 36 onwards, we might have additional metadata
    data_idx = 36
    while True:
        temp_chunk_id = b[data_idx: data_idx + 4]
        temp_chunk_size = int.from_bytes(b[data_idx + 4: data_idx + 8], byteorder="little")
        if temp_chunk_id != b'data':
            data_idx += temp_chunk_size + 8 # add 8 to include the chunk_id and chunk_size fields
        else:
            break
            
    assert b[data_idx: data_idx + 4] == b'data', "Format subchunk ID did not match 'data'"
    sample_size = int.from_bytes(b[data_idx + 4: data_idx + 8], byteorder="little")
    sample_begin = data_idx + 8
    assert sample_size % block_align == 0, "size of sample in bytes is not divisible by number of bytes per sample"
    return {
        "File Size": chunk_size + 8,
        "Audio Format": "PCM",
        "Channels": channels,
        "Sample Rate": sample_rate,
        "Byte Rate": byte_rate,
        "Block Align (bytes per sample)": block_align,
        "Bits Per Sample Per Channel": bits_per_sample,
        "Samples Size": sample_size,
        "Num Samples": int(sample_size/block_align),
        "File Length (seconds)": round(sample_size/(byte_rate), 3),
        "Sample Beginning Index": sample_begin,
    }

In [6]:
def get_percent_yinv(winnum, samples, recon=0.9, window=1504, visual=False):
#     winnum = 58
#     window = 1504 # 47 * 32, alternate value is 1490

    yf = rfft(samples[window*winnum:window*(winnum + 1)])
    ay = np.abs(yf)
    
    if visual:
        plt.figure(figsize=(24,8))
        plt.bar(range(len(yf)), ay)

    l = list(zip(range(len(ay)), ay))
    l.sort(key=lambda x: x[1], reverse=True)

    p1 = 0
    p2 = 0
    num_top = 512
    dnt = 256
    
#     recon = 0.9

    while not(p1 < recon and p2 >= recon):
        yinv_d = None
        yf2 = np.copy(yf)
        top = l[num_top + 1][1]

        for i in range(len(yf2)):
            if np.abs(yf2[i]) <= top:
                yf2[i] = 0
            

        yinv_d = irfft(yf2)
        p2 = 1 - np.linalg.norm(yinv_d-samples[window*winnum:window*(winnum + 1)])/np.linalg.norm(samples[window*winnum:window*(winnum + 1)])

        top = l[num_top][1]
        for i in range(len(yf2)):
            if np.abs(yf2[i]) <= top:
                yf2[i] = 0

        yinv_d = irfft(yf2)
        p1 = 1 - np.linalg.norm(yinv_d-samples[window*winnum:window*(winnum + 1)])/np.linalg.norm(samples[window*winnum:window*(winnum + 1)])
        
        if visual:
            print("reconstruction (num_top + 1), num_top : ", p2, p1, " num_top, dnt: ", num_top, dnt)

        if (p1 >= recon):
            num_top -= dnt
            dnt //= 2
        elif (p2 < recon):
            num_top += dnt
            dnt //= 2
        
        yf2 = None
        yinv_d = None
        
        if (num_top + 1 >= len(l)):
            num_top = len(l) - 2
            break
        
        if (num_top == 1):
            break
    
    yf = rfft(samples[window*winnum:window*(winnum + 1)])
    ay = np.abs(yf)

    l = list(zip(range(len(ay)), ay))
    l.sort(key=lambda x: x[1], reverse=True)

    tops = []
#     print(len(l), num_top + 1)
    top = l[num_top + 1][1]
    for i in range(len(yf)):
        if np.abs(yf[i]) <= top:
            yf[i] = 0
        else:
            tops.append((i, i * (44100/window), np.abs(yf[i])))
    tops.sort(key=lambda x: x[2], reverse=True)
    
    if visual:
        plt.bar(range(len(yf)), np.abs(yf))

    yinv = irfft(yf)

    pfin = 1 - np.linalg.norm(yinv-samples[window*winnum:window*(winnum + 1)])/np.linalg.norm(samples[window*winnum:window*(winnum + 1)])
    if visual:
        print("final number of entries: ", num_top + 1)
        print("final reconstruction: ", pfin)
    
    return (tops, yinv, pfin)

In [37]:
start_time = time.time()

f = open("sirduke.wav", "rb")
bb = f.read()
f.close()

window = 8820

file_info = parse_wav(bb)

num_samples = file_info['Num Samples']
print("num samples: ", num_samples)
samples = np.zeros(num_samples + (window - num_samples%window), dtype=np.int16)
samples[:num_samples] = np.frombuffer(bb[file_info['Sample Beginning Index']:], dtype=np.int16)
samples = samples.astype(np.float64)
samples /= 2**15

ns = num_samples

f = open("sirduke-95-8820.wav", "wb")
f.write(bb[:78])

# fft_samples = np.zeros(ns + (window - (ns % window)), dtype=np.int16)

notes = []

for i in range(0, len(samples), window):
    
    print(round((i/ns)*100, 4), end="\r")
    
    if np.linalg.norm(samples[i:i+window]) != 0:
        if num_samples - i >= window:
            tops, yinv, pfin = get_percent_yinv(i//window, samples, recon=0.95, window=window, visual=False)

            for j in range(window):
                if (i+j < ns):
                    f.write((yinv[j] * (2**15)).astype(np.int16))
            
            notes.append(tops)
        else:
            for j in range(num_samples - i):
                zero = np.zeros(1, dtype=np.int16)
                if (i+j < ns):
                    f.write((zero[0]).astype(np.int16))
    else:
        for j in range(window):
            zero = np.zeros(1, dtype=np.int16)
            if (i+j < ns):
                f.write((zero[0]).astype(np.int16))

f.close()
end_time = time.time()
print("Done: ", round(end_time - start_time, 3), " seconds taken")

num samples:  10255360
0.0

KeyboardInterrupt: 