In [1]:
import os
os.environ["CUDA_VISIBLE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math
import pickle
gpu_options = tf.GPUOptions(allow_growth = True)
config=tf.ConfigProto(gpu_options=gpu_options)
config.gpu_options.per_process_gpu_memory_fraction = 0.34

In [2]:
import librosa

In [3]:
n_file = 120

max_length = 200

sr = 16000

In [4]:
def IBM(S, N):
    M = []
    
    for i in range(len(S)):
        m_ibm = 1 * (S[i] > N[i])
        M.append(m_ibm)
    
    return M

In [5]:
def loadfile(path,stri, flag = 1):
    list_tr = []
    list_stft = []
    list_stft_abs = []
    list_length = []
    
    for i in range(n_file):
        s, sr = librosa.load(path + 'ratio_'+str(i+1) + '_'+ stri+'.wav', sr = None)
        if (flag == 1):
            list_tr.append(s)
        
        #Calculating STFT
        stft = librosa.stft(s, n_fft= 1024, hop_length= 512)
        
        stft_len = stft.shape[1]
        
        #Appending STFT to list
        if (flag == 1):
            list_stft.append(stft)
        
        #Calculating Absolute of STFT
        stft_abs = np.abs(stft)
        
        #Padding zeros to make length 300
        stft_abs = np.pad(stft_abs, ((0,0),(0, max_length-stft_len)), 'constant')
        
        #Appending abs to list
        list_stft_abs.append(stft_abs.T)
        
        #Appending time-length of STFT to list
        list_length.append(stft_len)
        
    return list_tr, list_stft, list_stft_abs, list_length

In [6]:
trs, S, S_abs, S_len = loadfile("clean/","clean")
trx, X, X_abs, X_len = loadfile("noisy/","mixture")
tri, I , I_abs, I_len= loadfile("IBM/","ideal")

In [7]:
x = tf.placeholder(tf.float32, [10, None, 513])
y_ = tf.placeholder(tf.float32, [10, None, 513])

hidden_units = 256
out_weights = tf.Variable(tf.random_normal([hidden_units, 513], stddev=2/(hidden_units+513), mean=0)) # xavier init
out_bias = tf.Variable(tf.zeros([513]))

In [8]:
lstm_cell = tf.nn.rnn_cell.LSTMCell(hidden_units, initializer=tf.contrib.layers.xavier_initializer())
outputs, _ = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32)

In [9]:
weights = tf.expand_dims(tf.ones([10, 1]), 1) * out_weights

y = tf.nn.sigmoid(tf.matmul(outputs, weights) + out_bias)

In [10]:
mse = tf.reduce_mean(tf.losses.mean_squared_error(y_, y))
train_step = tf.train.AdamOptimizer().minimize(mse) # adam optimizer with default learning rate

In [11]:
init = tf.global_variables_initializer()
saver = tf.train.Saver() 
sess = tf.Session(config=config)
sess.run(init)

In [12]:
epochs = 100

for epoch in range(epochs):
    avg_cost = 0
    for i in range(0, 120, 10):
        batch_x = X_abs[i:i+10] 
        batch_y = I_abs[i:i+10]
        _, cost = sess.run([train_step, mse], feed_dict={x: batch_x, y_: batch_y})
        avg_cost += cost/120
    print("Epoch:", '%02d'%(epoch+1), "\tcost={:.9f}".format(avg_cost))

