In [None]:
import speechbrain as sb
import os, sys
from speechbrain.utils.data_utils import get_all_files
from speechbrain.dataio.dataio import read_audio
import json
import torchaudio
from joblib import Parallel, delayed

def process_file(path):
    """
    Processes a single audio file to extract its identifier, path, number of speakers, and length.

    Parameters:
    - path (str): The file path of the audio file.

    Returns:
    - tuple: A tuple containing the audio file identifier and a dictionary with metadata about the audio file.
             The metadata includes the normalized path, number of speakers, and length in seconds.
    """
    parts = path.split("/")[-1].split("\\")[-1].split("_") #.split("\\") can be ignored when using a linux based system
    id = "_".join(parts[:-1])
    num_speakers = parts[3]
    info = torchaudio.info(path)
    length = info.num_frames / 16000

    return id, {
        "wav_path": path.replace("\\","/"),  #.split("\\") can be ignored when using a linux based system
        "num_speakers": num_speakers,
        "length": length
    }

def load_json(json_paths, save_file="train"):
    """
    Loads multiple audio files, processes each using `process_file`, and saves the metadata in a JSON file.

    Parameters:
    - json_paths (list): A list of paths to audio files to process.
    - save_file (str): The base name for the output JSON file where the metadata will be saved.

    Returns:
    None. This function generates a JSON file in the '../data/' directory containing the metadata for each audio file.
    """
    data = {}

    # Parallel processing
    results = Parallel(n_jobs=-1, verbose=10)(
        delayed(process_file)(path) for path in json_paths
    )

    for id, path_data in results:
        data[id] = path_data

    with open(f"../data/{save_file}_data.json", 'w') as json_file:
        json.dump(data, json_file, indent=4)

train_files = get_all_files("../data/train", match_and=['_segment.wav'])
test_files = get_all_files("../data/dev", match_and=['_segment.wav'])
valid_files = get_all_files("../data/eval", match_and=['_segment.wav'])

load_json(train_files, save_file="train")
load_json(test_files, save_file="test")
load_json(valid_files, save_file="valid")


## XVector Model

In [None]:
%%file hparams_selfsupervised_xvector.yaml

# Seed for reproducibility of results. Must be set before initializing model components that depend on randomness.
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Paths for saving outputs and logs from the model training process.
output_folder: !ref ../results/selfsupervised/Xvector/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# SSL (Self-Supervised Learning) Model Configuration
# We choose the base model of wav2vec2 which is not fine-tuned to demonstrate the generalizability and potential improvements.
sslmodel_hub: facebook/wav2vec2-base
sslmodel_folder: !ref <save_folder>/ssl_checkpoint

# Directories for storing processed data and annotations.
data_folder: ../data  # e.g., /path/to/data
train_annotation: !ref <data_folder>/train_data.json
valid_annotation: !ref <data_folder>/valid_data.json
test_annotation: !ref <data_folder>/test_data.json

# Logger for recording training progress and statistics.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>


####################### Training Parameters ####################################
number_of_epochs: 5
batch_size: 32
lr: 0.001 # Learning rate for the model optimizer.
lr_final: 0.0001 # Final learning rate after annealing.
lr_ssl: 0.00001 # Learning rate specific to the SSL model components.

# Control the freezing of model layers to fine-tune specific components.
freeze_ssl: False # Freeze all layers of the SSL model.
freeze_ssl_conv: True # Only freeze convolutional layers of the SSL model for potential performance improvement.

####################### Model Parameters #######################################
# Dimensions for the encoder and embeddings used within the x-vector architecture.
encoder_dim: 768
emb_dim: 64
out_n_neurons: 5 # Output neurons corresponding to the number of classes in the task.

# DataLoader configuration to specify how training data is batched and handled during training.
dataloader_options:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 0  # Number of workers for data loading. Use 2 for Linux, 0 for Windows compatibility.
    drop_last: False

# Configuration for the SSL model loaded from Hugging Face's Transformers library.
ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
    source: !ref <sslmodel_hub>
    output_norm: True
    freeze: !ref <freeze_ssl>
    freeze_feature_extractor: !ref <freeze_ssl_conv>
    save_path: !ref <sslmodel_folder>

