# Speech Enhancement Data Preparation + Primary Model

## Necessary Libraries

In [1]:
import os
from six.moves import urllib
from tqdm import tqdm
import tarfile
import zipfile
import librosa
import torch
import soundfile as sf
import numpy as np
from scipy.fft import fft
import matplotlib.pyplot as plt
import IPython.display as ipd
import random

In [2]:
sampling_rate = 16000

In [3]:
def makedir_exist_ok(dirpath):  # pylint: disable=no-self-use
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)

In [4]:
def gen_bar_updater():  # pylint: disable=no-self-use
    pbar = tqdm(total = None)

    def bar_update(count, block_size, total_size):
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)

    return bar_update

def download_and_extract_archive(url, download_root, extract_root=None,
                                   filename=None,
                                   md5=None, remove_finished=False):
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)

    archive = os.path.join(download_root, filename)

    if not os.path.isfile(archive):
        download_url(url, download_root, filename, md5)
    else:
        print(f'{filename} exists, no need to download again!')

    print("Extracting {} to {}".format(archive, extract_root))
    extract_archive(archive, extract_root, remove_finished)
    

def download_url(url, root, filename=None, md5=None):
    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename)

    makedir_exist_ok(root)

    try:
        print('Downloading ' + url + ' to ' + fpath)
        urllib.request.urlretrieve(url, fpath, reporthook = gen_bar_updater())
    except (urllib.error.URLError, IOError) as e:
        if url[:5] == 'https':
            url = url.replace('https:', 'http:')
            print('Failed download. Trying https -> http instead.'
                  ' Downloading ' + url + ' to ' + fpath)
            urllib.request.urlretrieve(url, fpath,
                                       reporthook = gen_bar_updater())
        else:
            raise e
            
def extract_archive(from_path,  # pylint: disable=no-self-use
                      to_path=None, remove_finished=False):
    if to_path is None:
        to_path = os.path.dirname(from_path)

    if from_path.endswith('.tar.gz'):
        with tarfile.open(from_path, 'r:gz', errorlevel = 1) as tar:
            for dir in os.listdir(path = to_path):
                if os.path.isdir(os.path.join(to_path, dir)):
                    rmdirfiles(os.path.join(to_path, dir))
            tar.extractall(path = to_path)

    elif from_path.endswith('.zip'):
        with zipfile.ZipFile(from_path, "r") as zip_ref:
            zip_ref.extractall(path = to_path)
    else:
        raise ValueError("Extraction of {} not supported".format(from_path))

    if remove_finished:
        os.remove(from_path)

def rmdirfiles(directory):
    for (dirpath, dirnames, filenames) in os.walk(directory):
        for filename in filenames:
            fname = os.path.join(dirpath, filename)
            os.chmod(fname, 0o777)
            os.remove(fname)
            
def resample(folder, sr=sampling_rate, ext='.flac', max_prog=100):
    resampled = []
    file_cnt = 0
    num_files = file_count(folder, ext)
    for (dirpath, dirnames, filenames) in os.walk(folder):
        for filename in sorted(filenames):
            file_cnt += 1
            progress = (100 * file_cnt / num_files)
            if filename.endswith(ext):
                fname = os.path.join(dirpath, filename)
                data, samplerate = librosa.load(fname, sr = sr)
                print(f'\r{int(10 * progress) / 10}%: {fname}, {samplerate} {data.shape}',
                      end = " ")
                resampled.append(data)
                if progress >= max_prog:
                    break
        else:
            continue
        break
    print('\r')
    return resampled


def file_count(folder, ext='.flac'):
    file_cnt = 0
    for (dirpath, dirnames, filenames) in os.walk(folder):
        for filename in sorted(filenames):
            if filename.endswith(ext):
                file_cnt += 1
    return file_cnt

def snr_mixer(clean, noise, snr):
    # Normalizing to rms equal to 1
    rmsclean = np.mean(clean[:] ** 2) ** 0.5
    rmsnoise = np.mean(noise[:] ** 2) ** 0.5
    noisyspeech = []

    if rmsclean != 0 and rmsnoise != 0:
        scalarclean = 1 / rmsclean
        clean = clean * scalarclean
        scalarnoise = 1 / rmsnoise
        noise = noise * scalarnoise

        # Set the noise level for a given SNR
        cleanfactor = 10 ** (snr / 20)
        noisyspeech = cleanfactor * clean + noise
        noisyspeech = noisyspeech / (scalarnoise + cleanfactor * scalarclean)
        scaled_clean = cleanfactor * clean / (scalarnoise + cleanfactor * scalarclean)
        scaled_noise = noise / (scalarnoise + cleanfactor * scalarclean)

        valid = True

    else:
        valid = False
    return noisyspeech, valid, scaled_clean, scaled_noise

