In [1]:
from asteroid.data import LibriMix
from asteroid_filterbanks import make_enc_dec

import torch
from torch import nn
from sklearn.cluster import KMeans

from asteroid import torch_utils
from asteroid_filterbanks.transforms import mag, apply_mag_mask
from asteroid.dsp.vad import ebased_vad
from asteroid.masknn.recurrent import SingleRNN
from asteroid.utils.torch_utils import pad_x_to_y

from pytorch_metric_learning.losses import BaseMetricLossFunction

In [2]:
# Adopted from https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/cluster.py
class DeepClusteringLoss(BaseMetricLossFunction):
    def compute_loss(self, embedding, tgt_index, binary_mask=None):
        spk_cnt = len(tgt_index.unique())

        batch, bins, frames = tgt_index.shape
        if binary_mask is None:
            binary_mask = torch.ones(batch, bins * frames, 1)
        binary_mask = binary_mask.float()
        if len(binary_mask.shape) == 3:
            binary_mask = binary_mask.view(batch, bins * frames, 1)
            
        # If boolean mask, make it float.
        binary_mask = binary_mask.to(tgt_index.device)

        # Fill in one-hot vector for each TF bin
        tgt_embedding = torch.zeros(batch, bins * frames, spk_cnt, device=tgt_index.device)
        tgt_embedding.scatter_(2, tgt_index.view(batch, bins * frames, 1), 1)

        # Compute VAD-weighted DC loss
        tgt_embedding = tgt_embedding * binary_mask
        embedding = embedding * binary_mask
        est_proj = torch.einsum("ijk,ijl->ikl", embedding, embedding)
        true_proj = torch.einsum("ijk,ijl->ikl", tgt_embedding, tgt_embedding)
        true_est_proj = torch.einsum("ijk,ijl->ikl", embedding, tgt_embedding)
        
        # Equation (1) in [1]
        cost = batch_matrix_norm(est_proj) + batch_matrix_norm(true_proj)
        cost = cost - 2 * batch_matrix_norm(true_est_proj)
        
        # Divide by number of active bins, for each element in batch
        return cost / torch.sum(binary_mask, dim=[1, 2])

def batch_matrix_norm(matrix, norm_order=2):
    keep_batch = list(range(1, matrix.ndim))
    return torch.norm(matrix, p=norm_order, dim=keep_batch) ** norm_order

In [3]:
# This is the base Deep Clustering model without the Mask Inference head used in Chimera++
# Adopted from https://github.com/asteroid-team/asteroid/blob/master/egs/wsj0-mix/DeepClustering/model.py

def make_model(conf):
    encoder, decoder = make_enc_dec('stft', **conf["filterbank"])
    embedding = Embedding(encoder.n_feats_out // 2, **conf["deepclustering"])
    model = Model(encoder, embedding, decoder)
    return model

class Embedding(nn.Module):
    def __init__(
        self, 
        channel_in, 
        n_src=2, 
        rnn_type='lstm',
        n_layers=2, 
        hidden_layer_size=600, 
        dropout=0.3,
        embedding_dim=40, 
        take_log=True,
        epsilon=1e-8
    ):
        super().__init__()
        self.channel_in = channel_in # channel_in = freq
        self.n_src = n_src
        self.take_log = take_log
        self.embedding_dim = embedding_dim
        self.lstm = SingleRNN(
            rnn_type, 
            channel_in, 
            hidden_layer_size, 
            n_layers=n_layers,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        lstm_output_dim = 2 * hidden_layer_size
        self.embedding_layer = nn.Linear(lstm_output_dim, channel_in * embedding_dim)
        self.embedding_activation = nn.Tanh()
        self.epsilon = epsilon
        
    def forward(self, input_data):
        batch_size, _, frames = input_data.shape
        if self.take_log:
            x = torch.log(input_data + self.epsilon)
        
        # LSTM layers
        lstm_output = self.lstm(x.permute(0, 2, 1))
        lstm_output = self.dropout(lstm_output)
        
        # Fully connected layer
        embedding_out = self.embedding_layer(lstm_output) # Shape is (batch_size, time, freq * embedding_size)
        embedding_out = self.embedding_activation(embedding_out)
        
        # Make shape (batch_size, freq, time, embedding_size)
        embedding_out = embedding_out.view(batch_size, frames, -1, self.embedding_dim)
        embedding_out = embedding_out.transpose(1, 2)
        
        # Make shape (batch_size, freq * time, embedding_size)
        embedding_out = embedding_out.reshape(batch_size, -1, self.embedding_dim)
        
        # Normalise (the embedding vector for each time * freq bin should be of unit norm)
        embedding_norm = torch.norm(embedding_out, p=2, dim=-1, keepdim=True)
        normalised_embedding = embedding_out / (embedding_norm + self.epsilon)
        
        return normalised_embedding

class Model(nn.Module):
    def __init__(self, encoder, embedding, decoder):
        super().__init__()
        self.encoder = encoder
        self.embedding = embedding
        self.decoder = decoder
        
    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        tf_representation = self.encoder(x)
        spectral_magnitude = mag(tf_representation)
        normalised_embedding = self.embedding(spectral_magnitude)
        return normalised_embedding
        
    def cluster(self, x):
        kmeans = KMeans(n_clusters=self.Embedding.n_src)
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        tf_representation = self.encode(x)
        spectral_magnitude = mag(tf_representation)
        normalised_embedding = self.embedding(spectral_magnitude)
    
        # Ignore time-frequency with energy < -40dB
        # ebased_vad = Energy based voice activity detection
        retained_bins = ebased_vad(spectral_magnitude)
        retained_embedding = normalised_embedding[retained_bins.view(1, -1)]
        
        clusters = kmeans.fit_predict(retained_embedding.cpu().data.numpy())
        
        # Create masks
        est_masks = []
        for i in range(self.Embedding.n_src):
            mask = ~retained_bins
            mask[retained_bins] = torch.from_numpy((clusters == i)).to(mask.device)
            est_masks.append(mask.float())
        
        # Apply the mask
        estimated_masks = torch.stack(est_masks, dim=1)
        masked_representation = apply_mag_mask(tf_representation, estimated_masks)
        # Pad masked audio to have same size as original
        separated_wav = pad_x_to_y(self.decoder(masked), x)
        return separated_wav

In [4]:
train_loader, val_loader = LibriMix.loaders_from_mini(batch_size=16)

Drop 0 utterances from 800 (shorter than 3 seconds)
Drop 0 utterances from 200 (shorter than 3 seconds)


In [5]:
mixture, sources = next(iter(train_loader))

In [6]:
enc, dec = make_enc_dec('stft', n_filters=512, kernel_size=16, stride=8)
embedding = Embedding(enc.n_feats_out // 2)
model = Model(enc, embedding, dec)

In [7]:
# Compute magnitude spectrograms and ideal ratio mask (IRM)
src_mag_spec = mag(model.encoder(sources))

# Normalise to get the real_mask. Maximise to get the binary mask
real_mask = src_mag_spec / (src_mag_spec.sum(1, keepdim=True) + 1e-8)
binary_mask = real_mask.argmax(1)

In [8]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

In [None]:
pred = model(mixture)
loss = DeepClusteringLoss()
deep_clustering_loss = loss.compute_loss(pred, binary_mask)
deep_clustering_loss.mean().backward()
optimizer.step()