In [1]:
%%capture

import json
import speechbrain as sb
import os, sys
from speechbrain.utils.data_utils import get_all_files
import torch
from speechbrain.dataio.dataio import read_audio
import random
import torchaudio

import json
from tqdm import tqdm
import torchaudio
from joblib import Parallel, delayed

def process_file(path):
    # Optimized path operations
    parts = path.split("/")[-1].split("\\")[-1].split("_")
    id = "_".join(parts[:-1])
    num_speakers = parts[3]
    info = torchaudio.info(path)
    length = info.num_frames / 16000

    return id, {
        "wav_path": path.replace("\\","/"),
        "num_speakers": num_speakers,
        "length": length
    }

def load_json(json_paths, save_file="train"):
    data = {}

    # Parallel processing
    results = Parallel(n_jobs=-1, verbose=10)(
        delayed(process_file)(path) for path in json_paths
    )

    for id, path_data in results:
        data[id] = path_data

    with open(f"../data/{save_file}_data.json", 'w') as json_file:
        json.dump(data, json_file, indent=4)


# Example usage
# train_files = get_all_files("../data/train", match_and=['_segment.wav'])
test_files = get_all_files("../data/dev", match_and=['_segment.wav'])
# valid_files = get_all_files("../data/eval", match_and=['_segment.wav'])

# load_json(train_files, save_file="train")
# load_json(test_files, save_file="test")


In [8]:
test_files_no_spk = [file for file in test_files if 'spk_0' in file]
test_files_1_spk = [file for file in test_files if 'spk_1' in file]
test_files_2_spk = [file for file in test_files if 'spk_2' in file]
test_files_3_spk = [file for file in test_files if 'spk_3' in file]
test_files_4_spk = [file for file in test_files if 'spk_4' in file]

load_json(test_files_no_spk, save_file="test_files_no_spk")
load_json(test_files_1_spk, save_file="test_files_1_spk")
load_json(test_files_2_spk, save_file="test_files_2_spk")
load_json(test_files_3_spk, save_file="test_files_3_spk")
load_json(test_files_4_spk, save_file="test_files_4_spk")

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 20 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:    0.0s finished
[Parallel(n_jobs=-1)]: Batch computation too fast (0.09888553619384766s.) Setting batch_size=2.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  22 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done  33 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Batch computation too fast (0.05100083351135254s.) Setting batch_size=4.
[Parallel(n_jobs=-1)]: Done  52 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Done  78 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Done 108 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Batch computation too fast (0.05635428428649902s.) Setting batch_size=8.
[Parallel(n_jobs=-1)]: Done 156 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 224 tasks      | elapsed:    0.2s
[Paral

In [12]:
%%file hparams_xvector_augmentation_results_per_class.yaml

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]

output_folder: !ref ../results/XVector/Augmented/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Path where data manifest files will be stored
# The data manifest files are created by the data preparation script.
data_folder: ../data
train_annotation: !ref <data_folder>/train_data.json
valid_annotation: !ref <data_folder>/valid_data.json
test_annotation_0_spk: !ref <data_folder>/test_files_no_spk_data.json
test_annotation_1_spk: !ref <data_folder>/test_files_1_spk_data.json
test_annotation_2_spk: !ref <data_folder>/test_files_2_spk_data.json
test_annotation_3_spk: !ref <data_folder>/test_files_3_spk_data.json
test_annotation_4_spk: !ref <data_folder>/test_files_4_spk_data.json


# NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
# rirs_noises_root: !ref <data_folder>/RIRS_NOISES
# data_folder_noise:
#   - !ref <rirs_noises_root>/simulated_rirs/
#   - !ref <rirs_noises_root>/real_rirs_isotropic_noises/
# 
# data_folder_rir: !ref <rirs_noises_root>/pointsource_noises/ 

noise_annotation: !ref <data_folder>/noises.csv
rir_annotation: !ref <data_folder>/simulated_rirs.csv

# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

ckpt_enable: True
ckpt_interval_minutes: 15 # save checkpoint every N min

n_mels: 40
deltas: True

# Training Parameters
sample_rate: 16000
number_of_epochs: 50
batch_size: 64
lr_start: 0.001
lr_final: 0.0001
weight_decay: 0.00002

tdnn_channels: 64
tdnn_channels_out: 128
n_classes: 5 
emb_dim: 128 

num_workers: 0
dataloader_options:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>

##################################
####### Data Augmentation ########
##################################
# Download and prepare the dataset of noisy sequences for augmentation

skip_prep: True
# prepare_noise_data: !name:speechbrain.augment.preparation.prepare_dataset_from_URL
#     URL: !ref <NOISE_DATASET_URL>
#     dest_folder: !ref <data_folder_noise>
#     ext: wav
#     csv_file: !ref <noise_annotation>


