# Rosie's Hearing Aid

# To-Do
* Debug decoder block to check if output dim follows FxCxNc
* Make decoder be able to seperate N sources 


In [1]:
import pandas as pd 
import numpy as np
import math
from itertools import permutations
import evaluation

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from tensorflow_addons.optimizers import AdamW

from einops import rearrange

import IPython as ip
import pywt
import librosa
import librosa.display
import sounddevice as sd
import soundfile as sf

import matplotlib.pyplot as plt 
import seaborn as sns 

## Encoder Block
![title](Diagrams/encoder.jpg)

### Chunk Operator 
Constants: $$ \lambda \sim \text{sample rate} \newline T \sim \text{max duration of inputs in seconds} \newline C_L \sim \text{chunk length}
$$
Dependants: 
$$ \text{N-Chunks} = \text{floor}({\frac{\lambda T}{C_L}}) $$

In [2]:
class ChunkOperator(layers.Layer):
    """ Performs a chunk operation on a 2D (feature,time) tensor 
    and outputs a 3D (feature,short_time,long_time) tensor. 

    Parameters
    ----------
    samplerate_hz : int
        The samlerate of the audio clip 
    max_input_length_in_seconds : int
        The max length of input audio in seconds, if clip is shorter than 
        max length then zero padding is applied     
    chunk_length : int
        States length of sliding window in chunk operation over time axis 
    num_filters_in_encoder: int 
        Number of filters in Conv1d operation used before chunk operation, this 
        data is needed for dim/padding problems 
    batch_size : int 
        The batch size is needed for reshaping during the chunk operation

    Returns
    -------
    tf.constant 
        3D (feature,short_time,long_time) tensor
    """
    
    def __init__(self,samplerate_hz,
                 max_input_length_in_seconds,
                 chunk_length,
                 num_filters_in_encoder,
                 batch_size):
        super(ChunkOperator,self).__init__()
        
        # Constants
        self.num_filters_in_encoder = num_filters_in_encoder
        self.batch_size = batch_size
        self.samplerate_hz = samplerate_hz
        self.max_input_length_in_seconds = max_input_length_in_seconds
        self.chunk_length = chunk_length 
        
        # Dependants
        self.num_full_chunks = self.samplerate_hz*self.max_input_length_in_seconds//self.chunk_length
        self.signal_length_samples = self.chunk_length*self.num_full_chunks
        self.chunk_advance = self.chunk_length // 2
        self.num_overlapping_chunks = self.num_full_chunks*2-1
        
        # The layer itself
        self.chunk_operator = tf.keras.layers.Lambda(self.segment_encoded_signal)
        
    @tf.function
    def segment_encoded_signal(self, x):
        x1 = tf.reshape(x, (self.batch_size, self.signal_length_samples//self.chunk_length , self.chunk_length , self.num_filters_in_encoder))
        x2 = tf.roll(x, shift=-self.chunk_advance, axis=1)
        x2 = tf.reshape(x2, (self.batch_size, self.signal_length_samples//self.chunk_length , self.chunk_length , self.num_filters_in_encoder))
        x2 = x2[:, :-1, :, :] # Discard last segment with invalid data

        x_concat = tf.concat([x1, x2], axis=1)
        x = x_concat[:, ::self.num_full_chunks, :, :]
        for i in range(1, self.num_full_chunks):
            x = tf.concat([x, x_concat[:, i::self.num_full_chunks, :, :]], axis=1)
        return x

    @tf.function
    def call(self, x):
        x = self.chunk_operator(x)
        return x

In [3]:
# Chunk operator parameters
BATCH_SIZE = 1
NETWORK_NUM_FILTERS_IN_ENCODER = 64
MAX_INPUT_LENGTH_IN_SECONDS = 10
SAMPLERATE_HZ = 8000
NETWORK_CHUNK_SIZE = 256
MAX_INPUT_LENGTH_IN_SECONDS = 10
SAMPLERATE_HZ = 8000

In [4]:
chunk_layer = ChunkOperator(samplerate_hz=SAMPLERATE_HZ,
    max_input_length_in_seconds=MAX_INPUT_LENGTH_IN_SECONDS,
    chunk_length=NETWORK_CHUNK_SIZE,
    num_filters_in_encoder=NETWORK_NUM_FILTERS_IN_ENCODER,
    batch_size=BATCH_SIZE)

x = np.random.uniform(0,1,(1,79872, 64))
x_chunk = chunk_layer(x)
if x_chunk.shape != (1, 623, 256, 64):
    raise ValueError('Chunk operation output is the wrong size!!!')
else:
    print("Chunk operation test passed!!!")

Chunk operation test passed!!!


### Encoder

In [5]:
class Encoder(layers.Layer):
    """ Takes original audio input x -> ReLU(Conv1D(x)) -> h0
         -> LayerNorm(Dense(h0)) -> h1 -> Chunk(h1) -> hc

    Parameters
    ----------
    num_filters_in_encoder: int 
        Number of filters in Conv1d operation used before chunk operation, this 
        data is needed for dim/padding problems 
    encoder_filter_length: int 
        Length of filter in Conv1D
    samplerate_hz : int
        The samlerate of the audio clip 
    max_input_length_in_seconds : int
        The max length of input audio in seconds, if clip is shorter than 
        max length then zero padding is applied     
    chunk_length : int
        States length of sliding window in chunk operation over time axis 
    batch_size : int 
        The batch size is needed for reshaping during the chunk operation

    Returns
    -------
    tf.constant 
        3D (feature,short_time,long_time) tensor
    """
    
    def __init__(self,num_filters_in_encoder,
                 encoder_filter_length,
                 samplerate_hz,
                 max_input_length_in_seconds,
                 chunk_length,
                 batch_size):
        super(Encoder,self).__init__()
        
        # Constants
        self.batch_size = batch_size
        self.num_filters_in_encoder = num_filters_in_encoder
        self.encoder_filter_length = encoder_filter_length
        self.batch_size = batch_size
        self.samplerate_hz = samplerate_hz
        self.max_input_length_in_seconds = max_input_length_in_seconds
        self.chunk_length = chunk_length 
        
        # Dependants
        self.encoder_hop_size = self.encoder_filter_length // 2
        
    def build(self,input_shape):
        self.conv1d = tf.keras.layers.Conv1D(filters=self.num_filters_in_encoder, \
                                      kernel_size=self.encoder_filter_length,  \
                                      strides=self.encoder_hop_size, use_bias=False, \
                                      padding="same")
        self.layer_norm = tf.keras.layers.LayerNormalization()
        self.chunk_operator = ChunkOperator(samplerate_hz=self.samplerate_hz,
                    max_input_length_in_seconds=self.max_input_length_in_seconds,
                    chunk_length=self.chunk_length,
                    num_filters_in_encoder=self.num_filters_in_encoder,
                    batch_size=self.batch_size)
        
    @tf.function
    def call(self, inputs):
        encoder_out = self.conv1d(tf.expand_dims(inputs, axis=2))
        x = self.layer_norm(encoder_out)
        x = self.chunk_operator(x)
        return x, encoder_out

In [6]:
# Encoder block parameters
BATCH_SIZE = 1
NETWORK_NUM_FILTERS_IN_ENCODER = 64
NETWORK_ENCODER_FILTER_LENGTH = 2
MAX_INPUT_LENGTH_IN_SECONDS = 10
SAMPLERATE_HZ = 8000
NETWORK_CHUNK_SIZE = 256
MAX_INPUT_LENGTH_IN_SECONDS = 10
SAMPLERATE_HZ = 8000

In [7]:
encoder_block = Encoder(num_filters_in_encoder=NETWORK_NUM_FILTERS_IN_ENCODER,
    encoder_filter_length=NETWORK_ENCODER_FILTER_LENGTH,
    samplerate_hz=SAMPLERATE_HZ,
    max_input_length_in_seconds=MAX_INPUT_LENGTH_IN_SECONDS,
    chunk_length=NETWORK_CHUNK_SIZE,
    batch_size=BATCH_SIZE)

x = np.random.uniform(0,1,(1,79872))
x_chunk, _ = encoder_block(x)
if x_chunk.shape != (1, 623, 256, 64):
    raise ValueError('Encoder block output is the wrong size!!!')
else:
    print("Encoder block test passed!!!")

Encoder block test passed!!!


## Linear Transformer with SineSPE

In [8]:
class SineSPE(layers.Layer):
    def __init__(self, 
                 num_heads: int = 8,
                 in_features: int = 64,
                 num_realizations: int = 256,
                 num_sines: int = 1):
        super(SineSPE, self).__init__()
        
        self.num_heads = num_heads
        self.in_features = in_features 
        self.num_sines = num_sines 
        self.num_realizations = num_realizations
        
        freqs_init = tf.random_normal_initializer()
        self.freqs = tf.Variable(
            initial_value=freqs_init(shape=(num_heads, in_features, num_sines), dtype="float32"),
            trainable=True,
        )
        
        offsets_init = tf.random_normal_initializer()
        self.offsets = tf.Variable(
            initial_value=offsets_init(shape=(num_heads, in_features, num_sines), dtype="float32"),
            trainable=True,
        )
        
        gains_init = tf.random_normal_initializer()
        self.gains = tf.Variable(
            initial_value=gains_init(shape=(num_heads, in_features, num_sines), dtype="float32"),
            trainable=True,
        )
        
        # Normalize gains 
        self.gains = self.gains/(tf.math.sqrt(tf.norm(self.gains,axis=-1,keepdims=True))/2)
        
        # Bias intial freqs
        self.freqs = self.freqs-4
        
        self.code_shape = (num_heads,in_features)

    def call(self, shape):
        """
        Generate the code, composed of a random QBar and Kbar,
        depending on the parameters, and return them for use with a
        SPE module to actually encode queries and keys.
        Args:
            shape: The outer shape of the inputs: (batchsize, *size)
            num_realizations: if provided, overrides self.num_realizations
        """
        
        if len(shape) != 2:
            raise ValueError('Only 1D inputs are supported by SineSPE')
        
        max_len = shape[1]
        
        # build omega_q and omega_k
        # with shape (num_heads,keys_dim,length,2*num_sines)
        indices = tf.linspace(0,max_len-1,max_len)
        indices = tf.cast(indices, dtype=tf.float32)

        # make sure freqs are in [0,.5]
        freqs = tf.nn.sigmoid(self.freqs[:,:,None,:])/2
        
        phases_q = 2*math.pi*freqs*indices[None,None,:,None]*self.offsets[:,:,None,:]
        omega_q = tf.stack([tf.math.cos(phases_q),tf.math.sin(phases_q)],axis=-1)
        omega_q = tf.reshape(omega_q,[1,self.num_heads,self.in_features,max_len,2*self.num_sines] )
        
        phases_k = 2*math.pi*freqs*indices[None,None,:,None]
        omega_k = tf.stack([tf.math.cos(phases_k),tf.math.sin(phases_k)],axis=-1)
        omega_k = tf.reshape(omega_k,[1,self.num_heads,self.in_features,max_len,2*self.num_sines] )
        
        # Gains is (num_heads,keys_dim,num_sines), make nonnegative with softplut
        gains = tf.math.softplus(self.gains)
        
        # Upsample
        gains = tf.stack([gains,gains],axis=-1)
        gains = tf.reshape(gains, [self.num_heads,self.in_features,2*self.num_sines])
        
        # Draw noise
        z = tf.random.normal((1,self.num_heads,self.in_features,2*self.num_sines,self.num_realizations))
        z = z/tf.math.sqrt(tf.cast(self.num_sines*2, dtype=tf.float32))
        
        # Scale each of the 2*num_sines by the appropriate gain
        z = z*gains[None, ..., None]
    
        # Compute sums over sines
        qbar = tf.linalg.matmul(omega_q,z)
        kbar = tf.linalg.matmul(omega_k,z)
        
        # Pemute to (1,length,num_heads,key_dim,num_realization)
        qbar = tf.transpose(qbar, perm=[0,3,1,2,4])
        kbar = tf.transpose(kbar, perm=[0,3,1,2,4])

        # scale
        scale = (self.num_realizations*self.in_features)**.25
        return (qbar/scale,kbar/scale)

In [9]:
class SPEFilter(layers.Layer):
    """Stochastic positional encoding filter
    Applies a positional code provided by a SPE module on actual queries and keys.
    Implements gating, i.e. some "dry" parameter, that lets original queries and keys through if activated.
    Args:
    gated: whether to use the gated version, which learns to balance
        positional and positionless features.
    code_shape: the inner shape of the codes, i.e. (num_heads, key_dim),
        as given by `spe.code_shape`
    """
    def __init__(self,gated,code_shape):
        super(SPEFilter, self).__init__()

        self.gated = gated
        self.code_shape = code_shape

        # create the gating parameters if required
        if gated:
            if code_shape is None:
                raise RuntimeError('code_shape has to be provided if gated is True.')

            gate_init = tf.random_normal_initializer()
            self.gate = tf.Variable(
                initial_value=gate_init(shape=(code_shape), dtype="float32"),
                trainable=True,
            )  

    def call(self,queries,keys,code):
        """
        Apply SPE on keys with a given code.
        Expects keys and queries of shape `(batch_size, ..., num_heads,
        key_dim)` and outputs keys and queries of shape `(batch_size,
        ..., num_heads, num_realizations)`. code is the tuple
        of the 2 tensors provided by the code instance, each one of
        shape (1, ..., num_heads, key_dim, num_realizations)
        """
        assert (queries.shape == keys.shape), \
            "As of current implementation, queries and keys must have the same shape. "\
            "got queries: {} and keys: {}".format(queries.shape, keys.shape)

        # qbar and kbar are (1, *shape, num_heads, keys_dim, num_realizations)
        (qbar, kbar) = code

        # check that codes have the shape we are expecting
        if self.code_shape is not None and qbar.shape[-3:-1] != self.code_shape:
            raise ValueError(
                f'The inner shape of codes is {qbar.shape[-3:-1]}, '
                f'but expected {self.code_shape}')

        # check shapes: size of codes should be bigger than queries, keys
        code_size = qbar.shape[1:-3]
        query_size = queries.shape[1:-2]
        

        #if (len(code_size) != len(query_size)
        #    or tf.reduce_any(
        #        tf.Variable(code_size) < tf.Variable(query_size)
        #    )):
        #        raise ValueError(f'Keys/queries have length {query_size}, '
        #                         f'but expected at most {code_size}')
                
        if (len(code_size) != len(query_size)):
                raise ValueError(f'Keys/queries have length {query_size}, '
                                 f'but expected at most {code_size}')
                
        if qbar.shape[-3:-1] != queries.shape[-2:]:
            raise ValueError(f'shape mismatch. codes have shape {qbar.shape}, '
                             f'but queries are {queries.shape}')

        # truncate qbar and kbar for matching current queries and keys,
        # but only if we need to
        for dim in range(len(query_size)):
            if code_size[dim] > query_size[dim]:
                indices = [slice(1), *[slice(qbar.shape[1+k]) for k in range(dim)],
                           slice(query_size[dim])]
                qbar = qbar[indices]
                kbar = kbar[indices]

        # apply gate if required
        if self.gated:
            # incorporate the constant bias for Pd if required. First draw noise
            # such that noise noise^T = 1, for each head, feature, realization.
            # qbar is : (1, *shape, num_heads, keys_dim, num_realizations)
            in_features = qbar.shape[-2]
            num_realizations = qbar.shape[-1]
            gating_noise = tf.random.normal(self.code_shape+\
                            (num_realizations,))/(in_features*num_realizations)**.25
            
            
            # normalize it so that it's an additive 1 to Pd
            #gating_noise = gating_noise / gating_noise.norm(dim=2, keepdim=True)

            # constrain the gate parameter to be in [0 1]
            gate = tf.math.sigmoid(self.gate[..., None])

            # qbar is (1, *shape, num_heads, keys_dim, num_realizations)
            # gating noise is (num_heads, keys_dim, num_realizations)
            # gate is (num_heads, keys_dim, 1)
            #import ipdb; ipdb.set_trace()
            qbar = tf.math.sqrt(1.-gate) * qbar  + tf.math.sqrt(gate) * gating_noise
            kbar = tf.math.sqrt(1.-gate) * kbar  + tf.math.sqrt(gate) * gating_noise

        # sum over d after multiplying by queries and keys
        # qbar/kbar are (1, *shape, num_heads, keys_dim, num_realizations)
        # queries/keys  (batchsize, *shape, num_heads, keys_dim)
        qhat = tf.math.reduce_sum(qbar * queries[..., None],axis=-2)
        khat = tf.math.reduce_sum(kbar * keys[..., None],axis=-2)

        # result is (batchsize, ..., num_heads, num_realizations)
        return (qhat, khat)

In [10]:
@tf.function
def compute_linear_mhsa(q, k, v):
    q = tf.nn.gelu(q)+1 # Needed for kernel assumption 
    k = tf.nn.gelu(k)+1
    kv = tf.einsum('... h s d, ...  h s m  -> ... h m d',k,v)
    k_sum = tf.math.reduce_sum(k,axis=2)
    z = 1/ (tf.einsum('... h l d, ... h d -> ... h l',q ,k_sum)+1e-4)
    Vhat = tf.einsum('... h l d, ... h m d, ... h l -> ... h l m',q,kv,z)
    return Vhat

class LinearAttentionSineSPE(tf.keras.layers.Layer):
    def __init__(self, d_model, heads=8, num_sines=5):
        super(LinearAttentionSineSPE, self).__init__()
        self.num_heads = heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)
        
        self.spe_encoder = SineSPE(num_heads=heads,          # Number of attention heads
                          in_features=self.depth,       # Dimension of keys and queries
                          num_realizations=self.depth,  # New dimension of keys and queries
                          num_sines=num_sines)          # Number of sinusoidal components
        self.spe_filter = SPEFilter(gated=True, code_shape=self.spe_encoder.code_shape)

    @tf.function
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
        
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)   

        q = tf.transpose(q,perm=[0,2,1,3])
        k = tf.transpose(k,perm=[0,2,1,3])
        
        pos_codes = self.spe_encoder(q.shape[:2])  # pos_codes is a tuple (qbar, kbar)
        q, k = self.spe_filter(q, k, pos_codes)
        q = tf.transpose(q,perm=[0,2,1,3])
        k = tf.transpose(k,perm=[0,2,1,3])
        
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention = compute_linear_mhsa(q, k, v)

        scaled_attention = tf.transpose(
            scaled_attention, perm=[0, 2, 1, 3]
        )  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(
            scaled_attention, (batch_size, -1, self.d_model)
        )  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output

In [11]:
class LinearSineSPETransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        
        super(LinearSineSPETransformerBlock, self).__init__()

        self.lha = LinearAttentionSineSPE(embed_dim,num_heads)
        
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="gelu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.lha(inputs,inputs,inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [12]:
BATCH_SIZE = 32
seq_len = 100
n_heads = 8
embedding_dim = 64 
ff_dim = 40

x = np.random.uniform(0,1,(BATCH_SIZE,seq_len,embedding_dim))
ltb = LinearSineSPETransformerBlock(embedding_dim, n_heads ,ff_dim)

y = ltb(x)

print(y.shape)
if y.shape != (32, 100, 64):
    raise ValueError('Linear SineSPE Transformer block output is the wrong size!!!')
else:
    print("Linear SineSPE Transformer block test passed!!!")

(32, 100, 64)
Linear SineSPE Transformer block test passed!!!


## Decoder Block 

In [146]:
class Decoder(layers.Layer):
    """ Takes Chop Shop output and combines with decoder output 
        to form sepeated sources.
        
    Parameters
    ----------
    n_sources: int 
        Number of sources to sepeate audio into
    num_filters_in_encoder: int 
        Number of filters in Conv1d operation used before chunk operation, this 
        data is needed for dim/padding problems 
    encoder_filter_length: int 
        Length of filter in Conv1D
    samplerate_hz : int
        The samlerate of the audio clip 
    max_input_length_in_seconds : int
        The max length of input audio in seconds, if clip is shorter than 
        max length then zero padding is applied     
    chunk_length : int
        States length of sliding window in chunk operation over time axis 
    batch_size : int 
        The batch size is needed for reshaping during the chunk operation

    Returns
    -------
    tf.constant 
        List of seperated signals 
    """
    
    def __init__(self,signal_length_samples,
                 n_sources,
                 num_filters_in_encoder,
                 encoder_filter_length,
                 samplerate_hz,
                 max_input_length_in_seconds,
                 chunk_length,
                 batch_size):
        super(Decoder,self).__init__()
        
        # Constants
        self.n_sources = n_sources
        self.batch_size = batch_size
        self.num_filters_in_encoder = num_filters_in_encoder
        self.encoder_filter_length = encoder_filter_length
        self.batch_size = batch_size
        self.samplerate_hz = samplerate_hz
        self.signal_length_samples = signal_length_samples
        self.max_input_length_in_seconds = max_input_length_in_seconds
        self.chunk_length = chunk_length 
        self.chunk_advance =  chunk_length  // 2
        
        # Dependants
        self.encoder_hop_size = self.encoder_filter_length // 2
        
        self.cut_len = 0
        if self.num_filters_in_encoder%self.n_sources != 0:
            self.cut_len = self.num_filters_in_encoder%self.n_sources
            
        
    def build(self,input_shape):
        self.overlap_and_add_mask_segments_layer = keras.layers.Lambda(self.overlap_and_add_mask_segments)
    
        self.DenseLayers = []
        for i in range(self.n_sources):
            dense_layer = keras.layers.Dense(units=self.encoder_filter_length, use_bias=False,name='OverLapAddDense_'+str(i))
            self.DenseLayers.append(dense_layer)
            
        self.OverLapAndAddDecoderLayers = []
        for i in range(self.n_sources):
            overlap_and_add_in_decoder_layer = keras.layers.Lambda(self.overlap_and_add_in_decoder,name='OverLapAdd_'+str(i))
            self.OverLapAndAddDecoderLayers.append(overlap_and_add_in_decoder_layer)
        
    @tf.function
    def overlap_and_add_mask_segments(self, x):
        x = tf.transpose(x, [0, 3, 1, 2])
        x = tf.signal.overlap_and_add(x, self.chunk_advance)
        return tf.transpose(x, [0, 2, 1])

    @tf.function
    def overlap_and_add_in_decoder(self, x):
        return tf.signal.overlap_and_add(x, self.encoder_hop_size)

    @tf.function
    def call(self, x, encoder_out):
        masks = self.overlap_and_add_mask_segments_layer(x)
        x = masks*encoder_out

        if self.cut_len != 0:
            x = tf.split(x,num_or_size_splits=[self.num_filters_in_encoder-self.cut_len,self.cut_len], axis=-1, num=None, name='split_cut')[0]
        d_sources = tf.split(x, num_or_size_splits=self.n_sources, axis=-1, num=None, name='split')
        
        decoded_sources = []
        for i in range(self.n_sources):
            decoded_spk = self.DenseLayers[i](d_sources[i])
            decoded_spk = self.OverLapAndAddDecoderLayers[i](decoded_spk)[:, :self.signal_length_samples]
            decoded_sources.append(decoded_spk)
            
        decoded = tf.stack(decoded_sources, axis=1)

        return decoded

## Loss Function: Si-SNR

Si-SNR $= 10 \log_{10} \frac{||x_{target}||^2}{||e_{noise}||^2} $, where 
$x_{target}=\frac{< \hat{x},x > x }{||x||^2}$

In [147]:
"""
@tf.function
def log10(x):
    numerator = tf.math.log(x)
    denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator

@tf.function
def sisnr(s, s_hat, do_expand=False, eps=1e-8):
    if do_expand:
        s = np.expand_dims(s, axis=0)
        s_hat = np.expand_dims(s_hat, axis=0)
    dot_product = tf.math.reduce_sum(s*s_hat, axis=1, keepdims=True)
    squares = tf.math.reduce_sum(s*s, axis=1, keepdims=True)
    s_target = s * dot_product / squares
    e_noise = s_hat - s_target
    s_target_squared = tf.math.reduce_sum(s_target*s_target, axis=1)
    e_noise_squared = tf.math.reduce_sum(e_noise*e_noise, axis=1)
    return 10*log10(s_target_squared / (e_noise_squared + eps))

@tf.function
def permutation_invariant_loss(y_true, y_pred,n_sources=2):
    sn, sn_hat = [], []
    for i in range(n_sources):
        sn.append(y_true[:,i,:])
        sn_hat.append(y_pred[:,i,:])
    
    if n_sources == 2:
        sisnr_perm0_spk0 = sisnr(sn[0], sn_hat[0])
        sisnr_perm0_spk1 = sisnr(sn[1], sn_hat[1])
        sisnr_perm0 = (sisnr_perm0_spk0 + sisnr_perm0_spk1) / 2

        sisnr_perm1_spk0 = sisnr(sn[0], sn_hat[1])
        sisnr_perm1_spk1 = sisnr(sn[1], sn_hat[0])
        sisnr_perm1 = (sisnr_perm1_spk0 + sisnr_perm1_spk1) / 2

        sisnr_perm_invariant = tf.math.maximum(sisnr_perm0, sisnr_perm1)
        return -sisnr_perm_invariant
    
    elif n_sources == 3:
        sisnr_perm0_spk0 = sisnr(sn[0], sn_hat[0])
        sisnr_perm0_spk1 = sisnr(sn[1], sn_hat[1])
        sisnr_perm0_spk2 = sisnr(sn[2], sn_hat[2])
        sisnr_perm0 = (sisnr_perm0_spk0 + sisnr_perm0_spk1 + sisnr_perm0_spk2) / 3

        sisnr_perm1_spk0 = sisnr(sn[0], sn_hat[0])
        sisnr_perm1_spk1 = sisnr(sn[1], sn_hat[2])
        sisnr_perm1_spk2 = sisnr(sn[2], sn_hat[1])
        sisnr_perm1 = (sisnr_perm1_spk0 + sisnr_perm1_spk1 + sisnr_perm1_spk2) / 3
        
        sisnr_perm2_spk0 = sisnr(sn[0], sn_hat[1])
        sisnr_perm2_spk1 = sisnr(sn[1], sn_hat[0])
        sisnr_perm2_spk2 = sisnr(sn[2], sn_hat[2])
        sisnr_perm2 = (sisnr_perm2_spk0 + sisnr_perm2_spk1 + sisnr_perm2_spk2) / 3
        
        sisnr_perm3_spk0 = sisnr(sn[0], sn_hat[1])
        sisnr_perm3_spk1 = sisnr(sn[1], sn_hat[2])
        sisnr_perm3_spk2 = sisnr(sn[2], sn_hat[0])
        sisnr_perm3 = (sisnr_perm3_spk0 + sisnr_perm3_spk1 + sisnr_perm3_spk2) / 3
        
        sisnr_perm4_spk0 = sisnr(sn[0], sn_hat[2])
        sisnr_perm4_spk1 = sisnr(sn[1], sn_hat[0])
        sisnr_perm4_spk2 = sisnr(sn[2], sn_hat[1])
        sisnr_perm4 = (sisnr_perm4_spk0 + sisnr_perm4_spk1 + sisnr_perm4_spk2) / 3
        
        sisnr_perm5_spk0 = sisnr(sn[0], sn_hat[2])
        sisnr_perm5_spk1 = sisnr(sn[1], sn_hat[1])
        sisnr_perm5_spk2 = sisnr(sn[2], sn_hat[0])
        sisnr_perm5 = (sisnr_perm5_spk0 + sisnr_perm5_spk1 + sisnr_perm5_spk2) / 3
        
        sisnr_perm_invariant = tf.stack([sisnr_perm1,sisnr_perm2,sisnr_perm3,sisnr_perm4,sisnr_perm5])
        sisnr_perm_invariant = tf.math.reduce_max(sisnr_perm_invariant)
        return -sisnr_perm_invariant
"""

"""
@tf.function
def get_permutation_invariant_sisnr(spk0_estimate, spk1_estimate, spk0_groundtruth, spk1_groundtruth):
    perm0_spk0 = sisnr(spk0_groundtruth, spk0_estimate, do_expand=True)
    perm0_spk1 = sisnr(spk1_groundtruth, spk1_estimate, do_expand=True)
    perm1_spk0 = sisnr(spk0_groundtruth, spk1_estimate, do_expand=True)
    perm1_spk1 = sisnr(spk1_groundtruth, spk0_estimate, do_expand=True)

    # Get best permutation
    if perm0_spk0 + perm0_spk1 > perm1_spk0 + perm1_spk1:
        return perm0_spk0, perm0_spk1

    return perm1_spk0, perm1_spk1 


@tf.function
def permutation_invariant_loss(y_true, y_pred,n_sources=2):
    # PIT for n-sources, work but very slow, needs to use tf.while_loop
    # yet, implementing tf.while_loop is a nightmare
    sn, sn_hat = [], []
    for i in range(n_sources):
        sn.append(y_true[:,i,:])
        sn_hat.append(y_pred[:,i,:])
        
    perm = list(permutations(range(n_sources)))
    sisnr_perm = []
        
    for i in range(len(perm)):
        sisnr_perm_spk = [sisnr(sn[p], sn_hat[j]) for p,j in enumerate(perm[i])]
        sisnr_perm.append(tf.math.reduce_sum(sisnr_perm_spk)/n_sources)

    sisnr_perm = tf.stack(sisnr_perm) 
    return -tf.math.reduce_max(sisnr_perm)
"""


'\n@tf.function\ndef get_permutation_invariant_sisnr(spk0_estimate, spk1_estimate, spk0_groundtruth, spk1_groundtruth):\n    perm0_spk0 = sisnr(spk0_groundtruth, spk0_estimate, do_expand=True)\n    perm0_spk1 = sisnr(spk1_groundtruth, spk1_estimate, do_expand=True)\n    perm1_spk0 = sisnr(spk0_groundtruth, spk1_estimate, do_expand=True)\n    perm1_spk1 = sisnr(spk1_groundtruth, spk0_estimate, do_expand=True)\n\n    # Get best permutation\n    if perm0_spk0 + perm0_spk1 > perm1_spk0 + perm1_spk1:\n        return perm0_spk0, perm0_spk1\n\n    return perm1_spk0, perm1_spk1 \n\n\n@tf.function\ndef permutation_invariant_loss(y_true, y_pred,n_sources=2):\n    # PIT for n-sources, work but very slow, needs to use tf.while_loop\n    # yet, implementing tf.while_loop is a nightmare\n    sn, sn_hat = [], []\n    for i in range(n_sources):\n        sn.append(y_true[:,i,:])\n        sn_hat.append(y_pred[:,i,:])\n        \n    perm = list(permutations(range(n_sources)))\n    sisnr_perm = []\n      

In [148]:
def pit_loss(y_true, y_pred, loss_type, batch_size, n_speaker, n_output, pit_axis=1):
    # [batch, spk #, length]
    real_spk_num = n_speaker

    # TODO 1: # output channel != # speaker
    v_perms = tf.constant(list(permutations(range(n_output), n_speaker)))
    v_perms_onehot = tf.one_hot(v_perms, n_output)

    y_true_exp = tf.expand_dims(y_true, pit_axis+1) # [batch, n_speaker, 1,        len]
    y_pred_exp = tf.expand_dims(y_pred, pit_axis)   # [batch, 1,         n_output, len]

    cross_total_loss = get_loss(loss_type, y_true_exp, y_pred_exp)

    loss_sets = tf.einsum('bij,pij->bp', cross_total_loss, v_perms_onehot) 
    loss = tf.reduce_min(loss_sets, axis=1)
    loss = tf.reduce_mean(loss)
        
    # find permutation sets for y pred
    s_perm_sets = tf.argmin(loss_sets, 1)
    s_perm_choose = tf.gather(v_perms, s_perm_sets)
    s_perm_idxs = tf.stack([
        tf.tile(
            tf.expand_dims(tf.range(batch_size), 1),
            [1, n_speaker]),
        s_perm_choose], axis=2)

    s_perm_idxs = tf.reshape(s_perm_idxs, [batch_size*n_speaker, 2])
    y_pred = tf.gather_nd(y_pred, s_perm_idxs)
    y_pred = tf.reshape(y_pred, [batch_size, n_speaker, -1])

    if loss_type != 'sdr':
        sdr = evaluation.sdr(y_true[:,:real_spk_num,:], y_pred[:,:real_spk_num,:])
        sdr = tf.reduce_mean(sdr)
    else:
        sdr = -loss/n_speaker

    return loss, y_pred, sdr, s_perm_choose

def get_loss(loss_type, t_true_exp, t_pred_exp, axis=-1):
    if loss_type == 'l1':
        y_cross_loss = t_true_exp - t_pred_exp
        cross_total_loss = tf.reduce_sum(tf.abs(y_cross_loss), axis=axis)

    elif loss_type == 'l2':
        y_cross_loss = t_true_exp - t_pred_exp
        cross_total_loss = tf.reduce_sum(tf.square(y_cross_loss), axis=axis)

    elif loss_type == 'snr':
        cross_total_loss = -evaluation.snr(t_true_exp, t_pred_exp)

    elif loss_type == 'sdr':
        cross_total_loss = -evaluation.sdr(t_true_exp, t_pred_exp)

    elif loss_type == 'sisnr':
        cross_total_loss = -evaluation.sisnr(t_true_exp, t_pred_exp)

    elif loss_type == 'sdr_modify':
        cross_total_loss = -evaluation.sdr_modify(t_true_exp, t_pred_exp)

    elif loss_type == 'sisdr':
        cross_total_loss = -evaluation.sisdr(t_true_exp, t_pred_exp)

    elif loss_type == 'sym_sisdr':
        cross_total_loss = -evaluation.sym_sisdr(t_true_exp, t_pred_exp)

    return cross_total_loss

@tf.function
def sisnri_sdri(s, s_est, mix_s, batch_size, n_speaker, n_output, pit_axis=1):
    mix_s = tf.repeat(mix_s,n_speaker,axis=1) 
    mix_s = tf.reshape(mix_s,(batch_size,n_speaker,-1))
    
    loss, _, sdr, _ = pit_loss(s, s_est,  'sisnr', batch_size, n_speaker, n_output, pit_axis=1)
    loss_b, _, sdr_b, _ = pit_loss(s, mix_s, 'sisnr', batch_size, n_speaker, n_output, pit_axis=1)
    
    loss *= -1
    loss_b *= -1
    
    if tf.math.is_nan(loss_b) and tf.math.is_nan(sdr_b) == False:
        return loss*-1, sdr-sdr_b
    elif tf.math.is_nan(loss_b) == False and tf.math.is_nan(sdr_b):
        return loss-loss_b, sdr_b
    elif tf.math.is_nan(loss_b) and tf.math.is_nan(sdr_b):
        return loss*-1, sdr
    else:
        return loss-loss_b, sdr-sdr_b
        


## RHA

In [149]:
class RHA(keras.Model):
    def __init__(self,batch_size,
                model_weights_file,
                num_filters_in_encoder,
                encoder_filter_length,
                chunk_size,
                num_full_chunks,
                num_chop_blocks,
                num_tran_blocks,
                num_head_per_att,
                dim_key_att,
                max_input_length_in_seconds,
                samplerate_hz,
                num_speakers):
        super(RHA, self).__init__()
        
        if num_speakers <= 1 or isinstance(num_speakers , int) != True:
            raise AssertionError('Passed value for num_speakrs is invalid, must be int greater than one.')
        self.num_speakers = num_speakers
        
        self.num_chop_blocks = num_chop_blocks 
        self.batch_size = batch_size
        self.model_weights_file = model_weights_file
        self.max_input_length_in_seconds = max_input_length_in_seconds
        self.num_tran_blocks = num_tran_blocks
        self.num_head_per_att = num_head_per_att
        self.dim_key_att = dim_key_att
        self.encoder_filter_length = encoder_filter_length
        self.num_filters_in_encoder = num_filters_in_encoder
        self.encoder_hop_size = encoder_filter_length // 2
        self.num_full_chunks = num_full_chunks
        self.signal_length_samples = chunk_size*num_full_chunks
        self.chunk_size = chunk_size
        self.chunk_advance = chunk_size // 2
        self.num_overlapping_chunks = num_full_chunks*2-1
        self.samplerate_hz = samplerate_hz
        
        # Build model
        self.model = self.getModel()
        self.loss_tracker = keras.metrics.Mean(name="loss_sisnr")
        self.sdr_tracker = keras.metrics.Mean(name="sdr")
        self.sdri_tracker = keras.metrics.Mean(name="sdri")
        self.sisnri_tracker = keras.metrics.Mean(name="sisnri")
        
    def getModel(self):
        # Model input 
        inputs = tf.keras.Input(self.signal_length_samples)
        
        # Encoder Block 
        z, encoder_out = Encoder(num_filters_in_encoder=self.num_filters_in_encoder,
                    encoder_filter_length=self.encoder_filter_length,
                    samplerate_hz=self.samplerate_hz,
                    max_input_length_in_seconds=self.max_input_length_in_seconds,
                    chunk_length=self.chunk_size,
                    batch_size=self.batch_size)(inputs)
        
        # Chop Shop 
        for i in range(self.num_chop_blocks):  
            # Intra ~ Short
            z = tf.reshape(z,(self.batch_size,-1,self.chunk_size))
            _ , _ , dim_check = z.get_shape()
            if dim_check % self.num_head_per_att != 0:
                for i in range(self.num_tran_blocks):
                    if i > 0: 
                        z = tf.reshape(z,(self.batch_size,-1,self.chunk_size))
                    pad_len = self.num_head_per_att-(dim_check % self.num_head_per_att)
                    zero_pad = tf.zeros((self.batch_size,self.num_overlapping_chunks*self.num_filters_in_encoder,pad_len))
                    z_pad = tf.concat([z,zero_pad], axis=-1)
                    x = LinearSineSPETransformerBlock(self.chunk_size+pad_len,self.num_head_per_att,self.dim_key_att)(z_pad)
                    x = x[:,:,:-pad_len]
                    z = x+z
                    z = tf.reshape(z,(self.batch_size,self.num_overlapping_chunks,self.chunk_size,self.num_filters_in_encoder))
            else:      
                for i in range(self.num_tran_blocks):
                    if i > 0: 
                        z = tf.reshape(z,(self.batch_size,-1,self.chunk_size))
                    x = LinearSineSPETransformerBlock(self.chunk_size,self.num_head_per_att,self.dim_key_att)(z)
                    z = x+z
                    z = tf.reshape(z,(self.batch_size,self.num_overlapping_chunks,self.chunk_size,self.num_filters_in_encoder))

            # Inter ~ Long 
            z = tf.reshape(z,(self.batch_size,-1,self.num_overlapping_chunks))
            _ , _ , dim_check = z.get_shape()
            if dim_check % self.num_head_per_att != 0:
                # Case 1: padding needed
                for i in range(self.num_tran_blocks):
                    if i > 0:  
                        z = tf.reshape(z,(self.batch_size,-1,self.num_overlapping_chunks))
                    pad_len = self.num_head_per_att-(dim_check % self.num_head_per_att)
                    zero_pad = tf.zeros((self.batch_size,self.chunk_size*self.num_filters_in_encoder,pad_len))
                    z_pad = tf.concat([z,zero_pad], axis=-1)
                    x = LinearSineSPETransformerBlock(self.num_overlapping_chunks+pad_len,self.num_head_per_att,self.dim_key_att)(z_pad)
                    x = x[:,: ,:-pad_len]
                    z = x+z
                z = tf.reshape(z,(self.batch_size,self.num_overlapping_chunks,self.chunk_size,self.num_filters_in_encoder))
            else:    
                # Case 2: padding not needed
                for i in range(self.num_tran_blocks):
                    if i > 0: 
                        z = tf.reshape(z,(self.batch_size,-1,self.chunk_size))
                    x = LinearSineSPETransformerBlock(self.num_overlapping_chunks,self.num_head_per_att,self.dim_key_att)(z)
                    z = x+z
                    z = tf.reshape(z,(self.batch_size,self.num_overlapping_chunks,self.chunk_size,self.num_filters_in_encoder))
        x = z
        
        # Decoder Block 
        decoded = Decoder(signal_length_samples=self.signal_length_samples,
                    n_sources=self.num_speakers,
                    num_filters_in_encoder=self.num_filters_in_encoder,
                    encoder_filter_length=self.encoder_filter_length,
                    samplerate_hz=self.samplerate_hz,
                    max_input_length_in_seconds=self.max_input_length_in_seconds,
                    chunk_length=self.chunk_size,
                    batch_size=self.batch_size)(x,encoder_out)
        
        # Final model 
        model = tf.keras.Model(inputs, decoded)
        return model 
        
    def call(self,inputs):
        yh = self.model(inputs)
        return yh
        
    def train_step(self, inputs):
        X, y = inputs
        with tf.GradientTape() as tape:
            yh = self.model(X)
            #loss = permutation_invariant_loss(y,yh,n_sources=self.num_speakers)
            loss, y_pred, sdr, s_perm_choose = pit_loss(y, yh, 'sisnr', self.batch_size, self.num_speakers, self.num_speakers, pit_axis=1)
            
        y_mix = tf.math.reduce_sum(y,axis=1,keepdims=True)
        sisnri_val, sdri_val = sisnri_sdri(y, yh, y_mix,  self.batch_size, self.num_speakers, self.num_speakers, pit_axis=1)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        self.loss_tracker.update_state(loss)
        self.sdr_tracker.update_state(sdr)
        self.sdri_tracker.update_state(sdri_val)
        self.sisnri_tracker.update_state(sisnri_val)
        return {
            "si-snr": self.loss_tracker.result()*-1,
            "sdr": self.sdr_tracker.result(),
            "si-snri": self.sisnri_tracker.result(),
            "sdri": self.sdri_tracker.result(),
        }
    
    @property
    def metrics(self):
        """List of the model's metrics.
        We make sure the loss tracker is listed as part of `model.metrics`
        so that `fit()` and `evaluate()` are able to `reset()` the loss tracker
        at the start of each epoch and at the start of an `evaluate()` call.
        """
        return [self.loss_tracker,self.sdr_tracker,self.sisnri_tracker,self.sdri_tracker]

    def test_step(self, inputs):
        X,y = inputs
        yh = self.model(X)
        #loss = permutation_invariant_loss(y,yh,n_sources=self.num_speakers)
        loss, y_pred, sdr, s_perm_choose = pit_loss(y, yh, 'sisnr', self.batch_size, self.num_speakers, self.num_speakers, pit_axis=1)
        y_mix = tf.math.reduce_sum(y,axis=1,keepdims=True)
        sisnri_val, sdri_val = sisnri_sdri(y, yh, y_mix,  self.batch_size, self.num_speakers, self.num_speakers, pit_axis=1)

        self.loss_tracker.update_state(loss)
        self.sdr_tracker.update_state(sdr)
        self.sdri_tracker.update_state(sdri_val)
        self.sisnri_tracker.update_state(sisnri_val)
        return {
            "si-snr": self.loss_tracker.result()*-1,
            "sdr": self.sdr_tracker.result(),
            "si-snri": self.sisnri_tracker.result(),
            "sdri": self.sdri_tracker.result(),
        }

In [154]:
BATCH_SIZE = 1
OPTIMIZER_CLIP_L2_NORM_VALUE = 5

NETWORK_NUM_SPEAKERS = 4
MAX_INPUT_LENGTH_IN_SECONDS = 6
SAMPLERATE_HZ = 8000

NETWORK_NUM_FILTERS_IN_ENCODER = 64
NETWORK_ENCODER_FILTER_LENGTH = 2
NETWORK_NUM_HEAD_PER_ATT = 8
NEWORK_DIM_KEY_ATT = 1024
NETWORK_NUM_TRAN_BLOCKS = 1
NETWORK_NUM_CHOP_BLOCKS = 1
NETWORK_CHUNK_SIZE = 256

NUM_CHUNKS = SAMPLERATE_HZ*MAX_INPUT_LENGTH_IN_SECONDS//NETWORK_CHUNK_SIZE
num_overlapping_chunks = NUM_CHUNKS*2-1

print(num_overlapping_chunks)
print((num_overlapping_chunks*NETWORK_CHUNK_SIZE)%NETWORK_NUM_SPEAKERS)

# (1, 373, 256, 64) = (batch_size, num_overlapping_chunks, \
#            NETWORK_CHUNK_SIZE, NETWORK_NUM_FILTERS_IN_ENCODER)
# 47872

373
0


In [155]:
model =  RHA(batch_size=BATCH_SIZE,
                model_weights_file=None,
                num_filters_in_encoder=NETWORK_NUM_FILTERS_IN_ENCODER,
                encoder_filter_length=NETWORK_ENCODER_FILTER_LENGTH,
                chunk_size=NETWORK_CHUNK_SIZE,
                num_full_chunks=NUM_CHUNKS,
                num_chop_blocks=NETWORK_NUM_CHOP_BLOCKS,
                num_tran_blocks=NETWORK_NUM_TRAN_BLOCKS,
                num_head_per_att=NETWORK_NUM_HEAD_PER_ATT,
                dim_key_att=NEWORK_DIM_KEY_ATT,
                max_input_length_in_seconds=MAX_INPUT_LENGTH_IN_SECONDS,
                samplerate_hz=SAMPLERATE_HZ,
                num_speakers=NETWORK_NUM_SPEAKERS)

opt = AdamW(1e-4,clipnorm=OPTIMIZER_CLIP_L2_NORM_VALUE)
model.compile(optimizer=opt,metrics=["mse"],)

print(model.model.summary()) 

Model: "model_19"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_33 (InputLayer)           [(None, 47872)]      0                                            
__________________________________________________________________________________________________
encoder_33 (Encoder)            ((1, 373, 256, 64),  256         input_33[0][0]                   
__________________________________________________________________________________________________
tf.reshape_128 (TFOpLambda)     (1, 23872, 256)      0           encoder_33[0][0]                 
__________________________________________________________________________________________________
linear_sine_spe_transformer_blo (1, 23872, 256)      791296      tf.reshape_128[0][0]             
___________________________________________________________________________________________

## Load data 

In [None]:
def pad_along_axis(array: np.ndarray, target_length: int, axis: int = 0):

    pad_size = target_length - array.shape[axis]

    if pad_size <= 0:
        return array

    npad = [(0, 0)] * array.ndim
    npad[axis] = (0, pad_size)

    return np.pad(array, pad_width=npad, mode='constant', constant_values=0)

In [None]:
music_data, music_fs = librosa.load('street_music_sample.wav', sr=SAMPLERATE_HZ)
dog_data, dog_fs = librosa.load('dog_sample.wav', sr=SAMPLERATE_HZ)
print(len(music_data),music_fs,librosa.get_duration(y=music_data, sr=music_fs))
print(len(dog_data),dog_fs,librosa.get_duration(y=dog_data, sr=dog_fs))

In [None]:
# Increase dog bark
dog_data = dog_data*2

# Pad audio
music_data = pad_along_axis(music_data,47872)
dog_data = pad_along_axis(dog_data,47872)
print(len(music_data),music_fs,librosa.get_duration(y=music_data, sr=music_fs))
print(len(dog_data),dog_fs,librosa.get_duration(y=dog_data, sr=dog_fs))

# Mix audio 
mix_data = music_data+dog_data
mix_data = mix_data[np.newaxis,...]
print(mix_data.shape)

# Stack audio for target 
target = np.stack([dog_data,music_data])
target = target[np.newaxis,...]
print(target.shape)

In [None]:
sd.play(music_data, music_fs)
status = sd.wait() 

In [None]:
plt.plot(range(music_data.shape[0]),music_data)
plt.title('Music Data')

In [None]:
sd.play(dog_data, dog_fs)
status = sd.wait() 

In [None]:
plt.plot(range(dog_data.shape[0]),dog_data)
plt.title('Dog Data')

In [None]:
sd.play(mix_data[0,:], dog_fs)
status = sd.wait() 

In [None]:
plt.plot(range(mix_data[0,:].shape[0]),mix_data[0,:])
plt.title('Mix Data')

## Train model

In [None]:
model.model.load_weights('test_weights.tf')

In [None]:
history = model.fit(mix_data, target, epochs=1, batch_size=1)
#model.model.save_weights('test_weights.tf')

## Test model 

In [None]:
yh = model(mix_data).numpy()
y1, y2 = yh[0,0,:], yh[0,1,:]

In [None]:
sd.play(y1, SAMPLERATE_HZ)
status = sd.wait() 
sf.write('seperated_outputA.wav', y1, 8000)

In [None]:
sd.play(y2, SAMPLERATE_HZ)
status = sd.wait() 
sf.write('seperated_outputB.wav', y2, 8000)

In [None]:
sd.play(mix_data[0,:], dog_fs)
status = sd.wait() 
sf.write('original_mix_input.wav', mix_data[0,:], 8000)