def quantize_audio(data, num_bits=8):
    """Quantize audio
    """

    step_size = 1.0 / 2 ** (num_bits)
    max_val = 2 ** (num_bits) - 1
    q_data = np.round(data / step_size)
    q_data = np.clip(q_data, 0, max_val)

    return np.uint8(q_data)

In [5]:
url_speech = "https://www.openslr.org/resources/12/dev-clean.tar.gz"
url_noise = "https://zenodo.org/record/1227121/files/OOFFICE_48k.zip?download=1"
fs = sampling_rate

# NOTE: for quick testing to only download small portion of repo, make 100 to resample all repo (VERY LONG)
speech_repo_resample_percent = 2  # 1-100%
noise_repo_resample_percent = 2
raw_folder_speech = os.path.join('SE', 'raw_speech')
raw_folder_noise = os.path.join('SE', 'raw_noise')
processed_folder = os.path.join('SE', 'processed')

resampled_speech_file = "speech_" + str(fs / 1000) + "KHz.pt"
resampled_noise_file = "noise_" + str(fs / 1000) + "KHz.pt"
data_file = 'dataset-speech.pt'
print(f'\rWarning: Resample {speech_repo_resample_percent}% of data repo and {noise_repo_resample_percent}% of noise repo')

makedir_exist_ok(raw_folder_speech)
makedir_exist_ok(raw_folder_noise)
makedir_exist_ok(processed_folder)

filename_speech = url_speech.rpartition('/')[2]
filename_noise = url_noise.rpartition('/')[2].rsplit('?')[0]
download_and_extract_archive(url_speech, download_root = raw_folder_speech, filename = filename_speech)
download_and_extract_archive(url_noise, download_root = raw_folder_noise, filename = filename_noise)

if not os.path.exists(os.path.join(processed_folder, resampled_speech_file)):
    print(f'\rResampling data @ {fs}Hz')
    data = resample(raw_folder_speech, sr = fs, ext = '.flac',
                           max_prog = speech_repo_resample_percent)

    # save resampled speech
    print(f'\rSaving resampled data: {os.path.join(processed_folder, resampled_speech_file)}')
    torch.save(data, os.path.join(processed_folder, resampled_speech_file))
else:
    print(f'\rWarning: Resampled data file exists (remove and run again to regenerate), start loading...')
    data = torch.load(os.path.join(processed_folder, resampled_speech_file))
    
if not os.path.exists(os.path.join(processed_folder, resampled_noise_file)):
    print(f'\rResampling noise @ {fs}Hz\r')
    noise = resample(raw_folder_noise, sr = fs, ext = '.wav',
                            max_prog = noise_repo_resample_percent)

    # save resampled noise
    print(
        f'\rSaving resampled noise: {os.path.join(processed_folder, resampled_noise_file)}')
    torch.save(noise, os.path.join(processed_folder, resampled_noise_file))
else:
    print(
        f'\rWarning: Resampled noise file exists (remove and run again to regenerate), start loading...')
    noise = torch.load(
        os.path.join(processed_folder, resampled_noise_file))
    



dev-clean.tar.gz exists, no need to download again!
Extracting SE\raw_speech\dev-clean.tar.gz to SE\raw_speech
OOFFICE_48k.zip exists, no need to download again!
Extracting SE\raw_noise\OOFFICE_48k.zip to SE\raw_noise


In [6]:
noise_in = []

for noise_frame in noise:
    noise_in = np.append(noise_in, noise_frame, axis = 0)  # make an nx1 noise array


In [7]:
sampling_rate = 16000
frame_length = int((20*sampling_rate)/1000) # 20 msec
chunk_size = 128 # temporal context 128*10msec ~ 1sec
fft_num = 320
win_length = int((20*sampling_rate)/1000) # 20 msec
win_hop = int((10*sampling_rate)/1000) # 10 msec
snr = 6
flag = True
stft_len = fft_num//2 + 1

In [8]:
for count, data_frame in tqdm(enumerate(data)):
    print(f'\rProcessing {count + 1} of {len(data)} data', end = "")
    noise_start = random.randint(0, noise_in.shape[0]- data_frame.shape[0])
    noise = noise_in[noise_start: noise_start + data_frame.shape[0]]
    noisy_frame, valid, scaled_clean, scaled_noise = snr_mixer(data_frame, noise, snr)
    noisy_feat = np.abs(librosa.stft(noisy_frame, n_fft=fft_num, win_length=win_length, hop_length=win_hop, window='hanning').T)
    noisy_angle = np.angle(librosa.stft(noisy_frame, n_fft=fft_num, win_length=win_length, hop_length=win_hop, window='hanning').T)
    clean_feat = np.abs(librosa.stft(scaled_clean, n_fft=fft_num, win_length=win_length, hop_length=win_hop, window='hanning').T)
    # normalize to 0-1.0
    noisy_feat = noisy_feat / np.amax(noisy_feat)
    clean_feat = clean_feat / np.amax(clean_feat)
    
    chunk_start = 0
    clean_start = chunk_size - 1
    while True:
        if chunk_start + chunk_size > noisy_feat.shape[0]:
            break

        feat_in = noisy_feat[chunk_start: chunk_start + chunk_size, :]
        angles = noisy_angle[chunk_start: chunk_start + chunk_size, :]
        feat_out = clean_feat[clean_start, :]
        chunk_start += 1
        clean_start += 1
        # only first column?
        feat_in = np.expand_dims(feat_in, axis=0)
        feat_out = np.expand_dims(feat_out, axis=0)
        angles = np.expand_dims(angles, axis=0)

        
        if flag:
            stft_in = feat_in.copy()
            label = feat_out.copy()
            angle = angles.copy()
            flag = False
        else:
            stft_in = np.concatenate((stft_in, feat_in), axis=0)
            label = np.concatenate((label, feat_out), axis = 0)
            angle = np.concatenate((angle, angles), axis = 0)
        #print(stft_in.shape)
        #print(label.shape)
    if count == 1:
        break

    
