In [None]:
%reset -f
import numpy as np
import matplotlib.pyplot as plt
import pyroomacoustics as pra
import librosa
import librosa.display
from matplotlib import cm
import time
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## RIR parameter

In [None]:
SorNum = 1
MicNum = 30
c = 343
fs = 16000
Ts = 1/fs
# UCA #
# radius_UCA = 0.5
# angle_step = 360/MicNum
# MicCenter = [1, 3, 1]    # 麥克風陣列中心點
# MicPos = np.zeros((3, MicNum))
# for mic_num in range(MicNum) :
#     MicPos[:, mic_num] = np.array([MicCenter[0] + radius_UCA*np.cos(angle_step*mic_num*np.pi/180), MicCenter[1] + radius_UCA*np.sin(angle_step*mic_num*np.pi/180), MicCenter[2]])

# ULA #
MicStart = [1, 1.5, 1]
spacing = 0.02
MicPos = np.zeros((3, MicNum))
for mic_num in range(MicNum) :
    MicPos[:, mic_num] = np.array([MicStart[0] + mic_num*spacing, MicStart[1], MicStart[2]])

SorPos = [2, 2.5, 1]
room_dim = [5, 6, 2.4]
reverberation_time = 0.2    # T60
points_rir = 4096
dim = 3
orientation = [0, 90]    # [水平 垂直]

# 畫空間圖 #
fig1 = plt.figure()
ax1 = fig1.add_subplot(projection='3d')
for mic_num in range(MicNum) :
    ax1.scatter(MicPos[0, mic_num], MicPos[1, mic_num], MicPos[2, mic_num], c='r', marker='o')

ax1.scatter(SorPos[0], SorPos[1], SorPos[2], c='b', marker='x')
ax1.set_xlim((0, room_dim[0]))
ax1.set_ylim((0, room_dim[1]))
ax1.set_zlim((0, room_dim[2]))
ax1.set_title('geometry plot')
ax1.set_xlabel('x-axis')
ax1.set_ylabel('y-axis')
ax1.set_zlabel('z-axis')

## generate RIR

In [None]:
# 產生 RIR #
e_absorption, max_order = pra.inverse_sabine(reverberation_time, room_dim)
room = pra.ShoeBox(room_dim, fs=fs, materials=pra.Material(e_absorption), max_order=max_order)

from pyroomacoustics.directivities import (
    DirectivityPattern,
    DirectionVector,
    CardioidFamily
)

dir_mic = CardioidFamily(
    orientation=DirectionVector(azimuth=orientation[0], colatitude=orientation[1], degrees=True),
    pattern_enum=DirectivityPattern.OMNI
)

room.add_source(position=SorPos)
room.add_microphone(loc=MicPos, directivity=dir_mic)
room.compute_rir()
h = np.zeros((MicNum, points_rir))
for mic_num in range(MicNum) :
    h[mic_num, :] = room.rir[mic_num][0][0:points_rir]     

# 畫 ground-truth RIR time plot #
look_mic = 5
fig2 = plt.figure(figsize=(10, 5))
ax2 = fig2.add_subplot()
ax2.plot(h[look_mic, :])
h_yaxis_upperlimit = np.max(h[look_mic]) + 0.025
h_yaxis_underlimit = np.min(h[look_mic]) - 0.025
ax2.set_ylim((h_yaxis_underlimit, h_yaxis_upperlimit))
ax2.set_title('h')
ax2.set_xlabel('points')
ax2.set_ylabel('magnitude')
plt.show()

## 讀音檔 or 產生 white noise source

In [None]:
Second = 20
SorLen =  Second*fs

# load source #
source, fs= librosa.load('245.wav', sr=fs, mono=True, offset=0.0, duration=Second, dtype=np.float32)    # audio source
# source = np.random.normal(0, 1, size=SorLen)                                                            # white noise source

# 畫 ground-truth source time plot #
fig3 = plt.figure(figsize=(10, 5))
ax3 = fig3.add_subplot()
ax3.plot(source)
ax3.set_title('source')
ax3.set_xlabel('points')
ax3.set_ylabel('magnitude')
plt.show()

## compute ground-truth CTF (H)

In [None]:
NWIN = 64
NFFT = 64
center = False    # True or False
hopsize = int(NWIN/2)
# L = int(np.floor((points_rir - NWIN)/hopsize)) + 1
# NumOfFrame = np.floor((SorLen - NWIN)/hopsize) + 1
frequency = int(NFFT/2) + 1
freqs_vec_half_fs = fs/2*np.linspace(0, 1, frequency)