# Statistical pooling layer to aggregate model outputs.
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False

# Normalization layer for mean and standard deviation adjustment of input features.
mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

# X-vector model configuration for generating embeddings from audio inputs.
embedding_model: !new:speechbrain.lobes.models.Xvector.Xvector
    in_channels: !ref <encoder_dim>
    activation: !name:torch.nn.LeakyReLU
    tdnn_blocks: 3
    tdnn_channels: [64, 64, 128]
    tdnn_kernel_sizes: [5, 2, 1]
    tdnn_dilations: [1, 2, 1]
    lin_neurons: !ref <emb_dim>

# Classifier configuration for predicting output classes from embeddings.
classifier: !new:speechbrain.lobes.models.Xvector.Classifier
    input_shape: [null, null, !ref <emb_dim>]
    activation: !name:torch.nn.LeakyReLU
    lin_blocks: 1
    lin_neurons: !ref <emb_dim>
    out_neurons: !ref <out_n_neurons>

# Epoch counter to manage the training cycles.
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

# Grouping of modules for training.
modules:
    ssl_model: !ref <ssl_model>
    mean_var_norm: !ref <mean_var_norm>
    embedding_model:  !ref <embedding_model>
    classifier: !ref <classifier>

# Module list grouping for combined optimization.
model: !new:torch.nn.ModuleList
    - [!ref <embedding_model>, !ref <classifier>]

# Log softmax activation for output normalization.
log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

# Loss function for training.
compute_cost: !name:speechbrain.nnet.losses.nll_loss

# Metric statistics for evaluating model performance.
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

# Optimizers for the main model and the SSL components.
opt_class: !name:torch.optim.Adam
    lr: !ref <lr>
ssl_opt_class: !name:torch.optim.Adam
    lr: !ref <lr_ssl>

# Learning rate schedulers for the main model and SSL model to improve training convergence.
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

lr_annealing_ssl: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr_ssl>
    improvement_threshold: 0.0025
    annealing_factor: 0.9

# Checkpointing configuration to save and recover training states.
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        ssl_model: !ref <ssl_model>
        embedding_model: !ref <embedding_model>
        classifier: !ref <classifier>
        normalizer: !ref <mean_var_norm>
        lr_annealing: !ref <lr_annealing>
        lr_annealing_ssl: !ref <lr_annealing_ssl>
        counter: !ref <epoch_counter>


In [None]:
%%file selfsupervised_xvector.py
import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml


class SelfSupervisedSpeakerCounter(sb.Brain):
    def compute_forward(self, batch, stage):
        """
        Forward pass for generating predictions from input batches.

        Parameters:
        - batch (dict): The batch of data to process.
        - stage (sb.Stage): The stage of the process (TRAIN, VALID, or TEST).

        Returns:
        - outputs (Tensor): The output predictions from the classifier.
        """

        batch = batch.to(self.device)
        wavs, lens = batch.sig

        outputs = self.modules.ssl_model(wavs, lens)
        feats = self.modules.mean_var_norm(outputs, lens)
        embeddings = self.modules.embedding_model(feats, lens)
        outputs = self.modules.classifier(embeddings)
        return outputs

    def compute_objectives(self, predictions, batch, stage):
        """
        Computes the loss for the current batch and stage.

        Parameters:
        - predictions (Tensor): The predictions made by the model.
        - batch (dict): The batch of data including labels.
        - stage (sb.Stage): The current stage (TRAIN, VALID, or TEST).

        Returns:
        - loss (Tensor): The computed loss value.
        """

        spkid, _ = batch.num_speakers_encoded
        predictions = predictions.squeeze(1)
        spkid = spkid.squeeze(1)

        loss = self.hparams.compute_cost(predictions, spkid)
        if stage != sb.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, spkid)

        return loss

    def on_stage_start(self, stage, epoch=None):
        """
        Called at the beginning of each stage to setup metrics and state.

        Parameters:
        - stage (sb.Stage): The current stage (TRAIN, VALID, or TEST).
        - epoch (int, optional): The current epoch number.
        """

        self.loss_metric = sb.utils.metric_stats.MetricStats(
            metric=sb.nnet.losses.nll_loss
        )

        if stage != sb.Stage.TRAIN:
            self.error_metrics = self.hparams.error_stats()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """
        Called at the end of each stage to summarize and log the stage results.

        Parameters:
        - stage (sb.Stage): The current stage (TRAIN, VALID, or TEST).
        - stage_loss (float): The average loss of the stage.
        - epoch (int, optional): The current epoch number, if applicable.
        """

        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss

        else:
            stats = {
                "loss": stage_loss,
                "error_rate": self.error_metrics.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            # Learning rate adjustments and logging
            old_lr, new_lr = self.hparams.lr_annealing(stats["error_rate"])
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            (
                old_lr_ssl,
                new_lr_ssl,
            ) = self.hparams.lr_annealing_ssl(stats["error_rate"])
            sb.nnet.schedulers.update_learning_rate(
                self.ssl_optimizer, new_lr_ssl
            )

            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr, "ssl_lr": old_lr_ssl},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            # Save the current checkpoint and delete previous checkpoints,
            self.checkpointer.save_and_keep_only(
                meta=stats, min_keys=["error_rate"]
            )

        # We also write statistics about test data to stdout and to logfile.
        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )

    def init_optimizers(self):
        """
        Initializes optimizers for the SSL model and the main model.
        """
        self.ssl_optimizer = self.hparams.ssl_opt_class(
            self.modules.ssl_model.parameters()
        )
        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable(
                "ssl_opt", self.ssl_optimizer
            )
            self.checkpointer.add_recoverable("optimizer", self.optimizer)

        self.optimizers_dict = {
            "model_optimizer": self.optimizer,
            "ssl_optimizer": self.ssl_optimizer,
        }


