In [5]:
import os
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

In [6]:
class EmoIdBrain(sb.Brain):
    def compute_forward(self, batch, stage):
        """Computation pipeline based on a encoder + emotion classifier.
        """
        batch = batch.to(self.device)
        wavs, lens = batch.sig

        outputs = self.modules.wav2vec2(wavs)

        # last dim will be used for AdaptativeAVG pool
        outputs = self.hparams.avg_pool(outputs, lens)
        outputs = outputs.view(outputs.shape[0], -1)

        outputs = self.modules.output_mlp(outputs)
        outputs = self.hparams.log_softmax(outputs)
        return outputs

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss using speaker-id as label.
        """
        emoid, _ = batch.emo_encoded

        """to meet the input form of nll loss"""
        emoid = emoid.squeeze(1)
        loss = self.hparams.compute_cost(predictions, emoid)
        if stage != sb.Stage.TRAIN:
            self.error_metrics.append(batch.id, predictions, emoid)
            self.acc_metrics.append(predictions, emoid)

        return loss

    def fit_batch(self, batch):
        """Trains the parameters given a single batch in input"""

        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
        loss.backward()
        if self.check_gradients(loss):
            self.wav2vec2_optimizer.step()
            self.optimizer.step()

        self.wav2vec2_optimizer.zero_grad()
        self.optimizer.zero_grad()

        return loss.detach()

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch.
        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Set up statistics trackers for this stage
        self.loss_metric = sb.utils.metric_stats.MetricStats(
            metric=sb.nnet.losses.nll_loss
        )

        # Set up evaluation-only statistics trackers
        if stage != sb.Stage.TRAIN:
            self.error_metrics = self.hparams.error_stats()
            self.acc_metrics = self.hparams.acc_stats()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch.
        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
        stage_loss : float
            The average loss for all of the data processed in this stage.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Store the train loss until the validation stage.
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss

        # Summarize the statistics from the stage for record-keeping.
        else:
            stats = {
                "loss": stage_loss,
                "error_rate": self.error_metrics.summarize("average"),
                "accuracy": self.acc_metrics.summarize()*100
            }

        # At the end of validation...
        if stage == sb.Stage.VALID:

            old_lr, new_lr = self.hparams.lr_annealing(stats["error_rate"])
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            (
                old_lr_wav2vec2,
                new_lr_wav2vec2,
            ) = self.hparams.lr_annealing_wav2vec2(stats["error_rate"])
            sb.nnet.schedulers.update_learning_rate(
                self.wav2vec2_optimizer, new_lr_wav2vec2
            )

            # The train_logger writes a summary to stdout and to the logfile.
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr, "wave2vec_lr": old_lr_wav2vec2},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )

            # Save the current checkpoint and delete previous checkpoints,
            # self.checkpointer.save_and_keep_only(
            #     meta=stats, min_keys=["error_rate"]
            # )

            # Save the current checkpoint
            self.checkpointer.save_checkpoint(meta=stats)

        # We also write statistics about test data to stdout and to logfile.
        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )

    def init_optimizers(self):
        "Initializes the wav2vec2 optimizer and model optimizer"
        self.wav2vec2_optimizer = self.hparams.wav2vec2_opt_class(
            self.modules.wav2vec2.parameters()
        )
        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable(
                "wav2vec2_opt", self.wav2vec2_optimizer
            )
            self.checkpointer.add_recoverable("optimizer", self.optimizer)

In [7]:
def dataio_prep(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined
    functions. We expect `prepare_mini_librispeech` to have been called before
    this, so that the `train.json`, `valid.json`,  and `valid.json` manifest
    files are available.
    Arguments
    ---------
    hparams : dict
        This dictionary is loaded from the `train.yaml` file, and it includes
        all the hyperparameters needed for dataset construction and loading.
    Returns
    -------
    datasets : dict
        Contains two keys, "train" and "valid" that correspond
        to the appropriate DynamicItemDataset object.
    """

    # Define audio pipeline
    @sb.utils.data_pipeline.takes("wav")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav):
        """Load the signal, and pass it and its length to the corruption class.
        This is done on the CPU in the `collate_fn`."""
        sig = sb.dataio.dataio.read_audio(wav)
        return sig

    # Initialization of the label encoder. The label encoder assignes to each
    # of the observed label a unique index (e.g, 'spk01': 0, 'spk02': 1, ..)
    label_encoder = sb.dataio.encoder.CategoricalEncoder()

    # Define label pipeline:
    @sb.utils.data_pipeline.takes("emo")
    @sb.utils.data_pipeline.provides("emo", "emo_encoded")
    def label_pipeline(emo):
        yield emo
        emo_encoded = label_encoder.encode_label_torch(emo)
        yield emo_encoded

    # Define datasets. We also connect the dataset with the data processing
    # functions defined above.
    datasets = {}
    for dataset in ["train", "valid", "test"]:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=hparams[f"{dataset}_annotation"],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline, label_pipeline],
            output_keys=["id", "sig", "emo_encoded"],
        )
    # Load or compute the label encoder (with multi-GPU DDP support)
    # Please, take a look into the lab_enc_file to see the label to index
    # mappinng.

    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[datasets["train"]],
        output_key="emo",
    )

    return datasets