# Add noise to input signal
snr_low: 0  # Min SNR for noise augmentation
snr_high: 15  # Max SNR for noise augmentation

add_noise: !new:speechbrain.augment.time_domain.AddNoise
    csv_file: !ref <noise_annotation>
    snr_low: !ref <snr_low>
    snr_high: !ref <snr_high>
    noise_sample_rate: !ref <sample_rate>
    clean_sample_rate: !ref <sample_rate>
    num_workers: !ref <num_workers>

# Speed perturbation
speed_changes: [95, 100, 105]  # List of speed changes for time-stretching

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
    orig_freq: !ref <sample_rate>
    speeds: !ref <speed_changes>

# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq_low: 0  # Min frequency band dropout probability
drop_freq_high: 1  # Max frequency band dropout probability
drop_freq_count_low: 1  # Min number of frequency bands to drop
drop_freq_count_high: 3  # Max number of frequency bands to drop
drop_freq_width: 0.05  # Width of frequency bands to drop

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: !ref <drop_freq_low>
    drop_freq_high: !ref <drop_freq_high>
    drop_freq_count_low: !ref <drop_freq_count_low>
    drop_freq_count_high: !ref <drop_freq_count_high>
    drop_freq_width: !ref <drop_freq_width>

# Time drop: randomly drops a number of temporal chunks.
drop_chunk_count_low: 1  # Min number of audio chunks to drop
drop_chunk_count_high: 5  # Max number of audio chunks to drop
drop_chunk_length_low: 1000  # Min length of audio chunks to drop
drop_chunk_length_high: 2000  # Max length of audio chunks to drop

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
    drop_length_low: !ref <drop_chunk_length_low>
    drop_length_high: !ref <drop_chunk_length_high>
    drop_count_low: !ref <drop_chunk_count_low>
    drop_count_high: !ref <drop_chunk_count_high>

# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
    parallel_augment: False
    concat_original: True
    repeat_augment: 1
    shuffle_augmentations: False
    min_augmentations: 4
    max_augmentations: 4
    augment_prob: 1.0
    augmentations: [
        !ref <add_noise>,
        !ref <speed_perturb>,
        !ref <drop_freq>,
        !ref <drop_chunk>]
        
##################################
##################################
##################################

# Feature extraction
compute_features: !new:speechbrain.lobes.features.Fbank
    n_mels: !ref <n_mels>
    # deltas: !ref <deltas>

# Mean and std normalization of the input features
mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

# To design a custom model, either just edit the simple CustomModel
# class that's listed here, or replace this `!new` call with a line
# pointing to a different file you've defined.
embedding_model: !new:speechbrain.lobes.models.Xvector.Xvector
    in_channels: !ref <n_mels>
    activation: !name:torch.nn.LeakyReLU
    tdnn_blocks: 5
    tdnn_channels:
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels_out>
    tdnn_kernel_sizes: [5, 3, 3, 1, 1]
    tdnn_dilations: [1, 2, 3, 1, 1]
    lin_neurons: !ref <emb_dim>

classifier: !new:speechbrain.lobes.models.Xvector.Classifier
    input_shape: [null, null, !ref <emb_dim>]
    activation: !name:torch.nn.LeakyReLU
    lin_blocks: 1
    lin_neurons: !ref <emb_dim>
    out_neurons: !ref <n_classes>

# The first object passed to the Brain class is this "Epoch Counter"
# which is saved by the Checkpointer so that training can be resumed
# if it gets interrupted at any point.
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

# Objects in "modules" dict will have their parameters moved to the correct
# device, as well as having train()/eval() called on them by the Brain class.
modules:
    compute_features: !ref <compute_features>
    embedding_model: !ref <embedding_model>
    classifier: !ref <classifier>
    mean_var_norm: !ref <mean_var_norm>

# This optimizer will be constructed by the Brain class after all parameters
# are moved to the correct device. Then it will be added to the checkpointer.
opt_class: !name:torch.optim.Adam
    lr: !ref <lr_start>
    weight_decay: !ref <weight_decay>

# This function manages learning rate annealing over the epochs.
# We here use the simple lr annealing method that linearly decreases
# the lr from the initial value to the final one.
lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler
    initial_value: !ref <lr_start>
    final_value: !ref <lr_final>
    epoch_count: !ref <number_of_epochs>

# This object is used for saving the state of training both so that it
# can be resumed if it gets interrupted, and also so that the best checkpoint
# can be later loaded for evaluation or inference.
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        embedding_model: !ref <embedding_model>
        classifier: !ref <classifier>
        normalizer: !ref <mean_var_norm>
        counter: !ref <epoch_counter>