def dataio_prep(hparams):
    """
    Prepares and returns datasets for training, validation, and testing.
    Parameters:
    - hparams (dict): A dictionary of hyperparameters for data preparation.
    Returns:
    - datasets (dict): A dictionary containing 'train', 'valid', and 'test' datasets.
    """

    # Define audio pipeline
    @sb.utils.data_pipeline.takes("wav_path")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav_path):
        """
        Audio processing pipeline that loads and returns an audio signal.
        Parameters:
            - wav_path (str): Path to the audio file.
        Returns:
            - sig (Tensor): Loaded audio signal tensor.
        """
        sig = sb.dataio.dataio.read_audio(wav_path)
        return sig
    
    # Initialization of the label encoder.
    label_encoder = sb.dataio.encoder.CategoricalEncoder()

    # Define label pipeline:
    @sb.utils.data_pipeline.takes("num_speakers")
    @sb.utils.data_pipeline.provides("num_speakers", "num_speakers_encoded")
    def label_pipeline(num_speakers):
        """
        Processes and encodes the number of speakers.

        Parameters:
        - num_speakers (int): The number of speakers in the audio.

        Yields:
        - num_speakers (int): The original number of speakers.
        - num_speakers_encoded (Tensor): Encoded tensor of the number of speakers.
        """

        yield num_speakers
        num_speakers_encoded = label_encoder.encode_label_torch(num_speakers)
        yield num_speakers_encoded

    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        "test": hparams["test_annotation"],
        "test_annotation_0_spk": hparams["test_annotation_0_spk"],
        "test_annotation_1_spk": hparams["test_annotation_1_spk"],
        "test_annotation_2_spk": hparams["test_annotation_2_spk"],
        "test_annotation_3_spk": hparams["test_annotation_3_spk"],
        "test_annotation_4_spk": hparams["test_annotation_4_spk"],
    }
    for dataset in data_info:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline, label_pipeline],
            output_keys=["id", "sig", "num_speakers_encoded"],
        )
        
    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[datasets["train"]],
        output_key="num_speakers",
    )

    return datasets


