In [None]:
%%capture

import json
import speechbrain as sb
import os, sys
from speechbrain.utils.data_utils import get_all_files
import torch
from speechbrain.dataio.dataio import read_audio
import random
import torchaudio


In [None]:
import speechbrain as sb
import os, sys
from speechbrain.utils.data_utils import get_all_files
from speechbrain.dataio.dataio import read_audio
import json
import torchaudio
from joblib import Parallel, delayed

def process_file(path):
    """
    Processes a single audio file to extract its identifier, path, number of speakers, and length.

    Parameters:
    - path (str): The file path of the audio file.

    Returns:
    - tuple: A tuple containing the audio file identifier and a dictionary with metadata about the audio file.
             The metadata includes the normalized path, number of speakers, and length in seconds.
    """
    parts = path.split("/")[-1].split("\\")[-1].split("_") #.split("\\") can be ignored when using a linux based system
    id = "_".join(parts[:-1])
    num_speakers = parts[3]
    info = torchaudio.info(path)
    length = info.num_frames / 16000

    return id, {
        "wav_path": path.replace("\\","/"),  #.split("\\") can be ignored when using a linux based system
        "num_speakers": num_speakers,
        "length": length
    }

def load_json(json_paths, save_file="train"):
    """
    Loads multiple audio files, processes each using `process_file`, and saves the metadata in a JSON file.

    Parameters:
    - json_paths (list): A list of paths to audio files to process.
    - save_file (str): The base name for the output JSON file where the metadata will be saved.

    Returns:
    None. This function generates a JSON file in the '../data/' directory containing the metadata for each audio file.
    """
    data = {}

    # Parallel processing
    results = Parallel(n_jobs=-1, verbose=10)(
        delayed(process_file)(path) for path in json_paths
    )

    for id, path_data in results:
        data[id] = path_data

    with open(f"../data/{save_file}_data.json", 'w') as json_file:
        json.dump(data, json_file, indent=4)

train_files = get_all_files("../data/train", match_and=['_segment.wav'])
test_files = get_all_files("../data/dev", match_and=['_segment.wav'])
valid_files = get_all_files("../data/eval", match_and=['_segment.wav'])

load_json(train_files, save_file="train")
load_json(test_files, save_file="test")
load_json(valid_files, save_file="valid")


## XVector Augmented

In [None]:
%%file hparams_xvector_augmentation.yaml

# Seed for reproducibility of results. Must be set before initializing 
# model components that depend on randomness.
seed: 1986
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]

# Specify paths for storing and accessing data.
output_folder: !ref ../results/XVector/Augmented/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Paths for saving outputs and logs from the model training process.
data_folder: ../data
train_annotation: !ref <data_folder>/train_data.json
valid_annotation: !ref <data_folder>/valid_data.json
test_annotation: !ref <data_folder>/test_data.json

# Annotations for additional noise and room impulse responses (RIRs).
noise_annotation: !ref <data_folder>/noises.csv
rir_annotation: !ref <data_folder>/simulated_rirs.csv

# Logger configuration for recording training progress and statistics.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

# Metric statistics configuration for evaluating model performance.
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

# Checkpoint configuration: enables periodic saving of model state.
ckpt_enable: True
ckpt_interval_minutes: 15 # Interval in minutes for saving checkpoints.

####################### Training Parameters ####################################
# Model and training hyperparameters.
n_mels: 40
sample_rate: 16000
number_of_epochs: 50
batch_size: 64
lr_start: 0.001
lr_final: 0.0001
weight_decay: 0.00002
tdnn_channels: 64
tdnn_channels_out: 128
n_classes: 5 
emb_dim: 128 

# DataLoader configuration to specify how training data is batched and handled during training.
num_workers: 0 # Number of workers for data loading. Use 2 for Linux, 0 for Windows compatibility.
dataloader_options:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>

####################### Data Augmentation ####################################
# Data augmentation settings to enhance model robustness.
skip_prep: True
snr_low: 0  # Min SNR (signal-to-noise) for noise augmentation
snr_high: 15  # Max SNR for noise augmentation

add_noise: !new:speechbrain.augment.time_domain.AddNoise
    csv_file: !ref <noise_annotation>
    snr_low: !ref <snr_low>
    snr_high: !ref <snr_high>
    noise_sample_rate: !ref <sample_rate>
    clean_sample_rate: !ref <sample_rate>
    num_workers: !ref <num_workers>