Overwriting hparams_xvector_augmentation_results_per_class.yaml


In [13]:
%%file train_xvector_augmentation_results_per_each_class.py

import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils import hpopt as hp
import torchaudio


# Brain class for speech enhancement training
class SpkIdBrain(sb.Brain):
    """Class that manages the training loop. See speechbrain.core.Brain."""

    def compute_forward(self, batch, stage):
        """Runs all the computation of that transforms the input into the
        output probabilities over the N classes.

        Arguments
        ---------
        batch : PaddedBatch
            This batch object contains all the relevant tensors for computation.
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.

        Returns
        -------
        predictions : Tensor
            Tensor that contains the posterior probabilities over the N classes.
        """

        # We first move the batch to the appropriate device.
        batch = batch.to(self.device)
        # Compute features, embeddings, and predictions
        feats, lens = self.prepare_features(batch.sig, stage)
        embeddings = self.modules.embedding_model(feats, lens)
        predictions = self.modules.classifier(embeddings)

        return predictions

    def prepare_features(self, wavs, stage):
        """Prepare the features for computation, including augmentation.

        Arguments
        ---------
        wavs : tuple
            Input signals (tensor) and their relative lengths (tensor).
        stage : sb.Stage
            The current stage of training.
        """
        wavs, lens = wavs

        # Add waveform augmentation if specified.
        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
            wavs, lens = self.hparams.wav_augment(wavs, lens)

        # Feature extraction and normalization
        feats = self.modules.compute_features(wavs)
        feats = self.modules.mean_var_norm(feats, lens)

        return feats, lens

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss given the predicted and targeted outputs.

        Arguments
        ---------
        predictions : tensor
            The output tensor from `compute_forward`.
        batch : PaddedBatch
            This batch object contains all the relevant tensors for computation.
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.

        Returns
        -------
        loss : torch.Tensor
            A one-element tensor used for backpropagating the gradient.
        """

        _, lens = batch.sig
        spks, _ = batch.num_speakers_encoded

        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
            spks = self.hparams.wav_augment.replicate_labels(spks)
            lens = self.hparams.wav_augment.replicate_labels(lens)

        # Compute the cost function
        loss = sb.nnet.losses.nll_loss(predictions, spks, lens)

        # Append this batch of losses to the loss metric for easy
        self.loss_metric.append(
            batch.id, predictions, spks, lens, reduction="batch"
        )

        # Compute classification error at test time
        if stage != sb.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, spks, lens)

        return loss

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch.

        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Set up statistics trackers for this stage
        self.loss_metric = sb.utils.metric_stats.MetricStats(
            metric=sb.nnet.losses.nll_loss
        )

        # Set up evaluation-only statistics trackers
        if stage != sb.Stage.TRAIN:
            self.error_metrics = self.hparams.error_stats()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch.

        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
        stage_loss : float
            The average loss for all of the data processed in this stage.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Store the train loss until the validation stage.
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss

        # Summarize the statistics from the stage for record-keeping.
        else:
            stats = {
                "loss": stage_loss,
                "error": self.error_metrics.summarize("average"),
            }

        # At the end of validation...
        if stage == sb.Stage.VALID:

            old_lr, new_lr = self.hparams.lr_annealing(epoch)
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            # The train_logger writes a summary to stdout and to the logfile.
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            # Save the current checkpoint and delete previous checkpoints,
            if self.hparams.ckpt_enable:
                self.checkpointer.save_and_keep_only(
                    meta=stats, min_keys=["error"]
                )
            hp.report_result(stats)

        # We also write statistics about test data to stdout and to the logfile.
        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )


def dataio_prep(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions.
    We expect `prepare_mini_librispeech` to have been called before this,
    so that the `train.json`, `valid.json`,  and `valid.json` manifest files
    are available.

    Arguments
    ---------
    hparams : dict
        This dictionary is loaded from the `train.yaml` file, and it includes
        all the hyperparameters needed for dataset construction and loading.

    Returns
    -------
    datasets : dict
        Contains two keys, "train" and "valid" that correspond
        to the appropriate DynamicItemDataset object.
    """

    # Initialization of the label encoder. The label encoder assigns to each
    # of the observed label a unique index (e.g, 'spk01': 0, 'spk02': 1, ..)
    label_encoder = sb.dataio.encoder.CategoricalEncoder()

    # Define audio pipeline
    @sb.utils.data_pipeline.takes("wav_path")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav_path):
        """Load the signal, and pass it and its length to the corruption class.
        This is done on the CPU in the `collate_fn`."""
        sig, fs = torchaudio.load(wav_path)

        # Resampling
        sig = torchaudio.functional.resample(sig, fs, 16000).squeeze(0)
        return sig

    # Define label pipeline:
    @sb.utils.data_pipeline.takes("num_speakers")
    @sb.utils.data_pipeline.provides("num_speakers", "num_speakers_encoded")
    def label_pipeline(num_speakers):
        """Defines the pipeline to process the input speaker label."""
        yield num_speakers
        num_speakers_encoded = label_encoder.encode_label_torch(num_speakers)
        yield num_speakers_encoded

    # Define datasets. We also connect the dataset with the data processing
    # functions defined above.
    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        "test_no_spk": hparams["test_annotation_0_spk"],
        "test_1_spk": hparams["test_annotation_1_spk"],
        "test_2_spk": hparams["test_annotation_2_spk"],
        "test_3_spk": hparams["test_annotation_3_spk"],
        "test_4_spk": hparams["test_annotation_4_spk"],
        
    }
    hparams["dataloader_options"]["shuffle"] = True
    for dataset in data_info:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline, label_pipeline],
            output_keys=["id", "sig", "num_speakers_encoded"],
        )

    # Load or compute the label encoder (with multi-GPU DDP support)
    # Please, take a look into the lab_enc_file to see the label to index
    # mapping.
    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[datasets["train"]],
        output_key="num_speakers",
    )

    return datasets


