# Speech Enhancement

## 对比三个效果最明显的语音增强模型

In [None]:
# 记录推理时间
import time
time_start = time.perf_counter()  # 记录开始时间
# function()   执行的程序
time_end = time.perf_counter()  # 记录结束时间
time_sum = time_end - time_start  # 计算的时间差为程序的执行时间，单位为秒/s
print("cpu运行程序时间: ", time_sum * 1000, " ms")

### Denoiser facebook

In [None]:
from IPython import display as disp
import torch
import torchaudio

from denoiser import pretrained
from denoiser.dsp import convert_audio

In [None]:
model = pretrained.dns64()
# wav, sr = torchaudio.load("/Users/yuexiajiao/Downloads/gtcrn-main/stream/test_wavs/mix.wav")
wav, sr = torchaudio.load("/Users/yuexiajiao/Downloads/noisy_trainset_28spk_wav/p226_007.wav")

wav = convert_audio(wav, sr, model.sample_rate, model.chin)
print(sr)

with torch.no_grad():
    denoised = model(wav[None])[0]
disp.display(disp.Audio(wav.data.cpu().numpy(), rate=model.sample_rate))
disp.display(disp.Audio(denoised.data.cpu().numpy(), rate=model.sample_rate))

### SEGAN+

In [None]:
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pre_model.segan.models import *
from pre_model.segan.datasets import *
import soundfile as sf
from scipy.io import wavfile
from torch.autograd import Variable
import numpy as np
import random
import librosa
import matplotlib
import timeit
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import json
import glob
import os

In [None]:
class ArgParser(object):

    def __init__(self, args):
        for k, v in args.items():
            setattr(self, k, v)

In [None]:
def main(opts):
    assert opts.cfg_file is not None
    assert opts.test_files is not None
    assert opts.g_pretrained_ckpt is not None

    with open(opts.cfg_file, 'r') as cfg_f:
        args = ArgParser(json.load(cfg_f))
        print('Loaded train config: ')
        # print(json.dumps(vars(args), indent=2))
    args.cuda = opts.cuda
    if hasattr(args, 'wsegan') and args.wsegan:
        segan = WSEGAN(args)     
    else:
        segan = SEGAN(args)     
    segan.G.load_pretrained(opts.g_pretrained_ckpt, True)
    if opts.cuda:
        segan.cuda()
    segan.G.eval()
    if opts.h5:
        with h5py.File(opts.test_files[0], 'r') as f:
            twavs = f['data'][:]
    else:
        # process every wav in the test_files
        if len(opts.test_files) == 1:
            # assume we read directory
            twavs = glob.glob(os.path.join(opts.test_files[0], '*.wav'))
        else:
            # assume we have list of files in input
            twavs = opts.test_files
            print(len(twavs))
    print('Cleaning {} wavs'.format(len(twavs)))
    beg_t = timeit.default_timer()
    for t_i, twav in enumerate(twavs, start=1):
        if not opts.h5:
            tbname = os.path.basename(twav)
            rate, wav = wavfile.read(twav)
            wav = normalize_wave_minmax(wav)
        else:
            tbname = 'tfile_{}.wav'.format(t_i)
            wav = twav
            twav = tbname
        wav = pre_emphasize(wav, args.preemph)
        pwav = torch.FloatTensor(wav).view(1,1,-1)
        if opts.cuda:
            pwav = pwav.cuda()
        g_wav, g_c = segan.generate(pwav)
        out_path = os.path.join(opts.synthesis_path,
                                tbname) 
        if opts.soundfile:
            sf.write(out_path, g_wav, 16000)
        else:
            wavfile.write(out_path, 16000, g_wav)
        end_t = timeit.default_timer()
        print('Cleaned {}/{}: {} in {} s'.format(t_i, len(twavs), twav,
                                                 end_t-beg_t))
        beg_t = timeit.default_timer()

In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--g_pretrained_ckpt', type=str, default="/Users/yuexiajiao/Desktop/DailyReport/pre_model/segan/ckpt_segan+/segan+_generator.ckpt")
    parser.add_argument('--test_files', type=str, nargs='+', default="/Users/yuexiajiao/Desktop/DailyReport/test_waves/dpcrn-compress")
    parser.add_argument('--h5', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=111, 
                        help="Random seed (Def: 111).")
    parser.add_argument('--synthesis_path', type=str, default="/Users/yuexiajiao/Desktop/DailyReport/test_waves/segan+",
                        help='Path to save output samples (Def: ' \
                             'segan_samples).')
    parser.add_argument('--cuda', action='store_true', default=False)
    parser.add_argument('--soundfile', action='store_true', default=False)
    parser.add_argument('--cfg_file', type=str, default="/Users/yuexiajiao/Desktop/DailyReport/pre_model/segan/ckpt_segan+/train.opts")

    opts = parser.parse_args([])

    if not os.path.exists(opts.synthesis_path):
        os.makedirs(opts.synthesis_path)
    
    # seed initialization
    random.seed(opts.seed)
    np.random.seed(opts.seed)
    torch.manual_seed(opts.seed)
    if opts.cuda:
        torch.cuda.manual_seed_all(opts.seed)

    main(opts)

In [None]:
SAVE_PATH="synth_segan+"

python -u clean.py --g_pretrained_ckpt $CKPT_PATH/$G_PRETRAINED_CKPT \
	--test_files $TEST_FILES_PATH --cfg_file $CKPT_PATH/train.opts \
	--synthesis_path $SAVE_PATH --soundfile

### DRCRN

### DPCRN Compressed

In [None]:
import os
import torch
import soundfile as sf
from pre_model.gtcrn import GTCRN

In [None]:
## load model
device = torch.device("cpu")
# model = GTCRN().eval()
ckpt = torch.load(os.path.join('pre_model/gtcrn', 'model_trained_on_dns3.tar'), map_location=device)
# model.load_state_dict(ckpt['model'])
# print(model)
model = GTCRN()
# print(model)
## load data
mix, fs = sf.read('/Users/yuexiajiao/Desktop/DailyReport/test_waves/noisy_test/p226_002_16k.wav', dtype='float32')
# 因为模型的默认输入是16000Hz，所以这里需要对48kHz重新适配采样率
# assert fs == 16000
print(fs)

In [None]:
# 降采样
import librosa
# to install librosa package
# > conda install -c conda-forge librosa 

newFilename = '/Users/yuexiajiao/Desktop/DailyReport/test_waves/noisy_test/p226_002_16k.wav'

y, sr = librosa.load('/Users/yuexiajiao/Desktop/DailyReport/test_waves/noisy_test/p226_002.wav', sr=48000)
y_8k = librosa.resample(y, orig_sr=sr, target_sr=16000)

import soundfile as sf
sf.write(newFilename, y_8k, 16000)

# librosa.output.write_wav(newFilename, y_8k, sr=16000)

In [None]:
## inference
input = torch.stft(torch.from_numpy(mix), 512, 256, 512, torch.hann_window(512).pow(0.5), return_complex=False)
# print(input[None].shape)
# print(input.shape)

with torch.no_grad():
    output = model(input[None])[0]
#     print(output.shape)
# print(output.shape)
# output = torch.view_as_complex(output)
enh = torch.istft(output, 512, 256, 512, torch.hann_window(512).pow(0.5))

## save enhanced wav
sf.write(os.path.join('/Users/yuexiajiao/Desktop/DailyReport/test_waves/dpcrn-compress/enh.wav'), enh.detach().cpu().numpy(), fs)

In [None]:
from IPython import display as disp
disp.display(disp.Audio(mix, rate=16000))
disp.display(disp.Audio(enh, rate=16000))