# RECIPE BEGINS!
if __name__ == "__main__":

    # Reading command line arguments.
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

    # Load hyperparameters file with command-line overrides.
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # Create dataset objects "train", "valid", and "test".
    datasets = dataio_prep(hparams)

    hparams["ssl_model"] = hparams["ssl_model"].to(device=run_opts["device"])
    # freeze the feature extractor part when unfreezing
    if not hparams["freeze_ssl"] and hparams["freeze_ssl_conv"]:
        hparams["ssl_model"].model.feature_extractor._freeze_parameters()

    # Initialize the Brain object to prepare for mask training.
    spkcounter = SelfSupervisedSpeakerCounter(
        modules=hparams["modules"],
        opt_class=hparams["opt_class"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    spkcounter.fit(
        epoch_counter=spkcounter.hparams.epoch_counter,
        train_set=datasets["train"],
        valid_set=datasets["valid"],
        train_loader_kwargs=hparams["dataloader_options"],
        valid_loader_kwargs=hparams["dataloader_options"],
    )

    test_stats = spkcounter.evaluate(
        test_set=datasets["test"],
        min_key="error_rate",
        test_loader_kwargs=hparams["dataloader_options"],
    )
    """
    To get test accuracy on each class uncomment the code below.
    """
    # print("Evaluating on no spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_0_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    #
    # print("Evaluating on 1 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_1_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    #
    # print("Evaluating on 2 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_2_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    #
    # print("Evaluating on 3 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_3_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    #
    # print("Evaluating on 4 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_4_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )


In [None]:
import torch
torch.cuda.set_device("cuda:0")
!python selfsupervised_xvector.py hparams_selfsupervised_xvector.yaml

## Linear Classifier

In [None]:
%%file hparams_selfsupervised_mlp.yaml

# Seed for reproducibility of results. Must be set before initializing model components that depend on randomness.
seed: 1993
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Paths for saving outputs and logs from the model training process.
data_folder: ../data  # e.g., /path/to/data
output_folder: !ref ../results/train_with_wav2vec2/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# SSL (Self-Supervised Learning) Model Configuration
# We choose the base model of wav2vec2 which is not fine-tuned to demonstrate the generalizability and potential improvements.
sslmodel_hub: facebook/wav2vec2-base
sslmodel_folder: !ref <save_folder>/ssl_checkpoint

# Directories for storing processed data and annotations.
train_annotation: !ref <data_folder>/train_data.json
valid_annotation: !ref <data_folder>/valid_data.json
test_annotation: !ref <data_folder>/test_data.json
test_annotation_0_spk: !ref <data_folder>/test_files_no_spk_data.json
test_annotation_1_spk: !ref <data_folder>/test_files_1_spk_data.json
test_annotation_2_spk: !ref <data_folder>/test_files_2_spk_data.json
test_annotation_3_spk: !ref <data_folder>/test_files_3_spk_data.json
test_annotation_4_spk: !ref <data_folder>/test_files_4_spk_data.json


# Logger for recording training progress and statistics.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>


####################### Training Parameters ####################################
number_of_epochs: 5
batch_size: 64
lr: 0.001 # Learning rate for the model optimizer.
lr_ssl: 0.0001 # Learning rate specific to the SSL model components.

# Control the freezing of model layers to fine-tune specific components.
freeze_ssl: False # Freeze all layers of the SSL model.
freeze_ssl_conv: True # Only freeze convolutional layers of the SSL model for potential performance improvement.

####################### Model Parameters #######################################
# Dimensions for the encoder and embeddings used within the x-vector architecture.
encoder_dim: 768
out_n_neurons: 5

# DataLoader configuration to specify how training data is batched and handled during training.
dataloader_options:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 0  # Number of workers for data loading. Use 2 for Linux, 0 for Windows compatibility.
    drop_last: False

# Configuration for the SSL model loaded from Hugging Face's Transformers library.
ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
    source: !ref <sslmodel_hub>
    output_norm: True
    freeze: !ref <freeze_ssl>
    freeze_feature_extractor: !ref <freeze_ssl_conv>
    save_path: !ref <sslmodel_folder>

# Statistical pooling layer to aggregate model outputs.
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <encoder_dim>
    n_neurons: !ref <out_n_neurons>
    bias: False

# Epoch counter to manage the training cycles.
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

# Grouping of modules for training.
modules:
    ssl_model: !ref <ssl_model>
    output_mlp: !ref <output_mlp>

# Module list grouping for combined optimization.
model: !new:torch.nn.ModuleList
    - [!ref <output_mlp>]
    -
# Log softmax activation for output normalization.
log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

# Loss function for training.
compute_cost: !name:speechbrain.nnet.losses.nll_loss

# Metric statistics for evaluating model performance.
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

# Optimizers for the main model and the SSL components.
opt_class: !name:torch.optim.Adam
    lr: !ref <lr>

ssl_opt_class: !name:torch.optim.Adam
    lr: !ref <lr_ssl>

# Learning rate schedulers for the main model and SSL model to improve training convergence.
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

lr_annealing_ssl: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr_ssl>
    improvement_threshold: 0.0025
    annealing_factor: 0.9

# Checkpointing configuration to save and recover training states.
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        ssl_model: !ref <ssl_model>
        lr_annealing_output: !ref <lr_annealing>
        lr_annealing_ssl: !ref <lr_annealing_ssl>
        counter: !ref <epoch_counter>



In [None]:
%%file selfsupervised_mlp.py
import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml


class SelfSupervisedSpeakerCounter(sb.Brain):
    def compute_forward(self, batch, stage):
        """
        Forward pass for generating predictions from input batches.

        Parameters:
        - batch (dict): The batch of data to process.
        - stage (sb.Stage): The stage of the process (TRAIN, VALID, or TEST).

        Returns:
        - outputs (Tensor): The output predictions from the classifier.
        """
        batch = batch.to(self.device)
        wavs, lens = batch.sig

        outputs = self.modules.ssl_model(wavs, lens)
        outputs = self.hparams.avg_pool(outputs, lens)
        outputs = outputs.view(outputs.shape[0], -1)

        outputs = self.modules.output_mlp(outputs)
        outputs = self.hparams.log_softmax(outputs)
        return outputs

    def compute_objectives(self, predictions, batch, stage):
        """
        Computes the loss for the current batch and stage.

        Parameters:
        - predictions (Tensor): The predictions made by the model.
        - batch (dict): The batch of data including labels.
        - stage (sb.Stage): The current stage (TRAIN, VALID, or TEST).

        Returns:
        - loss (Tensor): The computed loss value.
        """

        spkid, _ = batch.num_speakers_encoded

        """to meet the input form of nll loss"""
        spkid = spkid.squeeze(1)
        loss = self.hparams.compute_cost(predictions, spkid)
        if stage != sb.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, spkid)

        return loss

    def on_stage_start(self, stage, epoch=None):
        """
        Called at the beginning of each stage to setup metrics and state.

        Parameters:
        - stage (sb.Stage): The current stage (TRAIN, VALID, or TEST).
        - epoch (int, optional): The current epoch number.
        """

        self.loss_metric = sb.utils.metric_stats.MetricStats(
            metric=sb.nnet.losses.nll_loss
        )

        if stage != sb.Stage.TRAIN:
            self.error_metrics = self.hparams.error_stats()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """
        Called at the end of each stage to summarize and log the stage results.

        Parameters:
        - stage (sb.Stage): The current stage (TRAIN, VALID, or TEST).
        - stage_loss (float): The average loss of the stage.
        - epoch (int, optional): The current epoch number, if applicable.
        """

        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss

        else:
            stats = {
                "loss": stage_loss,
                "error_rate": self.error_metrics.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            # Learning rate adjustments and logging
            old_lr, new_lr = self.hparams.lr_annealing(stats["error_rate"])
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            (
                old_lr_ssl,
                new_lr_ssl,
            ) = self.hparams.lr_annealing_ssl(stats["error_rate"])
            sb.nnet.schedulers.update_learning_rate(
                self.ssl_optimizer, new_lr_ssl
            )

            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr, "ssl_lr": old_lr_ssl},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            self.checkpointer.save_and_keep_only(
                meta=stats, min_keys=["error_rate"]
            )

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )

    def init_optimizers(self):
        """
        Initializes optimizers for the SSL model and the main model.
        """
        self.ssl_optimizer = self.hparams.ssl_opt_class(
            self.modules.ssl_model.parameters()
        )
        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable(
                "ssl_opt", self.ssl_optimizer
            )
            self.checkpointer.add_recoverable("optimizer", self.optimizer)

        self.optimizers_dict = {
            "model_optimizer": self.optimizer,
            "ssl_optimizer": self.ssl_optimizer,
        }


