In [None]:
import os, sys, math, random, glob, shutil, time, functools, itertools
from pathlib import Path 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers 
import tensorflow.signal as tfs 

from scipy.io import wavfile 
from IPython.display import Audio, display

## Environment Constants

In [None]:
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

SR = 16000
SEGMENT_SEC = 2.0
SEGMENT = int(SR * SEGMENT_SEC)
N_FFT = 1024 
HOP = 256
WIN_LENGTH = 1024 
N_MELS = 128
PAD_MODE = 'REFLECT'

BATCH_SIZE = 8
EPOCHS = 15
STEPS_PER_EPOCH = 600
VAL_STEPS = 80
LEARNING_RATE = 3e-4 
WARMUP_STEPS = 500
EMA_DECAY = 0.999
CHECKPOINT_DIR = '/kaggle/working/denoiser_ckpt'
EXPORT_DIR = '/kaggle/working/denoiser_export'

In [None]:
os.makedirs(CHECKPOINT_DIR, exist_ok = True)
os.makedirs(EXPORT_DIR, exist_ok = True)

## Audio Utility IO functions

In [None]:
def norm_audio(x):
    x = np.asarray(x, dtype=np.float32)
    mx = np.max(np.abs(x)) + 1e-9
    return x / mx

In [None]:
def read_wav_mono(path, target_sr=SR):
    sr,y = wavfile.read(path)
    y = y.astype(np.float32)
    if y.ndim == 2:
        y = y.mean(axis=1)
    if sr!= target_sr:
        y = tf.audio.resample(y, sr, target_sr).numpy()
    return norm_audio(y), target_sr

In [None]:
def write_wav(path, y, sr=SR):
    y = np.asarray(y, dtype=np.float32)
    y = (y/(np.max(np.abs(y)) + 1e-9)*0.99)
    # scaling up from [-1, 1] to 32767
    wavfile.write(path, sr, (y*32767.0).astype(np.int16))

## Signal Transforms

In [None]:
# performs short time fourier transform 
# outputs 2d array of complex numbers
def stft(sig):
    return tfs.stft(
        sig, 
        frame_length=WIN_LENGTH, # how many samples to look at once
        frame_step=HOP, # how much to hop forward, in our case 1024 - 256 samples will be overlapped
        fft_length=N_FFT # how many frequency bins result from each analysis
        window_fn=tf.signal.hann_window # smooths edges to avoid sharp transitions
    )

In [None]:
# converts time frequency complex representation to time domain audio signal
def istft(stft_c, length):
    return tfs.inverse_stft(
        stft_c,
        frame_length=WIN_LENGTH,
        frame_step=HOP,
        window_fn=tf.signal.hann_window,
        output_length=length
    )

In [None]:
def complex_mag(stft_c):
    return tf.abs(stft_c)

In [None]:
def eps():
    return 1e-8

In [None]:
# goes from linear resolution to mel_resolution
# lower frequency bins are spaced close together (small pitch differences noticable to humans)
MEL_FILTER = tfs.linear_to_mel_weight_matrix(
    num_mel_bins=N_MELS,
    num_spectogram_bins=N_FFT/2 + 1, # linear frequency bins from STFT input
    sample_rate=SR,
    lower_edge_hertz=0.0,
    upper_edge_hertz=SR/2 # Nyquist frequency (half the sample rate)
)

## Visualization functions

## Synthetic Noise Generator (Clean and Noisy dataset)

In [None]:
def gen_tone(duration, sr=SR):
    t = np.linspace(0, duration, int(sr*duration), endpoint=False)
    # generating a time array 16k points per sec
    f0 = np.random.uniform(100, 1000) # base frequency
    y = np.sin(2*np.pi*f0*t) # pure sine wave at frequency f0
    # blend pure freqency with chirp with 50% probab
    if np.random.rand() < 0.5:
        f1 = np.random.uniform(200, 2000)
        # tone whose frequency changes over time, keeps changing from f0 to f1 linearly
        chirp = np.sin(2*np.pi*(f0 + (f1-f0)*t/duration)*t)
        y = 0.6*y + 0.4*chirp
    env = 0.5*(1-np.cos(2*np.pi*np.minimum(1.0, t/duration)))
    # smooth cosine curve controlling volume over time, prevents sudden starts or stops
    return norm_audio(y * env)

In [None]:
def gen_noise(duration, sr=SR):
    n = int(sr*duration)
    # white noise: energy concentration equal
    white = np.random.randn(n).astype(np.float32)
    freqs = np.fft.rfftfreq(n, 1/sr) # 1d array
    # pink noise: energy concentrate more at lower frequency
    pink_spec = (np.random.randn(len(freqs))+1j*np.random.randn(len(freqs)))/np.maximum(freqs, 1.0)
    # random complex numbers to generate noise, frequency below zero stays same
    pink = np.fft.irfft(pink_spec, n=n).astype(np.float32)
    # convert back to time domain
    babble = np.zeros(n, dtype=np.float32)
    # summing up several tones (3 to 6) to simulate overlapping sounds
    for _ in range(np.random.randint(3, 7)):
        babble += gen_tone(duration, sr)
    babble = babble / (np.max(np.abs(babble)) + 1e-9)
    mix = 0.5*white/np.max(np.abs(white)+1e-9) + 0.3*pink/np.max(np.abs(pink)+1e-9) + 0.2*babble
    return norm_audio(mix)
    

