In [None]:
import numpy as np
import matplotlib.pyplot as plt 
%matplotlib inline
import IPython.display as ipd

import soundfile as sf
import librosa
from data_utils import MudNoise
from model_stft import Mud
import json
import os
import torch
from torch import nn
from torch.autograd import Variable
import csv
from librosa import stft
plt.style.use('dark_background')

In [None]:
datadir = '/Data/DATASETS/WSJ/mud_noise/'
validation_data_path = datadir + 'cv/debug_mud_v1.h5'

data_verbose = MudNoise(validation_data_path, noisedir='/Data/DATASETS/NoiseX/8k/', task='cv')

In [None]:
def load_mask_model(load, base_dir='./'):
    json_dir = base_dir + '/exp/' + load
    with open(json_dir + '/architecture.json', 'r') as fff:
        p = json.load(fff)
        load_path = json_dir + '/net/' + 'cv/'

        model = Mud(n_fft=p['nfft'], kernel=(p['kernel1'], p['kernel2']), causal=p['causal'],
                                layers=p['layers'], stacks=p['stacks'], verbose=False)
        
        model = nn.DataParallel(model)

        mdl_idx = sorted([int(l.split('_')[-1].split('.')[0]) for l in os.listdir(load_path)])[-1]
        model.load_state_dict(torch.load(load_path + 'model_weight_{}.pt'.format(mdl_idx)))
        _ = model.eval()
        return model, p

mdl, _ = load_mask_model('201903162304_baseline_NC')

In [None]:
mix, s1 = data_verbose[7]
mix = Variable(mix[:4]).contiguous().cuda()
recon = mdl(mix.unsqueeze(0))

ipd.display(ipd.Audio(mix.data.cpu().numpy(), rate=8000))
ipd.display(ipd.Audio(recon.squeeze().data.cpu().numpy(), rate=8000))

In [None]:
# (all files)
all_files_names = []
all_traces = []
all_mics = []
all_nspk = {}
with open("/Data/Dropbox/cocoha_workshop_zh/cocoha_code/automated_exps/output_session2_15s.csv", "r") as f:
    reader = csv.reader(f, delimiter=",")
    
    for i, line in enumerate(reader):
        all_nspk[line[0]] = int(line[-1])
        all_traces.append(line[5:-1])
        all_mics.append(line[1:5])
        all_files_names.append(line[0])

print(len(all_files_names))

In [None]:
def find_trigger_v2(x):
    Y = x
    X = stft(Y)
    Y = np.log10(np.abs(X) ** 2 + 1e-8)
    YY = np.sum(Y, 0)
    n90 = np.percentile(YY, 65)
    tr = np.where(YY >= n90)[0][0]
    return tr * 512

name = '20190222114858'
all_data = []
t_in_secs = []
fs = 8000
for i in [5, 6, 7, 8]:
#     _fs, _y = wavfile.read('../data_whisper/session2/test_%s_wh%d.wav' % (name, i))
    _y, _fs = librosa.load('/Data/Dropbox/cocoha_workshop_zh/cocoha_code/data_whisper/session2/test_%s_wh%d.wav' % (name, i), sr=8000, mono=False)
    _y = _y.T.astype('float32')
    all_data.append(_y)
    assert(_fs == fs)
    t_in_secs.append(_y.shape[0] // fs)
    print("%d: %d seconds" % (i, _y.shape[0] // fs), end='\t\t')

n_spk = all_nspk[name]
print("N SPK: {}".format(n_spk))
m_len = np.min([a.shape[0] for a in all_data])
all_data_cut = [a[:m_len] for a in all_data]
all_data_array = np.hstack(all_data_cut) #.astype('float32')


t = np.arange(len(all_data_array[:, 0])) / fs

for_trigger = all_data_array[:, 6]
trigger = find_trigger_v2(for_trigger)
noise_calib = all_data_array[:trigger]
# noise_calib = NOISE_CALIB
print("Noise {}    [MAX = 22s]".format(len(noise_calib) // fs))
calib1 = all_data_array[0 * fs + trigger: 10 * fs + trigger]
calib2 = all_data_array[15 * fs + trigger: 25 * fs + trigger]
calib3 = all_data_array[30 * fs + trigger: 40 * fs + trigger]
calib4 = all_data_array[45 * fs + trigger: 55 * fs + trigger]

rec = all_data_array[max((15 * n_spk), 30) * fs + trigger:]
# rec = all_data_array[60 * fs + trigger:]


print("Removing trigger => [{} ~= {}]".format(t_in_secs[0] - trigger // fs, max((15 * n_spk), 30) + 15))

plt.plot(for_trigger[:trigger + 5 * fs])
plt.plot([trigger, trigger], [np.min(for_trigger[:trigger + 5 * fs]), np.max(for_trigger[:trigger + 5 * fs])], 'g--')

In [None]:
ipd.display(ipd.Audio(rec[:, 0], rate=8000))
ipd.display(ipd.Audio(calib1[:, 0], rate=8000))

In [None]:
mix = Variable(torch.from_numpy(rec.T[:8])).contiguous().cuda()
recon = mdl(mix.unsqueeze(0))

ipd.display(ipd.Audio(calib1[:, 0], rate=8000))
ipd.display(ipd.Audio(rec[:, 0], rate=8000))
ipd.display(ipd.Audio(recon.squeeze().data.cpu().numpy(), rate=8000))