def dataio_prep(hparams):
    """
    Prepares and returns datasets for training, validation, and testing.
    Parameters:
    - hparams (dict): A dictionary of hyperparameters for data preparation.
    Returns:
    - datasets (dict): A dictionary containing 'train', 'valid', and 'test' datasets.
    """

    # Define audio pipeline
    @sb.utils.data_pipeline.takes("wav_path")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav_path):
        """
        Audio processing pipeline that loads and returns an audio signal.
        Parameters:
            - wav_path (str): Path to the audio file.
        Returns:
            - sig (Tensor): Loaded audio signal tensor.
        """

        sig = sb.dataio.dataio.read_audio(wav_path)
        return sig

    # Label Encoder
    label_encoder = sb.dataio.encoder.CategoricalEncoder()

    # Define label pipeline:
    @sb.utils.data_pipeline.takes("num_speakers")
    @sb.utils.data_pipeline.provides("num_speakers", "num_speakers_encoded")
    def label_pipeline(num_speakers):
        """
        Processes and encodes the number of speakers.
        Parameters:
             - num_speakers (int): The number of speakers in the audio.
        Yields:
            - num_speakers (int): The original number of speakers.
            - num_speakers_encoded (Tensor): Encoded tensor of the number of speakers.
        """
        yield num_speakers
        num_speakers_encoded = label_encoder.encode_label_torch(num_speakers)
        yield num_speakers_encoded

    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        # "test": hparams["test_annotation"],
        "test_annotation_0_spk": hparams["test_annotation_0_spk"],
        "test_annotation_1_spk": hparams["test_annotation_1_spk"],
        "test_annotation_2_spk": hparams["test_annotation_2_spk"],
        "test_annotation_3_spk": hparams["test_annotation_3_spk"],
        "test_annotation_4_spk": hparams["test_annotation_4_spk"],
    }
    for dataset in data_info:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline, label_pipeline],
            output_keys=["id", "sig", "num_speakers_encoded"],
        )

    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[datasets["train"]],
        output_key="num_speakers",
    )

    return datasets