# Speed perturbation for time-stretching audio samples.
speed_changes: [95, 100, 105]  # List of speed changes for time-stretching

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
    orig_freq: !ref <sample_rate>
    speeds: !ref <speed_changes>

# Frequency band dropout to simulate frequency-specific signal loss.
drop_freq_low: 0  # Min frequency band dropout probability
drop_freq_high: 1  # Max frequency band dropout probability
drop_freq_count_low: 1  # Min number of frequency bands to drop
drop_freq_count_high: 3  # Max number of frequency bands to drop
drop_freq_width: 0.05  # Width of frequency bands to drop

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: !ref <drop_freq_low>
    drop_freq_high: !ref <drop_freq_high>
    drop_freq_count_low: !ref <drop_freq_count_low>
    drop_freq_count_high: !ref <drop_freq_count_high>
    drop_freq_width: !ref <drop_freq_width>

# Combines all defined augmentations for training data preprocessing.
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
    parallel_augment: False
    concat_original: True
    repeat_augment: 1
    shuffle_augmentations: False
    min_augmentations: 4
    max_augmentations: 4
    augment_prob: 1.0
    augmentations: [
        !ref <add_noise>,
        !ref <speed_perturb>,
        !ref <drop_freq>]
        
###################### Feature and Model Configuration ##################################

# Feature extraction using Mel-frequency cepstral coefficients.
compute_features: !new:speechbrain.lobes.features.Fbank
    n_mels: !ref <n_mels>

# Normalization of input features by mean and variance.
mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

# X-vector model configuration for embedding extraction.
embedding_model: !new:speechbrain.lobes.models.Xvector.Xvector
    in_channels: !ref <n_mels>
    activation: !name:torch.nn.LeakyReLU
    tdnn_blocks: 5
    tdnn_channels:
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels_out>
    tdnn_kernel_sizes: [5, 3, 3, 1, 1]
    tdnn_dilations: [1, 2, 3, 1, 1]
    lin_neurons: !ref <emb_dim>
        
# Classifier model configuration.
classifier: !new:speechbrain.lobes.models.Xvector.Classifier
    input_shape: [null, null, !ref <emb_dim>]
    activation: !name:torch.nn.LeakyReLU
    lin_blocks: 1
    lin_neurons: !ref <emb_dim>
    out_neurons: !ref <n_classes>
        
# Epoch counter for tracking the number of training iterations.
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

# Organizes modules for parameter optimization.
modules:
    compute_features: !ref <compute_features>
    embedding_model: !ref <embedding_model>
    classifier: !ref <classifier>
    mean_var_norm: !ref <mean_var_norm>

# Optimizer configuration for model training.
opt_class: !name:torch.optim.Adam
    lr: !ref <lr_start>
    weight_decay: !ref <weight_decay>

# Learning rate scheduler for adjusting the learning rate during training.
lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler
    initial_value: !ref <lr_start>
    final_value: !ref <lr_final>
    epoch_count: !ref <number_of_epochs>
        
# Checkpointer for managing the saving and loading of model states.
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        embedding_model: !ref <embedding_model>
        classifier: !ref <classifier>
        normalizer: !ref <mean_var_norm>
        counter: !ref <epoch_counter>

In [None]:
%%file train_xvector_augmentation.py

import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils import hpopt as hp
import torchaudio