Epoch: 01 	cost=0.244389131
Epoch: 02 	cost=0.232434666
Epoch: 03 	cost=0.229284970
Epoch: 04 	cost=0.226815259
Epoch: 05 	cost=0.225271697
Epoch: 06 	cost=0.224424066
Epoch: 07 	cost=0.223996287
Epoch: 08 	cost=0.223718130
Epoch: 09 	cost=0.223523233
Epoch: 10 	cost=0.223368681
Epoch: 11 	cost=0.223229909
Epoch: 12 	cost=0.223102104
Epoch: 13 	cost=0.222995716
Epoch: 14 	cost=0.222894757
Epoch: 15 	cost=0.222793529
Epoch: 16 	cost=0.222701914
Epoch: 17 	cost=0.222619606
Epoch: 18 	cost=0.222541101
Epoch: 19 	cost=0.222463336
Epoch: 20 	cost=0.222390724
Epoch: 21 	cost=0.222333739
Epoch: 22 	cost=0.222270543
Epoch: 23 	cost=0.222207232
Epoch: 24 	cost=0.222144685
Epoch: 25 	cost=0.222069002
Epoch: 26 	cost=0.222019869
Epoch: 27 	cost=0.221962179
Epoch: 28 	cost=0.221910851
Epoch: 29 	cost=0.221867323
Epoch: 30 	cost=0.221808627
Epoch: 31 	cost=0.221778184
Epoch: 32 	cost=0.221756242
Epoch: 33 	cost=0.221769155
Epoch: 34 	cost=0.221707160
Epoch: 35 	cost=0.221624671
Epoch: 36 	cost=0.22

KeyboardInterrupt: 

In [None]:
model_path = "model.ckpt"
# saving the model
save_path = saver.save(sess, model_path)
# restore trained model to use for validation
saver.restore(sess, model_path)
# load pickled time domain S_v: clean signals from validation set

with open('data/validation_clean.pkl', 'rb') as f:
    s = pickle.load(f)
# loading picked validation set files X_v, S_v, N_v in stft domain

with open('data/validation.pkl', 'rb') as f:
    X_v, S_v, N_v = pickle.load(f)

# taking magnitude and transpose

X_v_mod = [np.abs(signal).T for signal in X_v]
S_v_mod = [np.abs(signal).T for signal in S_v]
N_v_mod = [np.abs(signal).T for signal in N_v]
X_v_T = [signal.T for signal in X_v]

M_v = []
for i in range(1200):
    M_v.append(np.greater(S_v_mod[i], N_v_mod[i]).astype(int))

# checking validation loss and calculating snr

avg_cost = 0
snr = []

for i in range(0, 1200, 10):
    batch_x = X_v_mod[i:i+10] 
    batch_y = M_v[i:i+10]

    cost, M_hat = sess.run([mse, y], feed_dict={x: batch_x, y_: batch_y})
    
    avg_cost += cost/120
    
    batch_x_complex = X_v_T[i:i+10]
    batch_s = s[i:i+10]
    for j in range(10):
        S_hat = np.multiply(M_hat[j], batch_x_complex[j])
        s_hat = librosa.istft(S_hat.T, win_length=1024, hop_length=512)
        
        t = min(len(s_hat), len(batch_s[j]))
        snr.append(10*np.log10((np.sum(np.square(batch_s[j][:t])))/np.sum(np.square(batch_s[j][:t]-s_hat[:t]))))

print("Validation loss = {:.9f}".format(avg_cost))
print("SNR = ", sum(snr)/1200)

In [None]:
# filenames of test files (400 test files)
te_filenames = ['tex{}.wav'.format(id) for id in ids[:400]]

# load pickled test files X_te in stft domain

with open('data/test.pkl', 'rb') as f:
    X_te, srs = pickle.load(f)
    

# taking magnitude and transpose

X_te_mod = [np.abs(signal).T for signal in X_te]
X_te_T = [signal.T for signal in X_te]

# reconstructing test signals

for i in range(0, 400, 10):
    batch_x = X_te_mod[i:i+10] 
    M_hat = sess.run(y, feed_dict={x: batch_x})
    
    batch_x_complex = X_te_T[i:i+10]
    batch_filenames = te_filenames[i:i+10]
    batch_sr = srs[i:i+10]
    for j in range(10):
        S_hat = np.multiply(M_hat[j], batch_x_complex[j])
        s_hat = librosa.istft(S_hat.T, win_length=1024, hop_length=512)
        librosa.output.write_wav('recons'+batch_filenames[j][-11:], s_hat, srs[j])
        
sess.close()