In [None]:
# Important to run on a restarted kernel
!pip install smart_open
!pip install torchvision
!pip install sox
!pip install librosa
!pip install torchaudio


In [None]:
import numpy as np
from numpy.random import default_rng

import matplotlib.pyplot as plt
from smart_open import open

import os
from pathlib import Path
import random
import librosa

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# Download data

Running the cell below should download the data to the correct directories used by the notebook.

In [None]:
# make data directory if it does not exist
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

clean_audio_url = "https://huggingface.co/datasets/aarohig2/speech-separation-librispeech-esc50/resolve/main/clean_audio.zip"
noise_audio_url = "https://huggingface.co/datasets/aarohig2/speech-separation-librispeech-esc50/resolve/main/noise.zip"

# download the data
if not (data_dir / "clean_audio").exists():
    !wget -P data/ $clean_audio_url
    !unzip data/clean_audio.zip -d data
    !rm data/clean_audio.zip

if not (data_dir / "noise").exists():
    !wget -P data $noise_audio_url
    !unzip data/noise.zip -d data
    !rm data/noise.zip

# Data Augmentation

In [None]:
# Number of samples/data points per second.
sample_rate=8000

# the signal length of training sequences
training_signal_len = 40000

rng = default_rng()

In [None]:
# TODO: change these to correct lengths
voice_dir_len = 3000
noise_dir_len = 2000