In [8]:


# Reading command line arguments.
hparams_file = 'hparams/train_with_wav2vec2.yaml'
run_opts = {'debug': False, 'debug_batches': 2, 'debug_epochs': 2, 'device': 'cuda:0', 'data_parallel_backend': False, 'distributed_launch': False, 'distributed_backend': 'nccl', 'find_unused_parameters': False}
overrides = " "

# Initialize ddp (useful only for multi-GPU DDP training).
sb.utils.distributed.ddp_init_group(run_opts)

# Load hyperparameters file with command-line overrides.
with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)

datasets = dataio_prep(hparams)

hparams["wav2vec2"] = hparams["wav2vec2"].to("cuda:0")
# freeze the feature extractor part when unfreezing
if not hparams["freeze_wav2vec2"] and hparams["freeze_wav2vec2_conv"]:
    hparams["wav2vec2"].model.feature_extractor._freeze_parameters()

# Initialize the Brain object to prepare for mask training.
emo_id_brain = EmoIdBrain(
    modules=hparams["modules"],
    opt_class=hparams["opt_class"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Load the best checkpoint for evaluation
test_stats = emo_id_brain.evaluate(
    test_set=datasets["test"],
    min_key="error_rate",
    test_loader_kwargs=hparams["dataloader_options"],
)

usage: ipykernel_launcher.py [-h] [--debug] [--debug_batches DEBUG_BATCHES]
                             [--debug_epochs DEBUG_EPOCHS]
                             [--log_config LOG_CONFIG]
                             [--local_rank LOCAL_RANK] [--device DEVICE]
                             [--data_parallel_backend] [--distributed_launch]
                             [--distributed_backend DISTRIBUTED_BACKEND]
                             [--find_unused_parameters]
                             [--jit_module_keys [JIT_MODULE_KEYS [JIT_MODULE_KEYS ...]]]
                             [--auto_mix_prec] [--max_grad_norm MAX_GRAD_NORM]
                             [--nonfinite_patience NONFINITE_PATIENCE]
                             [--noprogressbar]
                             [--ckpt_interval_minutes CKPT_INTERVAL_MINUTES]
                             param_file
ipykernel_launcher.py: error: argument --find_unused_parameters: ignored explicit argument 'C:\\Users\\LAYLAR~1\\AppData\\Local

Traceback (most recent call last):
  File "C:\Users\LaylarZhang\anaconda3\envs\emotion\lib\argparse.py", line 1787, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
  File "C:\Users\LaylarZhang\anaconda3\envs\emotion\lib\argparse.py", line 1993, in _parse_known_args
    start_index = consume_optional(start_index)
  File "C:\Users\LaylarZhang\anaconda3\envs\emotion\lib\argparse.py", line 1915, in consume_optional
    raise ArgumentError(action, msg % explicit_arg)
argparse.ArgumentError: argument --find_unused_parameters: ignored explicit argument 'C:\\Users\\LAYLAR~1\\AppData\\Local\\Temp\\tmp-161568E2OPeiRw7NV.json'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\LaylarZhang\anaconda3\envs\emotion\lib\site-packages\IPython\core\interactiveshell.py", line 3444, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\LAYLAR~1\AppData\Local\Temp/ipykernel_

TypeError: object of type 'NoneType' has no len()