In [1]:
import sounddevice as sd
import soundfile as sf
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal

from scipy.io.wavfile import write

In [None]:
def buffer(x, wlen, p):
    '''
    Parameters
    ----------
    x: ndarray
        Signal array
    wlen: int
        Window length
    p: int
        Number of samples to overlap

    Returns
    -------
    (n,wlen) ndarray
        Buffer array 
    n int
        Number of windows
    '''
    #n = x.size // wlen + 1 #number of windows
    n = 0 #number of windows
    buffer = []
    i = 0
    while i + wlen <= x.size:
        buffer.append(x[i:min(i + wlen, x.size)])
        i += (wlen - p)
        n += 1
        
    return np.array(buffer,dtype=object), n

#print(buffer(np.arange(10),3,2))

In [103]:
#with a known input signal
def channelest1(x, y, xwin,ywin,overlap):
    hsize = ywin - xwin
    x_w, wnumx = buffer(x,xwin,overlap)
    y_w, wnumy = buffer(y,ywin,overlap + hsize)
    
    if wnumx != wnumy:
        print('mismatch')
        return
    
    H = np.zeros(ywin, dtype=complex)
    for i in range(wnumx):
        currentx = x_w[i]
        currenty = y_w[i]
        extendx = np.zeros(currenty.size)
        extendx[:currentx.size] = currentx 
        X = np.fft.fft(extendx)
        Y = np.fft.fft(currenty)

        H += np.divide(Y,X)
        
    H = H / wnumx
    h_est = np.fft.ifft(H)[:hsize]

    return np.real(h_est)



In [154]:
#with input noise
def channelest2(x, y, xwin,ywin,overlap,sigma):
    hsize = ywin - xwin
    x_w, wnumx = buffer(x,xwin,overlap)
    y_w, wnumy = buffer(y,ywin,overlap + hsize)

    if wnumx != wnumy:
        print('mismatch')
        return
    
    for i in range(wnumx):
        currentx = x_w[i]
        currenty = y_w[i]
        
        extendx = np.zeros(currenty.size)
        extendx[:currentx.size] = currentx 

        _, Sxx = scipy.signal.csd(extendx,extendx)
        _, Sxy = scipy.signal.csd(extendx,currenty)
        if i == 0:
            H = Sxy / Sxx
        else:
            H += Sxy / Sxx

    H = H / wnumx
    h_est = np.fft.ifft(H)[:hsize]
    h_est = np.real(h_est)
    return h_est

In [163]:
#with chirp
def channelest3(x, y, fs, hsize):
    
    w = np.convolve(y,x)
    h_est = w[fs-1:fs-1+hsize]/240 #hard code!!

    return h_est
        


In [41]:
#testing
h_test = np.random.randn(10)


In [121]:
#estimation 1
x = np.random.randn(10000)
y = np.convolve(h_test,x) + np.random.randn(h_test.size + x.size - 1)
xwin = 5
hsize = 10
result = channelest1(x,y,xwin,xwin + hsize,3)

print(h_test)
print(result)

[ 0.65590724 -1.49163429  0.37696037  0.51752404 -0.21346702  0.93670073
  1.9392701  -1.05138389  0.41004389 -0.80741243]
[ 0.95795982 -1.3114297   0.63948747  0.7522445  -0.04157234  1.14499136
  2.19137771 -0.80953444  0.58065557 -0.67963508]


In [161]:
#estimation 2
x = np.random.randn(10000)
y = np.convolve(h_test,x) + np.random.randn(h_test.size + x.size - 1)
xwin = 200
hsize = 10
result2 = channelest2(x, y, xwin,xwin + hsize,3,1)

print(h_test)
print(result2)

[ 0.65590724 -1.49163429  0.37696037  0.51752404 -0.21346702  0.93670073
  1.9392701  -1.05138389  0.41004389 -0.80741243]
[ 0.68251071  0.32472216 -0.23761746  2.02069825  0.26734261 -0.20362454
 -0.0626896  -0.01857291  0.00251138  0.00338228]


In [165]:
#estimation 3
duration = 1
fs = 48000 #sample freq
sample_times = np.linspace(0, duration, fs * duration)
x = 0.1 * scipy.signal.chirp(sample_times, 20, duration, 20000) 
y = np.convolve(h_test,x) + np.random.randn(h_test.size + x.size - 1)
hsize = 10

xinv = np.flip(x)

result3 = channelest3(xinv,y,fs,hsize)

print(h_test)
print(result3)



[ 0.65590724 -1.49163429  0.37696037  0.51752404 -0.21346702  0.93670073
  1.9392701  -1.05138389  0.41004389 -0.80741243]
[ 0.40418158 -1.38297943  0.15793921  1.08633598 -0.90845713  1.5352095
  1.81845739 -0.73622367 -0.17385931 -0.50177688]