# Recipe begins!
if __name__ == "__main__":

    with hp.hyperparameter_optimization(objective_key="error") as hp_ctx:

        # Reading command line arguments
        hparams_file, run_opts, overrides = hp_ctx.parse_arguments(
            sys.argv[1:], pass_trial_id=False
        )

        # Initialize ddp (useful only for multi-GPU DDP training).
        # sb.utils.distributed.ddp_init_group(run_opts)

        # Load hyperparameters file with command-line overrides.
        with open(hparams_file) as fin:
            hparams = load_hyperpyyaml(fin, overrides)

        # Create experiment directory
        sb.create_experiment_directory(
            experiment_directory=hparams["output_folder"],
            hyperparams_to_save=hparams_file,
            overrides=overrides,
        )

        # # Data preparation, to be run on only one process.
        # if not hparams["skip_prep"]:
        #     sb.utils.distributed.run_on_main(
        #         kwargs={
        #             "data_folder": hparams["data_folder"],
        #             "save_json_train": hparams["train_annotation"],
        #             "save_json_valid": hparams["valid_annotation"],
        #             "save_json_test": hparams["test_annotation"],
        #             "split_ratio": hparams["split_ratio"],
        #         },
        #     )
        # sb.utils.distributed.run_on_main(hparams["prepare_noise_data"])

        # Create dataset objects "train", "valid", and "test".
        datasets = dataio_prep(hparams)

        # Initialize the Brain object to prepare for mask training.
        spk_id_brain = SpkIdBrain(
            modules=hparams["modules"],
            opt_class=hparams["opt_class"],
            hparams=hparams,
            run_opts=run_opts,
            checkpointer=hparams["checkpointer"],
        )

        # The `fit()` method iterates the training loop, calling the methods
        # necessary to update the parameters of the model. Since all objects
        # with changing state are managed by the Checkpointer, training can be
        # stopped at any point, and will be resumed on next call.
        spk_id_brain.fit(
            epoch_counter=spk_id_brain.hparams.epoch_counter,
            train_set=datasets["train"],
            valid_set=datasets["valid"],
            train_loader_kwargs=hparams["dataloader_options"],
            valid_loader_kwargs=hparams["dataloader_options"],
        )
        if not hp_ctx.enabled:
            # Load the best checkpoint for evaluation
            print("Testing on no speakers class")
            test_stats = spk_id_brain.evaluate(
                test_set=datasets["test_no_spk"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )
            print("Testing on 1 speaker class")
            test_stats = spk_id_brain.evaluate(
                test_set=datasets["test_1_spk"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )
            print("Testing on 2 speakers class")
            test_stats = spk_id_brain.evaluate(
                test_set=datasets["test_2_spk"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )
            print("Testing on 3 speakers class")
            test_stats = spk_id_brain.evaluate(
                test_set=datasets["test_3_spk"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )
            print("Testing on 4 speakers class")
            test_stats = spk_id_brain.evaluate(
                test_set=datasets["test_4_spk"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )

Overwriting train_xvector_augmentation_results_per_each_class.py


In [14]:
!python train_xvector_augmentation_results_per_each_class.py hparams_xvector_augmentation_results_per_class.yaml

^C