class XVectorSpkCounter(sb.Brain):
    """
    A custom Brain class for training and evaluating a speaker counting model.
    This class is designed to handle the forward pass, loss computation, and
    the training and validation cycles, leveraging SpeechBrain's workflow.
    """

    def compute_forward(self, batch, stage):
        """
        Processes the input batch to produce model predictions.

        Parameters:
        - batch (PaddedBatch): Contains all tensors needed for computation.
        - stage (sb.Stage): The stage of the pipeline (TRAIN, VALID, or TEST).

        Returns:
        - Tensor: Posterior probabilities over the number of classes.
        """

        batch = batch.to(self.device)
        feats, lens = self.prepare_features(batch.sig, stage)
        embeddings = self.modules.embedding_model(feats, lens)
        predictions = self.modules.classifier(embeddings)

        return predictions

    def prepare_features(self, wavs, stage):
        """
        Prepares the signal features for model computation, applying
        waveform augmentation and feature extraction.

        Parameters:
        - wavs (tuple): Tuple of signals and their lengths.
        - stage (sb.Stage): Current training stage.

        Returns:
        - Tuple[Tensor, Tensor]: Features and their lengths.
        """

        wavs, lens = wavs

        # Add waveform augmentation.
        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
            wavs, lens = self.hparams.wav_augment(wavs, lens)

        # Feature extraction and normalization
        feats = self.modules.compute_features(wavs)
        feats = self.modules.mean_var_norm(feats, lens)

        return feats, lens

    def compute_objectives(self, predictions, batch, stage):
        """
        Computes the loss given predictions and targets.

        Parameters:
        - predictions (Tensor): Model predictions.
        - batch (PaddedBatch): Batch providing the targets.
        - stage (sb.Stage): The training stage.

        Returns:
        - Tensor: The loss tensor.
        """

        _, lens = batch.sig
        spks, _ = batch.num_speakers_encoded

        # Replicate labels for augmented audios
        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
            spks = self.hparams.wav_augment.replicate_labels(spks)
            lens = self.hparams.wav_augment.replicate_labels(lens)

        # Compute the cost function
        loss = sb.nnet.losses.nll_loss(predictions, spks, lens)

        self.loss_metric.append(
            batch.id, predictions, spks, lens, reduction="batch"
        )

        if stage != sb.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, spks, lens)

        return loss

    def on_stage_start(self, stage, epoch=None):
        """
        Initializes trackers at the beginning of each stage.

        Parameters:
        - stage (sb.Stage): Current stage.
        - epoch (int, optional): Current epoch number.
        """

        self.loss_metric = sb.utils.metric_stats.MetricStats(
            metric=sb.nnet.losses.nll_loss
        )

        # Set up evaluation-only statistics trackers
        if stage != sb.Stage.TRAIN:
            self.error_metrics = self.hparams.error_stats()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """
        Handles logging and learning rate adjustments at the end of each stage.

        Parameters:
        - stage (sb.Stage): Current stage.
        - stage_loss (float): Average loss of the stage.
        - epoch (int, optional): Current epoch number.
        """

        # Store the train loss until the validation stage.
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss

        # Summarize the statistics from the stage for record-keeping.
        else:
            stats = {
                "loss": stage_loss,
                "error": self.error_metrics.summarize("average"),
            }

        if stage == sb.Stage.VALID:

            old_lr, new_lr = self.hparams.lr_annealing(epoch)
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            # The train_logger writes a summary to stdout and to the logfile.
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            # Save the current checkpoint and delete previous checkpoints,
            if self.hparams.ckpt_enable:
                self.checkpointer.save_and_keep_only(
                    meta=stats, min_keys=["error"]
                )
            hp.report_result(stats)

        # We also write statistics about test data to stdout and to the logfile.
        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )


def dataio_prep(hparams):
    """
    Prepares and returns datasets for training, validation, and testing.
    Parameters:
    - hparams (dict): A dictionary of hyperparameters for data preparation.
    Returns:
    - datasets (dict): A dictionary containing 'train', 'valid', and 'test' datasets.
    """

    # Initialization of the label encoder.
    label_encoder = sb.dataio.encoder.CategoricalEncoder()

    # Define audio pipeline
    @sb.utils.data_pipeline.takes("wav_path")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav_path):
        """
        Audio processing pipeline that loads and returns an audio signal.
        Parameters:
            - wav_path (str): Path to the audio file.
        Returns:
            - sig (Tensor): Loaded audio signal tensor.
        """
        sig, fs = torchaudio.load(wav_path)

        # Resampling
        sig = torchaudio.functional.resample(sig, fs, 16000).squeeze(0)
        return sig

    # Define label pipeline:
    @sb.utils.data_pipeline.takes("num_speakers")
    @sb.utils.data_pipeline.provides("num_speakers", "num_speakers_encoded")
    def label_pipeline(num_speakers):
        """
        Processes and encodes the number of speakers.

        Parameters:
        - num_speakers (int): The number of speakers in the audio.

        Yields:
        - num_speakers (int): The original number of speakers.
        - num_speakers_encoded (Tensor): Encoded tensor of the number of speakers.
        """
        yield num_speakers
        num_speakers_encoded = label_encoder.encode_label_torch(num_speakers)
        yield num_speakers_encoded

    # Define datasets. We also connect the dataset with the data processing
    # functions defined above.
    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        "test": hparams["test_annotation"],
    }
    hparams["dataloader_options"]["shuffle"] = True
    for dataset in data_info:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline, label_pipeline],
            output_keys=["id", "sig", "num_speakers_encoded"],
        )

    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[datasets["train"]],
        output_key="num_speakers",
    )

    return datasets