# RECIPE BEGINS!
if __name__ == "__main__":

    # Reading command line arguments.
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

    # Load hyperparameters file with command-line overrides.
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # Create dataset objects "train", "valid", and "test".
    datasets = dataio_prep(hparams)

    hparams["ssl_model"] = hparams["ssl_model"].to(device=run_opts["device"])
    # freeze the feature extractor part when unfreezing
    if not hparams["freeze_ssl"] and hparams["freeze_ssl_conv"]:
        hparams["ssl_model"].model.feature_extractor._freeze_parameters()

    # Initialize the Brain object to prepare for mask training.
    spkcounter = SelfSupervisedSpeakerCounter(
        modules=hparams["modules"],
        opt_class=hparams["opt_class"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    spkcounter.fit(
        epoch_counter=spkcounter.hparams.epoch_counter,
        train_set=datasets["train"],
        valid_set=datasets["valid"],
        train_loader_kwargs=hparams["dataloader_options"],
        valid_loader_kwargs=hparams["dataloader_options"],
    )

    # Load the best checkpoint for evaluation
    print("Evaluating on all classes")
    test_stats = spkcounter.evaluate(
        test_set=datasets["test"],
        min_key="error_rate",
        test_loader_kwargs=hparams["dataloader_options"],
    )
    
    """
    To get test accuracy on each class uncomment the code below.
    """

    # print("Evaluating on no spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_0_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    # 
    # print("Evaluating on 1 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_1_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    # 
    # print("Evaluating on 2 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_2_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    # 
    # print("Evaluating on 3 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_3_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )
    # 
    # print("Evaluating on 4 spk class")
    # test_stats = spkcounter.evaluate(
    #     test_set=datasets["test_annotation_4_spk"],
    #     min_key="error_rate",
    #     test_loader_kwargs=hparams["dataloader_options"],
    # )

In [None]:
import torch
torch.cuda.set_device("cuda:0")
!python selfsupervised_mlp.py hparams_selfsupervised_mlp.yaml