In [None]:
import os
import time
import sys
import itertools
import json
import librosa as lr
import librosa.display as lrd
import numpy as np
import os.path as osp

import matplotlib.pyplot as plt
import IPython.display as ipd
from sklearn.metrics import mean_squared_error

%matplotlib inline

sys.path.append("..")
from libs.utilities import load_autoencoder_lossfunc, load_autoencoder_model, get_func_name
from libs.processing import pink_noise, take_file_as_noise, \
    make_fragments, unmake_fragments, unmake_fragments_slice, \
    s_to_exp, exp_to_s, s_to_reim, reim_to_s, s_to_db, db_to_s , \
    normalize_spectrum, normalize_spectrum_clean, unnormalize_spectrum  


In [None]:
# LOAD LOGS (log order is different from argument list order!)
mag_or_db = 'db'
memory_type = 'tcn' # gru, rnn, tcn
model_name = 'conv_' + memory_type
file_type = 'train' # 'valid' or 'train'

# logs_path = '/home/christie/SingleChannelDenoising_source/train_logs/logs_'+memory_type+'_'+mag_or_db+'.json'
# print(logs_path)
# with open(logs_path) as f:
#     logs = json.load(f)

In [None]:
# some parameters
file_name = 'NPR_News__04-03-2018_12AM_ET.wav' if file_type=='train' else 'newscast230834.wav'
# train: NPR_News__04-03-2018_12AM_ET.wav
# valid: newscast230834.wav
input_path = '/data/riccardo_datasets/npr_news/ds1/train/' + file_name
sr = 16000
snr = 15
n_fft = 512
hop_length = 128
win_length = 512

frag_hop_length = 1

frag_win_length = 32
os.environ["CUDA_VISIBLE_DEVICES"] = '2'


normalize = [False] 
slice_width = [1]
trim_negatives = False

proc_func_args_2 = [
    (s_to_exp(1.0),   exp_to_s(1.0)) if mag_or_db=='mag' 
                                    else (s_to_db, db_to_s)]

args_list_2 = list(itertools.product(normalize, proc_func_args_2, slice_width))

print('\n'.join(['{:2}. {} {} {}'.format(i, arg[0], get_func_name(arg[1][0]), arg[2]) 
                 for i,arg in enumerate(args_list_2)]))



In [None]:
 # load data from file name
x, _ = lr.load(input_path, sr=sr, duration=30, offset=120)

# apply noise!
x_noisy = pink_noise(x=x, sr=sr, snr=snr)