# RECIPE BEGINS!
if __name__ == "__main__":

    with hp.hyperparameter_optimization(objective_key="error") as hp_ctx:

        # Reading command line arguments
        hparams_file, run_opts, overrides = hp_ctx.parse_arguments(
            sys.argv[1:], pass_trial_id=False
        )

        # Load hyperparameters file with command-line overrides.
        with open(hparams_file) as fin:
            hparams = load_hyperpyyaml(fin, overrides)

        # Create experiment directory
        sb.create_experiment_directory(
            experiment_directory=hparams["output_folder"],
            hyperparams_to_save=hparams_file,
            overrides=overrides,
        )

        # Create dataset objects "train", "valid", and "test".
        datasets = dataio_prep(hparams)

        # Initialize the Brain object to prepare for mask training.
        spk_counter = XVectorSpkCounter(
            modules=hparams["modules"],
            opt_class=hparams["opt_class"],
            hparams=hparams,
            run_opts=run_opts,
            checkpointer=hparams["checkpointer"],
        )

        spk_counter.fit(
            epoch_counter=spk_counter.hparams.epoch_counter,
            train_set=datasets["train"],
            valid_set=datasets["valid"],
            train_loader_kwargs=hparams["dataloader_options"],
            valid_loader_kwargs=hparams["dataloader_options"],
        )
        if not hp_ctx.enabled:
            # Load the best checkpoint for evaluation
            test_stats = spk_counter.evaluate(
                test_set=datasets["test"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )


In [None]:
import torch
torch.cuda.set_device("cuda:0")
!python train_xvector_augmentation.py hparams_xvector_augmentation.yaml

## XVector UnAugmented

In [None]:
%%file hparams_xvector_fbanks.yaml

# Seed for reproducibility of results. Must be set before initializing 
# model components that depend on randomness.
seed: 1986
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]

# Specify paths for storing and accessing data.
output_folder: !ref ../results/XVector/Augmented/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Paths for saving outputs and logs from the model training process.
data_folder: ../data
train_annotation: !ref <data_folder>/train_data.json
valid_annotation: !ref <data_folder>/valid_data.json
test_annotation: !ref <data_folder>/test_data.json

# Annotations for additional noise and room impulse responses (RIRs).
noise_annotation: !ref <data_folder>/noises.csv
rir_annotation: !ref <data_folder>/simulated_rirs.csv

# Logger configuration for recording training progress and statistics.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

# Metric statistics configuration for evaluating model performance.
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

# Checkpoint configuration: enables periodic saving of model state.
ckpt_enable: True
ckpt_interval_minutes: 15 # Interval in minutes for saving checkpoints.

####################### Training Parameters ####################################
# Model and training hyperparameters.
n_mels: 40
sample_rate: 16000
number_of_epochs: 50
batch_size: 64
lr_start: 0.001
lr_final: 0.0001
weight_decay: 0.00002
tdnn_channels: 64
tdnn_channels_out: 128
n_classes: 5 
emb_dim: 128 

# DataLoader configuration to specify how training data is batched and handled during training.
num_workers: 0 # Number of workers for data loading. Use 2 for Linux, 0 for Windows compatibility.
dataloader_options:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
        
###################### Feature and Model Configuration ##################################

# Feature extraction using Mel-frequency cepstral coefficients.
compute_features: !new:speechbrain.lobes.features.Fbank
    n_mels: !ref <n_mels>

# Normalization of input features by mean and variance.
mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

# X-vector model configuration for embedding extraction.
embedding_model: !new:speechbrain.lobes.models.Xvector.Xvector
    in_channels: !ref <n_mels>
    activation: !name:torch.nn.LeakyReLU
    tdnn_blocks: 5
    tdnn_channels:
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels>
        - !ref <tdnn_channels_out>
    tdnn_kernel_sizes: [5, 3, 3, 1, 1]
    tdnn_dilations: [1, 2, 3, 1, 1]
    lin_neurons: !ref <emb_dim>
        