H = librosa.stft(h, n_fft=NFFT, hop_length=hopsize, win_length=NWIN, window='boxcar', center=center)
L  = H.shape[2]

## compute source signal for frequency (S)

In [None]:
S = librosa.stft(source, n_fft=NFFT, hop_length=hopsize, win_length=NWIN, window='hann', center=center)
NumOfFrame = S.shape[1]
S = torch.from_numpy(S).to(torch.complex64)

## RIR mix source 先在時域上處理再做 fft (Y)

In [None]:
h_conv_source = np.zeros((MicNum, h.shape[1] + source.shape[0] - 1))

for i in range(MicNum) :
    h_conv_source[i, :] = np.convolve(h[i, :], source)

h_conv_source = h_conv_source[:, 0:SorLen] 

# 加上 white noise 當作 interferer #
# SNR = 0
# def awgn(x, SNR, seed=7) :
#     np.random.seed(seed)
#     snr = 10**(SNR/10)
#     x_power = np.sum(x**2)/len(x)
#     noise_power = x_power/snr
#     noise = np.random.randn(len(x))*np.sqrt(noise_power)
    
#     return x + noise

# y = np.zeros((MicNum, SorLen))
# for i in range(MicNum) :
#    y[i, :] =  awgn(h_conv_source[i, :], SNR)

# 不加 white noise #
y = h_conv_source

Y = librosa.stft(y, n_fft=NFFT, hop_length=hopsize, win_length=NWIN, window='hann', center=center)
Y = torch.from_numpy(Y).to(torch.complex64)

# 畫 mics received signal frequency plot #
fig6 = plt.figure(figsize=(7, 10))
ax6_1 = fig6.add_subplot(2, 1, 1)
S_underlimit = np.min(librosa.amplitude_to_db(np.abs(S)))
S_upperlimit = np.max(librosa.amplitude_to_db(np.abs(S)))
img = librosa.display.specshow(librosa.amplitude_to_db(np.abs(S)), sr=fs,  hop_length=hopsize, y_axis='linear', x_axis='frames', ax=ax6_1, cmap=cm.rainbow, vmin=S_underlimit, vmax=S_upperlimit)
ax6_1.set_title('S')
fig6.colorbar(img, ax=ax6_1, format="%+2.0f dB")

ax6_2 = fig6.add_subplot(2, 1, 2)
img = librosa.display.specshow(librosa.amplitude_to_db(np.abs(Y[look_mic])), sr=fs,  hop_length=hopsize, y_axis='linear', x_axis='frames', ax=ax6_2, cmap=cm.rainbow, vmin=S_underlimit , vmax=S_upperlimit)
ax6_2.set_title('Y')
fig6.colorbar(img, ax=ax6_2, format="%+2.0f dB")
plt.show()

## initial Rss Rsy

In [None]:
ini_frame = int(np.floor(NumOfFrame/10))

# 初始化 Rss #
Rss = torch.zeros(L, L, frequency, dtype=torch.complex64).to(device)
for FrameNo in range(L-1, ini_frame) :
    S_choose = S[:, FrameNo-L+1:FrameNo+1].T.to(device)
    for n in range(frequency) :
        Rss[:, :, n] = Rss[:, :, n] + torch.mm(torch.flipud(S_choose[:, n]).reshape((L, 1)), torch.flipud(S_choose[:, n]).conj().reshape((1, L)))

Rss = Rss/(ini_frame-L+1)

# 初始化 Rsy #
Rsy = torch.zeros(L, MicNum, frequency, dtype=torch.complex64).to(device)
for FrameNo in range(L-1, ini_frame) :
    S_choose = S[:, FrameNo-L+1:FrameNo+1].T.to(device)
    Y_choose = Y[:, :, FrameNo].to(device)
    for n in range(frequency) :
        Rsy[:, :, n] = Rsy[:, :, n] + torch.mm(torch.flipud(S_choose[:, n]).reshape((L, 1)), Y_choose[:, n].conj().reshape((1, MicNum)))
 
Rsy = Rsy/(ini_frame-L+1)

## 畫圖看初始 A 在頻域的樣子

In [None]:
A_ini = torch.zeros(MicNum, L, frequency, dtype=torch.complex64).to(device)
dia_load_ini = 10**(-10)
for n in range(frequency) :
    A_ini[:, :, n] = torch.mm(Rsy[:, :, n].conj().T, torch.inverse(Rss[:,:,n] + dia_load_ini*torch.eye(L).to(device)))