print('\rQuantizing stft output and labels...\r')
stft_in_q = quantize_audio(stft_in)
label_q = quantize_audio(label)
# 0 (90%): train and test   1 (10%): validate
data_type = (np.random.rand(label_q.shape[0], 1) + 0.1).astype(int)
print(f'\rtrain data: {stft_in_q.shape}  labels: {label_q.shape}  data type:{data_type.shape}')
speech_dataset = (stft_in_q, label_q, angle, data_type)
torch.save(speech_dataset, os.path.join(processed_folder, data_file))
print(f'\rDataset created: {os.path.join(processed_folder, data_file)}')

0it [00:00, ?it/s]

Processing 1 of 47 data

1it [00:11, 11.61s/it]

Processing 2 of 47 data

1it [00:36, 36.28s/it]


Quantizing stft output and labels...
train data: (814, 128, 161)  labels: (814, 161)  data type:(814, 1)
Dataset created: SE\processed\dataset-speech.pt


# 上面好像只process了一个data

In [9]:
dataset = torch.load('SE/processed/dataset-speech.pt')

In [10]:
dataset[1].shape

(814, 161)

In [11]:
clean_snippet = dataset[1][:1280]

In [12]:
import librosa
import numpy as np
noisy_snippet = dataset[0][:, 127, :]
print(noisy_snippet.shape)
noisy_angle = dataset[2][:, 127, :]
print(noisy_angle.shape)
clean_snippet = dataset[1]
print(clean_snippet.shape)

(814, 161)
(814, 161)
(814, 161)


In [13]:
noisy_recon = clean_snippet*(np.cos(noisy_angle)+1j*np.sin(noisy_angle))
x = librosa.istft(noisy_recon.T, hop_length=win_hop, win_length=win_length, window='hanning', center=True, dtype=None, length=None)

In [14]:
ipd.Audio(x,rate=16000)

In [15]:
import torch.nn as nn



class SENetv0(nn.Module):

    # output = 161 STFT feats
    def __init__(
            self,
            num_channels=161,
            dimensions=(161, 1),  # pylint: disable=unused-argument
            bias=False,
            **kwargs

    ):
        super().__init__()

        self.project = nn.Sequential(
           
            nn.Conv1d(161, 256, 9), # in : 161 x 128; out: 256 x 120
            nn.ReLU(inplace=True),   
            
            nn.MaxPool1d(2),# in : 256 x 120; out: 256 x 60           
            nn.Conv1d(256, 128, 9),# in : 256 x 60; out: 128 x 52
            nn.ReLU(),
            nn.BatchNorm1d(128),
            
            #########################################################################
            nn.Conv1d(128, 128, 9, padding=4),# in : 128 x 52; out: 128 x 52
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 128, 9, padding=4),# in : 128 x 52; out: 128 x 52
            nn.ReLU(),
            nn.BatchNorm1d(128),
            #########################################################################
            nn.MaxPool1d(2),# in : 128 x 52; out: 128 x 26
            
            nn.Conv1d(128, 96, 9),# in : 128 x 26, out: 96 x 18
            nn.ReLU(),
            nn.BatchNorm1d(96),
            nn.MaxPool1d(2), # in : 96 x 18, out: 96 x 9            
            nn.Conv1d(96, 161, 9), # in : 96 x 9, out: 161 x 1
        )



    def forward(self, x):
        x = self.project(x)
        return x

In [16]:
model = SENetv0().cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
optimizer.zero_grad()
loss_fn = torch.nn.MSELoss()
_input, _output, angles, _lab = dataset
from tqdm import tqdm
for au in tqdm(range(_input.shape[0])):
    au_in = torch.from_numpy(_input[0].T).cuda().unsqueeze(0).float()
    au_out = torch.from_numpy(_output[0]).cuda().float()
    pred_out = model(au_in)
    loss = loss_fn(pred_out, au_out)
    loss.backward()
    optimizer.step()

  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 814/814 [00:29<00:00, 27.24it/s]