# Classifier model configuration.
classifier: !new:speechbrain.lobes.models.Xvector.Classifier
    input_shape: [null, null, !ref <emb_dim>]
    activation: !name:torch.nn.LeakyReLU
    lin_blocks: 1
    lin_neurons: !ref <emb_dim>
    out_neurons: !ref <n_classes>
        
# Epoch counter for tracking the number of training iterations.
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

# Organizes modules for parameter optimization.
modules:
    compute_features: !ref <compute_features>
    embedding_model: !ref <embedding_model>
    classifier: !ref <classifier>
    mean_var_norm: !ref <mean_var_norm>

# Optimizer configuration for model training.
opt_class: !name:torch.optim.Adam
    lr: !ref <lr_start>
    weight_decay: !ref <weight_decay>

# Learning rate scheduler for adjusting the learning rate during training.
lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler
    initial_value: !ref <lr_start>
    final_value: !ref <lr_final>
    epoch_count: !ref <number_of_epochs>
        
# Checkpointer for managing the saving and loading of model states.
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        embedding_model: !ref <embedding_model>
        classifier: !ref <classifier>
        normalizer: !ref <mean_var_norm>
        counter: !ref <epoch_counter>

In [None]:
%%file train_xvector_fbanks.py

import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils import hpopt as hp
import torchaudio


class XVectorSpkCounter(sb.Brain):
    """
    A custom Brain class for training and evaluating a speaker counting model.
    This class is designed to handle the forward pass, loss computation, and
    the training and validation cycles, leveraging SpeechBrain's workflow.
    """

    def compute_forward(self, batch, stage):
        """
        Processes the input batch to produce model predictions.

        Parameters:
        - batch (PaddedBatch): Contains all tensors needed for computation.
        - stage (sb.Stage): The stage of the pipeline (TRAIN, VALID, or TEST).

        Returns:
        - Tensor: Posterior probabilities over the number of classes.
        """

        batch = batch.to(self.device)
        feats, lens = self.prepare_features(batch.sig, stage)
        embeddings = self.modules.embedding_model(feats, lens)
        predictions = self.modules.classifier(embeddings)

        return predictions

    def prepare_features(self, wavs, stage):
        """
        Prepares the signal features for model computation, applying
        waveform augmentation and feature extraction.

        Parameters:
        - wavs (tuple): Tuple of signals and their lengths.
        - stage (sb.Stage): Current training stage.

        Returns:
        - Tuple[Tensor, Tensor]: Features and their lengths.
        """

        wavs, lens = wavs

        # Feature extraction and normalization
        feats = self.modules.compute_features(wavs)
        feats = self.modules.mean_var_norm(feats, lens)

        return feats, lens

    def compute_objectives(self, predictions, batch, stage):
        """
        Computes the loss given predictions and targets.

        Parameters:
        - predictions (Tensor): Model predictions.
        - batch (PaddedBatch): Batch providing the targets.
        - stage (sb.Stage): The training stage.

        Returns:
        - Tensor: The loss tensor.
        """

        _, lens = batch.sig
        spks, _ = batch.num_speakers_encoded
        
        # Compute the cost function
        loss = sb.nnet.losses.nll_loss(predictions, spks, lens)

        self.loss_metric.append(
            batch.id, predictions, spks, lens, reduction="batch"
        )

        if stage != sb.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, spks, lens)

        return loss

    def on_stage_start(self, stage, epoch=None):
        """
        Initializes trackers at the beginning of each stage.

        Parameters:
        - stage (sb.Stage): Current stage.
        - epoch (int, optional): Current epoch number.
        """

        self.loss_metric = sb.utils.metric_stats.MetricStats(
            metric=sb.nnet.losses.nll_loss
        )

        # Set up evaluation-only statistics trackers
        if stage != sb.Stage.TRAIN:
            self.error_metrics = self.hparams.error_stats()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """
        Handles logging and learning rate adjustments at the end of each stage.

        Parameters:
        - stage (sb.Stage): Current stage.
        - stage_loss (float): Average loss of the stage.
        - epoch (int, optional): Current epoch number.
        """

        # Store the train loss until the validation stage.
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss

        # Summarize the statistics from the stage for record-keeping.
        else:
            stats = {
                "loss": stage_loss,
                "error": self.error_metrics.summarize("average"),
            }

        if stage == sb.Stage.VALID:

            old_lr, new_lr = self.hparams.lr_annealing(epoch)
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            # The train_logger writes a summary to stdout and to the logfile.
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            # Save the current checkpoint and delete previous checkpoints,
            if self.hparams.ckpt_enable:
                self.checkpointer.save_and_keep_only(
                    meta=stats, min_keys=["error"]
                )
            hp.report_result(stats)

        # We also write statistics about test data to stdout and to the logfile.
        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )


def dataio_prep(hparams):
    """
    Prepares and returns datasets for training, validation, and testing.
    Parameters:
    - hparams (dict): A dictionary of hyperparameters for data preparation.
    Returns:
    - datasets (dict): A dictionary containing 'train', 'valid', and 'test' datasets.
    """

    # Initialization of the label encoder.
    label_encoder = sb.dataio.encoder.CategoricalEncoder()

    # Define audio pipeline
    @sb.utils.data_pipeline.takes("wav_path")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav_path):
        """
        Audio processing pipeline that loads and returns an audio signal.
        Parameters:
            - wav_path (str): Path to the audio file.
        Returns:
            - sig (Tensor): Loaded audio signal tensor.
        """
        sig, fs = torchaudio.load(wav_path)

        # Resampling
        sig = torchaudio.functional.resample(sig, fs, 16000).squeeze(0)
        return sig

    # Define label pipeline:
    @sb.utils.data_pipeline.takes("num_speakers")
    @sb.utils.data_pipeline.provides("num_speakers", "num_speakers_encoded")
    def label_pipeline(num_speakers):
        """
        Processes and encodes the number of speakers.

        Parameters:
        - num_speakers (int): The number of speakers in the audio.

        Yields:
        - num_speakers (int): The original number of speakers.
        - num_speakers_encoded (Tensor): Encoded tensor of the number of speakers.
        """
        yield num_speakers
        num_speakers_encoded = label_encoder.encode_label_torch(num_speakers)
        yield num_speakers_encoded

    # Define datasets. We also connect the dataset with the data processing
    # functions defined above.
    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        "test": hparams["test_annotation"],
    }
    hparams["dataloader_options"]["shuffle"] = True
    for dataset in data_info:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline, label_pipeline],
            output_keys=["id", "sig", "num_speakers_encoded"],
        )

    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[datasets["train"]],
        output_key="num_speakers",
    )

    return datasets


# RECIPE BEGINS!
if __name__ == "__main__":

    with hp.hyperparameter_optimization(objective_key="error") as hp_ctx:

        # Reading command line arguments
        hparams_file, run_opts, overrides = hp_ctx.parse_arguments(
            sys.argv[1:], pass_trial_id=False
        )

        # Load hyperparameters file with command-line overrides.
        with open(hparams_file) as fin:
            hparams = load_hyperpyyaml(fin, overrides)

        # Create experiment directory
        sb.create_experiment_directory(
            experiment_directory=hparams["output_folder"],
            hyperparams_to_save=hparams_file,
            overrides=overrides,
        )

        # Create dataset objects "train", "valid", and "test".
        datasets = dataio_prep(hparams)

        # Initialize the Brain object to prepare for mask training.
        spk_counter = XVectorSpkCounter(
            modules=hparams["modules"],
            opt_class=hparams["opt_class"],
            hparams=hparams,
            run_opts=run_opts,
            checkpointer=hparams["checkpointer"],
        )

        spk_counter.fit(
            epoch_counter=spk_counter.hparams.epoch_counter,
            train_set=datasets["train"],
            valid_set=datasets["valid"],
            train_loader_kwargs=hparams["dataloader_options"],
            valid_loader_kwargs=hparams["dataloader_options"],
        )
        if not hp_ctx.enabled:
            # Load the best checkpoint for evaluation
            test_stats = spk_counter.evaluate(
                test_set=datasets["test"],
                min_key="error",
                test_loader_kwargs=hparams["dataloader_options"],
            )


In [None]:
import torch
torch.cuda.set_device("cuda:0")
!python train_xvector_fbanks.py hparams_xvector_fbanks.yaml