In [None]:
def random_segment(y, length):
    if len(y) < length:
        pad = length - len(y)
        y = np.pad(y, (0, pad), mode='reflect')
        return y
    start = np.random.randint(0, len(y)-length)
    return y[start:start+length]

In [None]:
def mix_clean_noise(clean, noise, snr_db=None):
    if snr_db is None:
        snr_db = np.random.uniform(-5, 15)
    # normalizing both
    c = clean / (np.std(clean)+1e-9)
    n = noise / (np.std(noise)+1e-9)
    # getting rms of both signals
    rms_c = np.sqrt(np.mean(c**2)+1e-9)
    rms_n = np.sqrt(np.mean(n**2)+1e-9)
    target_rms_n = rms_c / (10**(snr_db/20.0))
    # scaling the noise to get the desired ratio
    n = n * (target_rms_n / (rms_n + 1e-9))
    noisy = c + n
    return norm_audio(noisy), norm_audio(c), norn_audio(n)

## Pipeline

In [None]:
def wav_loader_factory(clean_paths, noise_paths):
    # loader function yields one noisy, clean pair (is iterable)
    def load_and_mix(_):
        if clean_paths:
            cp = random.choice(clean_paths)
            c, _sr = read_wav_mono(cp, SR)
        else:
            c = gen_tone(SEGMENT_SEC)
        if noise_paths and np.random.rand() < 0.9:
            npth = random.choice(noise_paths)
            n, _sr = read_wav_mono(npth, SR)
        else:
            n = gen_noise(SEGMENT_SEC + 1.0)
        c_seg = random_segment(c, SEGMENT)
        n_seg = random_segment(n, SEGMENT)
        # noisy will be the model input and clean will be the target
        noisy, clean, noise = mix_clean_noise(c_seg, n_seg)
        return noisy.astype(np.float32), clean.astype(float32)
    return load_and_mix
        

In [None]:
# Example workflow:
# For an example step size of 600
# Each call to loader function returns a (noisy, clean) pair, ((32000,), (32000,))
# Gen functions inside tf_dataset calls the loader function 600 * 8 * 2 = 9600 times
# For each epoch a fresh pool is generate 9600 new samples
# The samples are shuffled
# From this pool, batches of 8 are created, so in total 600 batches of 8 samples are created
# Train_ds is an iterable object
# Calling next() on it yields one batch -> ((8, 32000), (8, 32000))
# You can call the next function 600 times

def tf_dataset(clean_paths, noise_paths, batch_size, steps):
    # generator function to call the loader function 2 * required amount times (helps in shuffling)
    def gen(): # stream
        loader = wav_loader_factory(clean_paths, noise_paths)
        for _ in range(steps * batch_size * 2):
            yield loader(None)
    # output dimensions
    output_sig = (tf.TensorSpec(shape=(SEGMENT,), dtype=tf.float32),
                  tf.TensorSpec(shape=(SEGMENT,), dtype=tf.float32))
    # reiterable (generates fresh pool for every epoch)
    ds = tf.data.Dataset.from_generator(gen, output_signature=output_sig)
    ds = ds.shuffle(8192, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size, drop_remainder=True)
    # asynchronously preparing the next batch while the current one is being processed
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = tf_dataset(CLEAN_WAVS, NOISE_WAVS, BATCH_SIZE, STEPS_PER_EPOCH)
val_ds = tf_dataset(CLEAN_WAVS, NOISE_WAVS, BATCH_SIZE, VAL_STEPS)

noisy_b, clean_b = next(iter(train_ds))
print("Batch shapes:", noisy_b.shape, clean_b.shape)

## Model

In [None]:
# layer friendly stft functions 
# input: (Batch, Datapoints): eg (8, 32000) for 2 sec sample
def stft_layer(x):
    X = tf.numpy_function(lambda a: tfs.stft(a, WIN_LENGTH, HOP, N_FFT, window_fn=tf.signal.hann_window).numpy(),
                         [x], Tout=tf.complex64)
    X.set_shape([None, None, N_FFT//2 + 1])
    return X

In [None]:
class STFTMagLayer(layers.Layer):
    def call(self, x):
        # input : (8, 32000)
        # output: (8, 122, 513)
        X = tfs.stft(x, frame_length=WIN_LENGTH, frame_step=HOP, fft_length=N_FFT,
                    window_fn=tf.signal.hann_window) # outputs array of complex numbers
        mag = tf.abs(X) # magnitude
        phase = tf.math.angle(X) # phase 
        # helpful for reconstructing audio in inverse stft
        return tf.transpose(mag, [0, 1, 2]), tf.transpose(phase, [0, 1, 2]), X

In [None]:
def db_log(x):
    return tf.math.log(x+ 1e-6)

In [None]:
def inv_db_log(x):
    return tf.math.expm1(x)

In [None]:
def unet_block(x, filters, name, down=True):
    if down:
        x = layers.Conv2D(filters, 3, strides=2, padding='same', name=name+'_conv')(x)
        x = layers.BatchNormalization(name=name+'_bn')(x)
        x = layers.Activation('relu', name=name+'_relu')(x)
        return x
    else:
        x = layer.Conv2DTranspose(filters, 3, strides=2, padding='same', name=name+'_deconv')(x)
        x = layers.BatchNormalization(name=name+'_bn')(x)
        x = layers.Activation('relu', name=name+'_relu')(x)
        return x