In [None]:
def standardize_length(sample, training_signal_len):
    '''
    Returns `sample` such that the length is equal to `training_signal_len`.
    If the sample is too short, it will repeat the start of the input until enough samples
    have been added.
    '''
    len_diff = training_signal_len - len(sample)
    if len_diff < 0: # The sample is too long
        return sample[:training_signal_len]
    elif len_diff > 0: # The sample is too short
        sample_dup = np.tile(sample, (len_diff // len(sample)) + 1)
        return np.concatenate([sample, sample_dup[:len_diff]])
    else:
        return sample

def dynamic_mix_data_prep(voice_dir, noise_dir, training_signal_len, sample_rate):
    """
    This function defines the compute graph for dynamic mixing.
    """

    # Randomly chooses files from voice and noise url lists
    clean_file_idx, noise_file_idx = rng.choice(voice_dir_len, size=2, replace=False)
    clean1_file, clean2_file = voice_dir[clean_file_idx], voice_dir[noise_file_idx]
    noise_file = noise_dir[np.random.randint(noise_dir_len)]
    
    sources = []
    first_lvl = None
    spk_files = [clean2_file, noise_file, clean1_file]

    for i, spk_file in enumerate(spk_files):
        
        # Load audio and ensure it's of minimum length
        with open(spk_file, 'rb') as file:
            # print(f"File librosa.load({file})")
            y, sr = librosa.load(file, sr=sample_rate, duration=training_signal_len/sample_rate)
            y = standardize_length(y, training_signal_len)

        # Normalize audio
        tmp = librosa.util.normalize(y)

        # Layer on a stack
        if i == 0 or i == 1:
            gain = np.clip(random.normalvariate(-27.43, 2.57), -45, 0)
            tmp *= 10 ** (gain / 20)  # Convert dB to amplitude
            first_lvl = gain
        else:
            gain = np.clip(first_lvl + random.normalvariate(-2.51, 2.66), -45, 0)
            tmp *= 10 ** (gain / 20)  # Convert dB to amplitude
            
        sources.append(tmp)

    # Mix the sources together
    mixture = sum(sources)

    # Calculate the maximum amplitude among the tensors
    max_amp = max(np.abs(mixture).max(), *[np.abs(s).max() for s in sources])

    # Calculate scale value and apply it to the array
    mix_scaling = 1 / max_amp * 0.9
    mixture *= mix_scaling
    
    # Save the clean audio also
    with open(clean1_file, 'rb') as file:
        clean, _ = librosa.load(file, sr=sample_rate, duration=training_signal_len/sample_rate)
        clean = standardize_length(clean, training_signal_len)
        
    clean *= 10 ** (gain / 20)  # Apply the same gain as mixture

    return mixture, clean, sample_rate

# Dataset

In [None]:
def log_memory_usage():
    t = torch.cuda.get_device_properties(0).total_memory / 1024 ** 2  # Convert bytes to megabytes
    r = torch.cuda.memory_reserved(0) / 1024 ** 2  # Convert bytes to megabytes
    a = torch.cuda.memory_allocated(0) / 1024 ** 2  # Convert bytes to megabytes
    max_allocated = torch.cuda.max_memory_allocated(0) / 1024 ** 2  # Convert bytes to megabytes
    print(f"Memory allocated: {a:.2f} MB")
    print(f"Memory reserved: {r: .2f} MB")
    print(f"Memory remaining: {t - r: .2f}")
    print(f"Peak memory allocated: {max_allocated:.2f} MB")

In [None]:
def get_clean_and_noise_urls():
    '''    
    Returns (Clean URLs, Noise URLs)
    '''

    clean_speech_path = "data/clean_speech/"
    noise_path = "data/noise/"

    # Store urls of our dataset
    clean = np.array([])
    noise = np.array([])

    # Load clean speech urls
    for file in os.listdir(clean_speech_path):
        if file.endswith(".flac"):
            clean = np.append(clean, clean_speech_path + file)
    
    print(f'Number of clean audio files are {len(clean)}')

    # Load noise urls
    for file in os.listdir(noise_path):
        if file.endswith(".wav"):
            noise = np.append(noise, noise_path + file)
    
    print(f'Number of Noisy audio files are {len(noise)}')
    return clean, noise

class CustomDataset(Dataset):
    def __init__(self, transform=None, target_transform=None, num_samples=100):
        self.transform = transform
        self.target_transform = target_transform
        self.num_samples = num_samples
        
        clean_dir, noise_dir = get_clean_and_noise_urls()
        self.voice_dir = clean_dir
        self.noise_dir = noise_dir

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        mixture, clean, _ = dynamic_mix_data_prep(self.voice_dir, self.noise_dir, training_signal_len, sample_rate)
                
        return torch.tensor(standardize_length(mixture, training_signal_len)), torch.tensor(standardize_length(clean, training_signal_len))

# Conv-TasNet Architecture

In [None]:
class GlobalLayerNorm(nn.Module):
    '''
       Calculate Global Layer Normalization
       dim: (int or list or torch.Size) –
            input shape from an expected input of size
       eps: a value added to the denominator for numerical stability.
       elementwise_affine: a boolean value that when set to True,
           this module has learnable per-element affine parameters
           initialized to ones (for weights) and zeros (for biases).
    '''

    def __init__(self, dim, eps=1e-05, elementwise_affine=True):
        super(GlobalLayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(self.dim, 1))
            self.bias = nn.Parameter(torch.zeros(self.dim, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        # x = N x C x L
        # N x 1 x 1
        # cln: mean,var N x 1 x L
        # gln: mean,var N x 1 x 1
        if x.dim() != 3:
            raise RuntimeError("{} accept 3D tensor as input".format(
                self.__class__.__name__))

        mean = torch.mean(x, (1, 2), keepdim=True)
        var = torch.mean((x-mean)**2, (1, 2), keepdim=True)
        # N x C x L
        if self.elementwise_affine:
            x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias
        else:
            x = (x-mean)/torch.sqrt(var+self.eps)
        return x


class CumulativeLayerNorm(nn.LayerNorm):
    '''
       Calculate Cumulative Layer Normalization
       dim: you want to norm dim
       elementwise_affine: learnable per-element affine parameters
    '''

    def __init__(self, dim, elementwise_affine=True):
        super(CumulativeLayerNorm, self).__init__(
            dim, elementwise_affine=elementwise_affine)

    def forward(self, x):
        # x: N x C x L
        # N x L x C
        x = torch.transpose(x, 1, 2)
        # N x L x C == only channel norm
        x = super().forward(x)
        # N x C x L
        x = torch.transpose(x, 1, 2)
        return x


def select_norm(norm, dim):
    if norm not in ['gln', 'cln', 'bn']:
        raise ValueError("Unsupported norm type")
    if norm == 'gln':
        return GlobalLayerNorm(dim, elementwise_affine=True)
    if norm == 'cln':
        return CumulativeLayerNorm(dim, elementwise_affine=True)
    else:
        return nn.BatchNorm1d(dim)


class Conv1D(nn.Conv1d):
    '''
       Applies a 1D convolution over an input signal composed of several input planes.
    '''

    def __init__(self, *args, **kwargs):
        super(Conv1D, self).__init__(*args, **kwargs)

    def forward(self, x, squeeze=False):
        # x: N x C x L
        if x.dim() not in [2, 3]:
            raise RuntimeError("{} accept 2/3D tensor as input".format(
                self.__class__.__name__))
        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
        if squeeze:
            x = torch.squeeze(x)
        return x


class ConvTrans1D(nn.ConvTranspose1d):
    '''
       This module can be seen as the gradient of Conv1d with respect to its input.
       It is also known as a fractionally-strided convolution
       or a deconvolution (although it is not an actual deconvolution operation).
    '''

    def __init__(self, *args, **kwargs):
        super(ConvTrans1D, self).__init__(*args, **kwargs)

    def forward(self, x, squeeze=False):
        """
        x: N x L or N x C x L
        """
        if x.dim() not in [2, 3]:
            raise RuntimeError("{} accept 2/3D tensor as input".format(
                self.__class__.__name__))
        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
        if squeeze:
            x = torch.squeeze(x)
        return x

class Conv1D_Block(nn.Module):
    '''
       Consider only residual links
    '''

    def __init__(self, in_channels=256, out_channels=512,
                 kernel_size=3, dilation=1, norm='gln', causal=False):
        super(Conv1D_Block, self).__init__()
        # conv 1 x 1
        self.conv1x1 = Conv1D(in_channels, out_channels, 1)
        self.PReLU_1 = nn.PReLU()
        self.norm_1 = select_norm(norm, out_channels)
        # not causal don't need to padding, causal need to pad+1 = kernel_size
        self.pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
            dilation * (kernel_size - 1))
        # depthwise convolution
        self.dwconv = Conv1D(out_channels, out_channels, kernel_size,
                             groups=out_channels, padding=self.pad, dilation=dilation)
        self.PReLU_2 = nn.PReLU()
        self.norm_2 = select_norm(norm, out_channels)
        self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True)
        self.causal = causal

    def forward(self, x):
        # x: N x C x L
        # N x O_C x L
        c = self.conv1x1(x)
        # N x O_C x L
        c = self.PReLU_1(c)
        c = self.norm_1(c)
        # causal: N x O_C x (L+pad)
        # noncausal: N x O_C x L
        c = self.dwconv(c)
        # N x O_C x L
        if self.causal:
            c = c[:, :, :-self.pad]
        c = self.Sc_conv(c)
        return x+c


class ConvTasNet(nn.Module):
    '''
       ConvTasNet module
       N	Number of ﬁlters in autoencoder
       L	Length of the ﬁlters (in samples)
       B	Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks
       Sc	Number of channels in skip-connection paths’ 1 × 1-conv blocks
       H	Number of channels in convolutional blocks
       P	Kernel size in convolutional blocks
       X	Number of convolutional blocks in each repeat
       R	Number of repeats
    '''

    def __init__(self,
                 N=512,
                 L=16,
                 B=128,
                 H=512,
                 P=3,
                 X=8,
                 R=3,
                 norm="gln",
                 num_spks=2,
                 activate="relu",
                 causal=False):
        super(ConvTasNet, self).__init__()
        # n x 1 x T => n x N x T
        self.encoder = Conv1D(1, N, L, stride=L // 2, padding=0)
        # n x N x T  Layer Normalization of Separation
        self.LayerN_S = select_norm('cln', N)
        # n x B x T  Conv 1 x 1 of  Separation
        self.BottleN_S = Conv1D(N, B, 1)
        # Separation block
        # n x B x T => n x B x T
        self.separation = self._Sequential_repeat(
            R, X, in_channels=B, out_channels=H, kernel_size=P, norm=norm, causal=causal)
        # n x B x T => n x 2*N x T
        self.gen_masks = Conv1D(B, num_spks*N, 1)
        # n x N x T => n x 1 x L
        self.decoder = ConvTrans1D(N, 1, L, stride=L//2)
        # activation function
        active_f = {
            'relu': nn.ReLU(),
            'sigmoid': nn.Sigmoid(),
            'softmax': nn.Softmax(dim=0)
        }
        self.activation_type = activate
        self.activation = active_f[activate]
        self.num_spks = num_spks

    def _Sequential_block(self, num_blocks, **block_kwargs):
        '''
           Sequential 1-D Conv Block
           input:
                 num_block: how many blocks in every repeats
                 **block_kwargs: parameters of Conv1D_Block
        '''
        Conv1D_Block_lists = [Conv1D_Block(
            **block_kwargs, dilation=(2**i)) for i in range(num_blocks)]

        return nn.Sequential(*Conv1D_Block_lists)

    def _Sequential_repeat(self, num_repeats, num_blocks, **block_kwargs):
        '''
           Sequential repeats
           input:
                 num_repeats: Number of repeats
                 num_blocks: Number of block in every repeats
                 **block_kwargs: parameters of Conv1D_Block
        '''
        repeats_lists = [self._Sequential_block(
            num_blocks, **block_kwargs) for i in range(num_repeats)]
        return nn.Sequential(*repeats_lists)

    def forward(self, x):
        if x.dim() >= 3:
            raise RuntimeError(
                "{} accept 1/2D tensor as input, but got {:d}".format(
                    self.__class__.__name__, x.dim()))
        if x.dim() == 1:
            x = torch.unsqueeze(x, 0)
        # x: n x 1 x L => n x N x T
        w = self.encoder(x)
        # n x N x L => n x B x L
        e = self.LayerN_S(w)
        e = self.BottleN_S(e)
        # n x B x L => n x B x L
        e = self.separation(e)
        # n x B x L => n x num_spk*N x L
        m = self.gen_masks(e)
        # n x N x L x num_spks
        m = torch.chunk(m, chunks=self.num_spks, dim=1)
        # num_spks x n x N x L
        m = self.activation(torch.stack(m, dim=0))
        d = [w*m[i] for i in range(self.num_spks)]
        # decoder part num_spks x n x L
        s = [self.decoder(d[i], squeeze=True) for i in range(self.num_spks)]
        return s


def check_parameters(net):
    '''
        Returns module parameters. Mb
    '''
    parameters = sum(param.numel() for param in net.parameters())
    return parameters / 10**6


def test_convtasnet():
    x = torch.randn(320)
    nnet = ConvTasNet()
    s = nnet(x)
    print(str(check_parameters(nnet))+' Mb')
    print(s[1].shape)


if __name__ == "__main__":
    test_convtasnet()


# Training

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"

In [None]:
from itertools import permutations

def sisnr(x, s, eps=1e-8):
    """
    calculate training loss
    input:
          x: separated signal, N x S tensor
          s: reference signal, N x S tensor
    Return:
          sisnr: N tensor
    """

    def l2norm(mat, keepdim=False):
        return torch.norm(mat, dim=-1, keepdim=keepdim)

    if x.shape != s.shape:
        raise RuntimeError(
            "Dimention mismatch when calculate si-snr, {} vs {}".format(
                x.shape, s.shape))
    x_zm = x - torch.mean(x, dim=-1, keepdim=True)
    s_zm = s - torch.mean(s, dim=-1, keepdim=True)
    t = torch.sum(
        x_zm * s_zm, dim=-1,
        keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
    return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))


def si_snr_loss(ests, egs):
    # spks x n x S
    print(egs.shape, ests.shape)
    print(egs)
    refs = egs
#     num_spks = len(refs)
    num_spks = refs.size(0)

    def sisnr_loss(permute):
        # for one permute
        return sum(
            [sisnr(ests[s], refs[t])
             for s, t in enumerate(permute)]) / len(permute)
        # average the value

    # P x N
#     N = egs["mix"].size(0)
    N = egs.size(1)
    sisnr_mat = torch.stack(
        [sisnr_loss(p) for p in permutations(range(num_spks))])
    max_perutt, _ = torch.max(sisnr_mat, dim=0)
    # si-snr
    return -torch.sum(max_perutt) / N

In [None]:
# Create a directory to save periodic checkpoints.
!mkdir checkpoints

def calculate_snr(signal, noise):
    signal_power = torch.sum(signal ** 2) / signal.numel()
    noise_power = torch.sum(noise ** 2) / noise.numel()

    linear_snr = signal_power / noise_power
    return linear_snr

# TODO:
# 1. Do we need mixed precision / torch.autocast?
# Speechbrain uses it on CPU fp16, or CUDA fp16/bf16:
# https://github.com/speechbrain/speechbrain/blob/5a8535c7bb202ecc852223f4a200b1f1bbc248fb/speechbrain/core.py#L790C1-L794C32
# 2. (Done? Unclear why the threshold is -30) Consider loss thresholding:
# https://github.com/speechbrain/speechbrain/blob/1350e9b3cebae9f78e57e97d82a2d89ba3fc2ae1/recipes/WSJ0Mix/separation/train.py#L156C1-L160C43
# 3. (Done?) Consider gradient clipping, skipping batches with bad losses:
# https://github.com/speechbrain/speechbrain/blob/1350e9b3cebae9f78e57e97d82a2d89ba3fc2ae1/recipes/WSJ0Mix/separation/train.py#L164C1-L181C64

# Reference hyperparameters:
# https://github.com/speechbrain/speechbrain/blob/1350e9b3cebae9f78e57e97d82a2d89ba3fc2ae1/recipes/WSJ0Mix/separation/hparams/sepformer.yaml


def train_model(model, criterion, optimizer, train_loader, val_loader,
                num_epochs=50, threshold=-30, loss_upper_lim=999999, clip_grad_norm=5):

    # Move model to GPU if available
    device = torch.device(
        "cuda" if torch.cuda.is_available() else
        "mps" if torch.backends.mps.is_available() else
        "cpu"
    )
    model.to(device)
    train_losses = []
    val_losses = []
    snrs = []

    # Training Loop
    for epoch in range(num_epochs):
        print("=" * 50)
        count = 0
        model.train()  # Set model to training mode
        train_loss = 0.0
        iter = 0
        nonfinite_count = 0
        for inputs, targets in train_loader:            
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            # predicted_mask = model(inputs)
            outputs = model(inputs)
#             outputs = torch.stack(outputs).clone().detach().requires_grad_(True)[0]

#             outputs = outputs[0]

            # Apply the predicted mask to the input data (element-wise multiplication)
            # masked_output = predicted_mask * inputs
            # This is done by forward pass (for Separation at least), so we're not doing this anymore?

            # Calculate loss - compare masked output with clean audio (targets)
            loss = criterion(outputs[0], targets)
            print(loss)

            # Loss thresholding
            loss = loss[loss > threshold]
            if loss.nelement() > 0:
                loss = loss.mean()  # Does this make sense?

            # Backward pass, gradient clipping, optimization
            if loss.nelement() > 0 and loss < loss_upper_lim:
                loss.backward()
                if clip_grad_norm >= 0:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        clip_grad_norm,
                    )

                optimizer.step()
                iter += 1
            else:
                nonfinite_count += 1
                print("infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                            nonfinite_count
                     )
                )
                loss.data = torch.tensor(0.0).to(device)

            optimizer.zero_grad()

            train_loss += loss.item()
            count += 1
            
            # TO HELP DEBUG MEMORY ISSUES
#             log_memory_usage()
            
            if iter % 50 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Iter [{iter}/{len(train_loader)}], Loss: {train_loss:.4f}')

        # Validation
        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        snr_sum = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)

                loss = criterion(outputs[0], targets)
                print(loss)
                if loss.nelement() > 0:
                    loss = loss.mean()  # Does this make sense?
                val_loss += loss.item()
                snr_ratio = calculate_snr(outputs[0], targets)
                snr_sum += snr_ratio.item()


        # Calculate average losses
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        snr_sum /= len(val_loader)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        snrs.append(snr_sum)
        
        # TO HELP DEBUG MEMORY ISSUES
#         log_memory_usage()

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val SNR: {10 * np.log10(snr_sum):.4f}, LR: {optimizer.param_groups[-1]['lr']}")
        torch.save({'model_state_dict': model.state_dict(), 'training_loss': train_loss, 'val_loss': val_loss},f'checkpoints/detector_epoch_{epoch + 1}.pth')


    # Save the model after training
    torch.save(model.state_dict(), 'final_model.pth')

    # Plotting
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
    plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

In [None]:
# Create DataLoader for training and validation sets
train_samples = 10_000
val_samples = 2_000

print("Initializing Train Dataset:")
train_dataset = CustomDataset(num_samples=train_samples)
print("Initializing Validation Dataset:")
val_dataset = CustomDataset(num_samples=val_samples)

In [None]:
# Initialize the Conv-Tas-Net model
model = ConvTasNet()


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Define a loss function and optimizer
criterion = sisnr  # alternatively, criterion = cal_si_snr
# criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_model(model, criterion, optimizer, train_loader, val_loader)