A_ini_forplot = np.zeros((MicNum, frequency, L), dtype='complex64')
for i in range(A_ini.shape[1]) :
    A_ini_forplot[:, :, i] = A_ini[:, i, :].cpu()

ctf_ini_tdomain = librosa.istft(A_ini_forplot, hop_length=hopsize, win_length=NWIN, n_fft=NFFT, window='boxcar', center=center)

# 底下畫圖 #
fig8 = plt.figure(figsize=(20, 5))
ax8_1 = fig8.add_subplot(1, 2, 1)
H_underlimit = np.min(librosa.amplitude_to_db(np.abs(H[look_mic]))) - 5
H_upperlimit = np.max(librosa.amplitude_to_db(np.abs(H[look_mic]))) + 5
img = librosa.display.specshow(librosa.amplitude_to_db(np.abs(H[look_mic])), sr=fs,  hop_length=hopsize, y_axis='linear', x_axis='frames', 
                                ax=ax8_1, cmap=cm.rainbow, vmin=H_underlimit , vmax=H_upperlimit)
ax8_1.set_title('H')
fig8.colorbar(img, ax=ax8_1, format="%+2.0f dB")

ax8_2 = fig8.add_subplot(1, 2, 2)
img = librosa.display.specshow(librosa.amplitude_to_db(np.abs(A_ini_forplot[look_mic])), sr=fs,  hop_length=hopsize, y_axis='linear', x_axis='frames',
                                ax=ax8_2, cmap=cm.rainbow, vmin=H_underlimit , vmax=H_upperlimit)
ax8_2.set_title('A_ini')
fig8.colorbar(img, ax=ax8_2, format="%+2.0f dB")
plt.show()

fig10 = plt.figure(figsize=(20, 5))
ax10_1 = fig10.add_subplot(1, 2, 1)
ax10_1.plot(freqs_vec_half_fs, librosa.amplitude_to_db(np.abs(H[look_mic])))
ax10_1.set_ylim((H_underlimit, H_upperlimit))
if H[look_mic].shape[1] < 10 :
    ax10_1.legend([str(i) for i in range(H[look_mic].shape[1])])
ax10_1.set_title('H')
ax10_1.set_xlabel('frequency(Hz)')
ax10_1.set_ylabel('dB')

ax10_2 = fig10.add_subplot(1, 2, 2)
ax10_2.plot(freqs_vec_half_fs, librosa.amplitude_to_db(np.abs(A_ini_forplot[look_mic])))
ax10_2.set_ylim((H_underlimit, H_upperlimit))
if H[look_mic].shape[1] < 10 :
    ax10_2.legend([str(i) for i in range(A_ini_forplot[look_mic].shape[1])])
ax10_2.set_title('A_ini')
ax10_2.set_xlabel('frequency(Hz)')
ax10_2.set_ylabel('dB')
plt.show()   

## 畫圖看初始 A 在時域的樣子 (圖有 adjust max point 但實際沒有)

In [None]:
# 讓最高點一樣高 #
look_mic = 5    # 上面宣告過 可改成其他數字 or default = look_mic
h_max = np.max(h[look_mic, :])
ctf_ini_tdomain_max = np.max(ctf_ini_tdomain[look_mic, :])
ratio = h_max/ctf_ini_tdomain_max

# 底下畫圖 #
h_yaxis_upperlimit = np.max(h[look_mic]) + 0.025
h_yaxis_underlimit = np.min(h[look_mic]) - 0.025

fig9 = plt.figure(figsize=(10, 7.5))
ax9_1 = fig9.add_subplot(2, 1, 1)
ax9_1.plot(h[look_mic, :], 'r')
ax9_1.plot(ctf_ini_tdomain[look_mic, :]*ratio, 'b')
ax9_1.set_ylim((h_yaxis_underlimit, h_yaxis_upperlimit))
ax9_1.set_title('ctf_ini_tdomain')
ax9_1.set_xlabel('points')
ax9_1.set_ylabel('magnitude')

ax9_2 = fig9.add_subplot(2, 1, 2)
ax9_2.plot(h[look_mic, :], 'r')
ax9_2.set_ylim((h_yaxis_underlimit, h_yaxis_upperlimit))
ax9_2.set_title('h')
ax9_2.set_xlabel('points')
ax9_2.set_ylabel('magnitude')
plt.savefig('look_mic='+str(look_mic)+'_NWIN='+str(NWIN)+'_NFFT='+str(NFFT)+'_hopsize='+str(hopsize)+'.png')
plt.show()

## initial A RAA rAy