# convert to TF-domain
s = lr.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
s_noisy = lr.stft(x_noisy, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
    

In [None]:
def run_the_whole_thing(args_index, use_phase=True, trim_negatives=False):
    # collect variable arguments
#     print(args_list_2[args_index])
    normalize, proc_funcs, slice_width = args_list_2[args_index]
    proc_func, unproc_func = proc_funcs

    
#     print('[dn] Loading data from {}...'.format(input_path))
   # split into fragments
    s_frags = make_fragments(s, frag_hop_len=frag_hop_length, frag_win_len=frag_win_length)
    s_frags_noisy = make_fragments(s_noisy, frag_hop_len=frag_hop_length, frag_win_len=frag_win_length)
    # apply pre-processing (data representation)
    y_frags = proc_func(s_frags)
    y_frags_noisy = proc_func(s_frags_noisy)
#     print('[dn] Generated {} fragments with shape {}'.format(len(y_frags_noisy), y_frags_noisy[0].shape))
    # normalization factors
    nf_frags = np.empty((len(y_frags), 2))
    nf_frags_noisy = np.empty((len(y_frags_noisy), 2))
    y_frags_n = np.zeros_like(y_frags)
    y_frags_noisy_n = np.zeros_like(y_frags_noisy)

    # normalize fragments (batch-wise)
    #y_frags_noisy_n, nf_frags_noisy = normalize_spectrum(y_frags_noisy)
    #print(nf_frags_noisy)


    # normalize fragments (individually)
    for i in range(len(y_frags_noisy)):
        frag_normalized, frag_norm_factors = normalize_spectrum(y_frags[i])
        frag_normalized_noisy, frag_norm_factors_noisy = normalize_spectrum(y_frags_noisy[i])
        y_frags_n[i] = frag_normalized if normalize else y_frags[i]
        y_frags_noisy_n[i] = frag_normalized_noisy if normalize else y_frags_noisy[i]
        nf_frags_noisy[i] = frag_norm_factors_noisy
        nf_frags[i] = frag_norm_factors

#     time_slice = slice(np.int(frag_win_length//2 - slice_width/2), \
#                          np.int(frag_win_length//2 + slice_width/2)) 
    
    time_slice = slice(0,3) 

#     print('time_slice :', time_slice.start, time_slice.stop)
    
    
    #### LOAD TRAINED MODEL ###
    model_path = '/data/riccardo_models/denoising/phase2/ds1/'+mag_or_db+'/'+model_name+'.h5'
    print("model_path :", model_path)
    model, lossfunc = load_autoencoder_model(model_path, time_slice=time_slice)


    ### PREDICT DATA ###
#     print('[dn] Predicting with trained model...')
    y_frags_pred = model.predict(y_frags_noisy_n)
#     y_frags_pred = y_frags_noisy_n
#     print('[dn] Prediction finished! Generated {} fragments'.format(len(y_frags_pred)))


    ### PLOT A FEW PREDICTED FRAGMENTS ###
    specrange = np.linspace(0, len(y_frags_pred), 3, dtype=int)
    sample_specs = []
#     print(list(specrange))

    plt.figure(figsize=(18, 6))
    plt.suptitle('Clean (true), noisy, and predicted sample fragments (normalized)')
    rangespan = len(specrange)
    j = 1
    y_frags_pred_dn = np.zeros_like(y_frags_pred)
    for i in range(len(y_frags_pred)):
        # un-normalize (individually)
        y_frags_pred_dn[i] = unnormalize_spectrum(y_frags_pred[i], nf_frags_noisy[i]) if normalize else y_frags_pred[i]
        if i in specrange:  
            sample_specs.append((y_frags_noisy_n[i,...,0], y_frags_pred[i,...,0]))
            plt.subplot(3,rangespan,j)
            lrd.specshow(lr.amplitude_to_db(y_frags_n[i,...,0]), vmin=-10, vmax=5, cmap='coolwarm')

            plt.subplot(3,rangespan,rangespan+j)
            lrd.specshow(lr.amplitude_to_db(y_frags_noisy_n[i,...,0]), vmin=-10, vmax=5, cmap='coolwarm')

            plt.subplot(3,rangespan,rangespan*2+j)
            lrd.specshow(lr.amplitude_to_db(y_frags_pred[i,...,0]), vmin=-10, vmax=5, cmap='coolwarm')

            j += 1


    # un-normalize (batch-wise)
    #y_frags_pred_dn = unnormalize_spectrum(y_frags_pred, nf_frags_noisy)

    # trim negative values
    if trim_negatives:
        y_frags_pred[y_frags_pred < 0] = 0
    s_pred = unproc_func(y_frags_pred, s_frags_noisy) if use_phase else unproc_func(y_frags_pred)

    # undo fragments
    
    s_pred = unmake_fragments_slice(s_pred, frag_hop_len=frag_hop_length, frag_win_len=frag_win_length, time_slice=time_slice)
    
    # get waveform
    x_pred = lr.istft(s_pred, hop_length=hop_length, win_length=win_length)



    ### PLOT RESULT AND LISTEN ###
    duration = 20
    offset = 2
    l = lr.time_to_frames(duration, sr=sr, n_fft=n_fft, hop_length=hop_length)
    o = lr.time_to_frames(offset, sr=sr, n_fft=n_fft, hop_length=hop_length)
    print(o,l)

    plt.figure(figsize=(18, 5))
    plt.suptitle('True and predicted spectrograms (de-normalized, {} seconds)'.format(duration))

    plt.subplot(211)
    lrd.specshow(lr.amplitude_to_db(np.abs(s[:,o:o+l])), vmin=-50, vmax=25, cmap='coolwarm')
    plt.colorbar()

    plt.subplot(212)
    lrd.specshow(lr.amplitude_to_db(np.abs(s_pred[:,o:o+l])), vmin=-50, vmax=25, cmap='coolwarm')
    plt.colorbar()
    
    print('done!')
    return x_pred, s_pred, sample_specs
    

In [None]:
# TEST VARIOUS MODELS HERE
        
x_pred, s_pred, sample_specs = run_the_whole_thing(0, use_phase=True, trim_negatives=trim_negatives)


In [None]:
ipd.Audio(x_pred, rate=sr) # valid

In [None]:
ipd.Audio(x_pred, rate=sr) # train

In [None]:
ipd.Audio(x_noisy, rate=sr)

In [None]:
time_slice = slice(np.int(frag_win_length//2 - slice_width[0]/2), 
                     np.int(frag_win_length//2 + slice_width[0]/2)) 
t_start = time_slice.start

## trimming
s_pred_2 = s_pred[:, frag_win_length:-frag_win_length]
s_noisy_2 = s_noisy[:, frag_win_length: s_pred_2.shape[1]+frag_win_length]
s_2 = s[:, frag_win_length: s_pred_2.shape[1]+frag_win_length]

snr = s_pred_2**2 / (s_noisy_2 - s_2)**2
sdr =  s_2**2 / (s_2 - s_pred_2)**2
print('snr :',np.mean(lr.power_to_db(np.abs(snr))))
print('sdr :',np.mean(lr.power_to_db(np.abs(sdr))))

In [None]:
# time_slice_len = time_slices_len[2]
# time_slice = slice(np.int(frag_win_length//2 - time_slice_len/2), 
#                      np.int(frag_win_length//2 + time_slice_len/2)) 
# i = time_slice.start

In [None]:
## trimming
s_pred_2 = s_pred[:, frag_win_length:-frag_win_length]
s_noisy_2 = s_noisy[:, frag_win_length: s_pred_2.shape[1]+frag_win_length]
s_2 = s[:, frag_win_length: s_pred_2.shape[1]+frag_win_length]

print(s_pred.shape, s_noisy.shape, s.shape)
print(s_pred_2.shape, s_noisy_2.shape, s_2.shape)

In [None]:
## Evaluation 

## SNR
snr = s_pred_2**2 / (s_noisy_2 - s_2)**2
sdr =  s_2**2 / (s_2 - s_pred_2)**2 
print('snr :',np.mean(lr.power_to_db(np.abs(snr))))
print('sdr :',np.mean(lr.power_to_db(np.abs(sdr))))
# snr.shape

plt.figure(figsize=(18, 5)), 
plt.subplot(2,1,1)
lrd.specshow(lr.power_to_db(np.abs(snr)), cmap='coolwarm')
plt.suptitle('SNR')
plt.colorbar()

plt.subplot(2,1,2)
lrd.specshow(lr.power_to_db(sdr), cmap='coolwarm')
plt.suptitle('SDR')
plt.colorbar()

In [None]:
plt.figure(figsize=(18, 5))
plt.subplot(211)
lrd.specshow(lr.amplitude_to_db(np.abs(s_noisy_2[:,:100])), vmin=-50, vmax=25, cmap='coolwarm') #[:,:100]

plt.subplot(212)
lrd.specshow(lr.amplitude_to_db(np.abs(s_pred_2[::-1,:100])), vmin=-50, vmax=25, cmap='coolwarm')
plt.subplots_adjust(wspace=0, hspace=0)

In [None]:
# plt.figure(figsize=(18, 5))
# plt.subplot(211)
# lrd.specshow(lr.amplitude_to_db(np.abs(s_noisy[:,-(100):])), vmin=-50, vmax=25, cmap='coolwarm')

# plt.subplot(212)
# lrd.specshow(lr.amplitude_to_db(np.abs(s_pred[::-1,-100:])), vmin=-50, vmax=25, cmap='coolwarm')
# plt.subplots_adjust(wspace=0, hspace=0)

In [None]:
print(s_pred.shape, s_noisy.shape)


In [None]:
ipd.Audio(x, rate=sr)

In [None]:
###############################

In [None]:
### PLOT RESULT AND LISTEN ###
duration = 20
offset = 2
l = lr.time_to_frames(duration, sr=sr, n_fft=n_fft, hop_length=hop_length)
o = lr.time_to_frames(offset, sr=sr, n_fft=n_fft, hop_length=hop_length)
print(o,l)

plt.figure(figsize=(18, 5))
plt.suptitle('True and predicted spectrograms (de-normalized, {} seconds)'.format(duration))

plt.subplot(211)
lrd.specshow(lr.amplitude_to_db(np.abs(s[:,o:o+l])), vmin=-50, vmax=25, cmap='coolwarm')
plt.colorbar()

plt.subplot(212)
lrd.specshow(lr.amplitude_to_db(np.abs(s_pred[:,o:o+l])), vmin=-50, vmax=25, cmap='coolwarm')
plt.colorbar()