In [None]:
A = torch.zeros(MicNum, L, frequency, dtype=torch.complex64).to(device)
RAA = torch.zeros(L, L, frequency, dtype=torch.complex64).to(device)
rAy = torch.zeros(L, 1, frequency, dtype=torch.complex64).to(device)

## recursive process

In [None]:
alpha = 0.99
beta = 0.99
gamma = 0.001
delta = 0.001
dia_load_A = 10^(-2)
dia_load_S_predict = 10^(-2)
save_mode = 'front'    # 'front' or 'back'
S_predict = torch.zeros(L, 1, frequency, dtype=torch.complex64).to(device)
S_save = torch.zeros(frequency, NumOfFrame, dtype=torch.complex64)
S_save_L = torch.zeros(frequency, NumOfFrame, dtype=torch.complex64)

for FrameNo in range(ini_frame, NumOfFrame) :
    for n in range(frequency) :
        Y_before = Y[:, n, FrameNo-1].to(device)
        Y_now = Y[:, n, FrameNo].to(device)
        if FrameNo != ini_frame :
            Rss[:, :, n] = alpha*Rss[:, : ,n] + (1 - alpha)*torch.mm(S_predict[:, :, n], S_predict[:, :, n].conj().T)
            Rsy[:, :, n] = beta*Rsy[:, :, n] + (1 - beta)*torch.mm(S_predict[:, :, n], Y_before.conj().reshape((1, MicNum)))

        # A[:, :, n] = torch.mm(Rsy[:, :, n].conj().T, torch.inverse(Rss[:, :, n] + dia_load_A*torch.eye(L).to(device)))    # A 利用公式算出
        for i in range(A.shape[1]) :    # A 給真的RIR
            A[:, i, :] = torch.from_numpy(H[:, :, i]).to(device)

        RAA[:, :, n] = gamma*RAA[:, :, n] + (1 - gamma)*torch.mm(A[:, :, n].conj().T, A[:, :, n])
        rAy[:, :, n] = delta*rAy[:, :, n] + (1 - delta)*torch.mm(A[:, :, n].conj().T, Y_now.reshape((MicNum, 1)))
        S_predict[:, :, n] = torch.mm(torch.inverse(RAA[:, :, n] + dia_load_S_predict*torch.eye(L).to(device)), rAy[:, :, n])

    if save_mode == 'front' :
        # 存最前面的 frame #
        S_save[:, FrameNo] = S_predict[0, :, :].cpu()
        S_save_L[:, FrameNo] = S_predict[L-1, :, :].cpu()
    else :
        # 存最後面的 frame #
        S_save[:, FrameNo-(L-1)] = S_predict[L-1, :, :].cpu()
        if FrameNo == NumOfFrame :
            for count in range(L) :
                S_save[:, NumOfFrame-count] = S_predict[count, :, :].cpu()

    print(f'right now processing frame = {FrameNo}')

## 畫圖看最後 A 在時域的樣子

In [None]:
A_forplot = np.zeros((MicNum, frequency, L), dtype='complex64')
for i in range(A.shape[1]) :
    A_forplot[:, :, i] = A[:, i, :].cpu()

ctf_fin_tdomain = librosa.istft(A_ini_forplot, hop_length=hopsize, win_length=NWIN, n_fft=NFFT, window='boxcar', center=center)

look_mic = 5    # 上面宣告過 可改成其他數字 or default = look_mic
h_max = np.max(h[look_mic, :])
ctf_fin_tdomain_max = np.max(ctf_fin_tdomain[look_mic, :])
ratio = h_max/ctf_fin_tdomain_max

# 底下畫圖 #
h_yaxis_upperlimit = np.max(h[look_mic]) + 0.025
h_yaxis_underlimit = np.min(h[look_mic]) - 0.025

fig10 = plt.figure(figsize=(10, 7.5))
ax10_1 = fig10.add_subplot(2, 1, 1)
ax10_1.plot(h[look_mic, :], 'r')
ax10_1.plot(ctf_fin_tdomain[look_mic, :]*ratio, 'b')
ax10_1.set_ylim((h_yaxis_underlimit, h_yaxis_upperlimit))
ax10_1.set_title('ctf_fin_tdomain')
ax10_1.set_xlabel('points')
ax10_1.set_ylabel('magnitude')

ax10_2 = fig10.add_subplot(2, 1, 2)
ax10_2.plot(h[look_mic, :], 'r')
ax10_2.set_ylim((h_yaxis_underlimit, h_yaxis_upperlimit))
ax10_2.set_title('h')
ax10_2.set_xlabel('points')
ax10_2.set_ylabel('magnitude')
plt.show()