In [13]:
%%file trainmoespeed.py
#!/usr/bin/env/python3
"""Recipe for training a neural speech separation system on Libri2/3Mix datasets.
The system employs an encoder, a decoder, and a masking network.

To run this recipe, do the following:
> python train.py hparams/sepformer-libri2mix.yaml
> python train.py hparams/sepformer-libri3mix.yaml


The experiment file is flexible enough to support different neural
networks. By properly changing the parameter files, you can try
different architectures. The script supports both libri2mix and
libri3mix.


Authors
 * Cem Subakan 2020
 * Mirco Ravanelli 2020
 * Samuele Cornell 2020
 * Mirko Bronzi 2020
 * Jianyuan Zhong 2020
"""

import csv
import os
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm

import speechbrain as sb
import speechbrain.nnet.schedulers as schedulers
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


# Define training procedure
class Separation(sb.Brain):
    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""

        # Unpack lists and put tensors in the right device
        mix, mix_lens = mix
        mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
        # Convert targets to tensor
        targets = torch.cat(
            [targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)

        # Add speech distortions
        if stage == sb.Stage.TRAIN:
            with torch.no_grad():
                if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
                    mix, targets = self.add_speed_perturb(targets, mix_lens)

                    mix = targets.sum(-1)

                if self.hparams.use_wavedrop:
                    mix = self.hparams.drop_chunk(mix, mix_lens)
                    mix = self.hparams.drop_freq(mix)

                if self.hparams.limit_training_signal_len:
                    mix, targets = self.cut_signals(mix, targets)

        # Separation
        mix_w = self.hparams.Encoder(mix)
        est_mask = self.hparams.MaskNet(mix_w)
        
        est_mask = est_mask.permute(1, 0, 2, 3)
        g1_scores = self.hparams.SPK_Scores1(est_mask[:,0:10, :, :])
        g2_scores = self.hparams.SPK_Scores2(est_mask[:,10:20, :, :])
        est_mask = est_mask.permute(1, 0, 2, 3)
        topk_indices = torch.topk(g1_scores, k=4, dim=1).indices
        selected_masks = torch.squeeze(est_mask[topk_indices, :, :, :],0)
        group1_masks = torch.prod(selected_masks, dim=0)
        topk_indices = torch.topk(g2_scores, k=4, dim=1).indices + 10
        selected_masks = torch.squeeze(est_mask[topk_indices, :, :, :],0)
        group2_masks = torch.prod(selected_masks, dim=0)
        est_mask = torch.stack([group1_masks,group2_masks])
        
        mix_w = torch.stack([mix_w] * self.hparams.num_spks)
        sep_h = mix_w * est_mask

        # Decoding
        est_source = torch.cat(
            [
                self.hparams.Decoder(sep_h[i]).unsqueeze(-1)
                for i in range(self.hparams.num_spks)
            ],
            dim=-1,
        )

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]

        return est_source, targets

    def compute_objectives(self, predictions, targets):
        """Computes the si-snr loss"""
        return self.hparams.loss(targets, predictions)

    def fit_batch(self, batch):
        """Trains one batch"""

        # Unpacking batch list
        mixture = batch.mix_sig
        targets = [batch.vocals, batch.accompaniment]
        noise = None
        with self.training_ctx:
            predictions, targets = self.compute_forward(
                mixture, targets, sb.Stage.TRAIN, noise
            )
            loss = self.compute_objectives(predictions, targets)

            # hard threshold the easy dataitems
            if self.hparams.threshold_byloss:
                th = self.hparams.threshold
                loss = loss[loss > th]
                if loss.nelement() > 0:
                    loss = loss.mean()
            else:
                loss = loss.mean()

        if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
            self.scaler.scale(loss).backward()
            if self.hparams.clip_grad_norm >= 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self.modules.parameters(),
                    self.hparams.clip_grad_norm,
                )
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.nonfinite_count += 1
            logger.info(
                "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                    self.nonfinite_count
                )
            )
            loss.data = torch.tensor(0.0).to(self.device)
        self.optimizer.zero_grad()

        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        snt_id = batch.name
        mixture = batch.mix_sig
        targets = [batch.vocals, batch.accompaniment]

        with torch.no_grad():
            predictions, targets = self.compute_forward(mixture, targets, stage)
            loss = self.compute_objectives(predictions, targets)

        # Manage audio file saving
        if stage == sb.Stage.TEST and self.hparams.save_audio:
            if hasattr(self.hparams, "n_audio_to_save"):
                if self.hparams.n_audio_to_save > 0:
                    self.save_audio(snt_id, mixture, targets, predictions)
                    self.hparams.n_audio_to_save += -1
            else:
                self.save_audio(snt_id, mixture, targets, predictions)

        return loss.mean().detach()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"si-snr": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            # Learning rate annealing
            if isinstance(
                self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
            ):
                current_lr, next_lr = self.hparams.lr_scheduler(
                    [self.optimizer], epoch, stage_loss
                )
                schedulers.update_learning_rate(self.optimizer, next_lr)
            else:
                # if we do not use the reducelronplateau, we do not change the lr
                current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]

            self.hparams.train_logger.log_stats(
                stats_meta={"epoch": epoch, "lr": current_lr},
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"si-snr": stage_stats["si-snr"]},
                min_keys=["si-snr"],
            )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
    def add_speed_perturb(self, targets, targ_lens):
        original_len = targets.shape[1]
        num_sources = targets.shape[-1]
        processed_targets_list = []
    
        for i in range(num_sources):
            current_source = targets[:, :, i]
    
            processed_source = current_source
    
            if self.hparams.use_speedperturb:
                processed_source = self.hparams.speed_perturb(current_source)
    
                if self.hparams.use_rand_shift:
                     max_shift_actual = min(self.hparams.max_shift, processed_source.shape[1])
                     min_shift_actual = max(self.hparams.min_shift, -processed_source.shape[1])
    
                     if max_shift_actual > min_shift_actual:
                         rand_shift = torch.randint(
                             min_shift_actual, max_shift_actual + 1, (1,),
                             device=processed_source.device
                         )
                         processed_source = torch.roll(
                             processed_source, shifts=(rand_shift[0].item(),), dims=1
                         )
    
            current_len = processed_source.shape[1]
    
            if current_len > original_len:
                processed_source = processed_source[:, :original_len]
    
            elif current_len < original_len:
                padding_needed = original_len - current_len
                padding_sequence = processed_source[:, -padding_needed:]
                processed_source = torch.cat((processed_source, padding_sequence), dim=1)
    
            processed_targets_list.append(processed_source.to(targets.device))
    
        targets = torch.stack(processed_targets_list, dim=-1)
    
        mix = targets.sum(dim=-1)
    
        return mix, targets


    def cut_signals(self, mixture, targets):
        """This function selects a random segment of a given length within the mixture.
        The corresponding targets are selected accordingly"""
        randstart = torch.randint(
            0,
            1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
            (1,),
        ).item()
        targets = targets[
            :, randstart : randstart + self.hparams.training_signal_len, :
        ]
        mixture = mixture[
            :, randstart : randstart + self.hparams.training_signal_len
        ]
        return mixture, targets

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the neural networks"""
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

    def save_results(self, test_data):
        """This script computes the SDR and SI-SNR metrics and saves
        them into a csv file"""

        # This package is required for SDR computation
        from mir_eval.separation import bss_eval_sources

        # Create folders where to store audio
        save_file = os.path.join(self.hparams.output_folder, "test_results.csv")

        # Variable init
        all_sdrs = []
        all_sdrs_i = []
        all_sisnrs = []
        all_sisnrs_i = []
        csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]

        test_loader = sb.dataio.dataloader.make_dataloader(
            test_data, **self.hparams.dataloader_opts
        )

        with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
            writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
            writer.writeheader()

            # Loop over all test sentence
            with tqdm(test_loader, dynamic_ncols=True) as t:
                for i, batch in enumerate(t):
                    # Apply Separation
                    mixture, mix_len = batch.mix_sig
                    snt_id = batch.name
                    targets = [batch.vocals, batch.accompaniment]

                    with torch.no_grad():
                        predictions, targets = self.compute_forward(
                            batch.mix_sig, targets, sb.Stage.TEST
                        )

                    # Compute SI-SNR
                    sisnr = self.compute_objectives(predictions, targets)

                    # Compute SI-SNR improvement
                    mixture_signal = torch.stack(
                        [mixture] * self.hparams.num_spks, dim=-1
                    )
                    mixture_signal = mixture_signal.to(targets.device)
                    sisnr_baseline = self.compute_objectives(
                        mixture_signal, targets
                    )
                    sisnr_i = sisnr - sisnr_baseline
                    
                    
                    try:
                        # Compute SDR
                        sdr, _, _, _ = bss_eval_sources(
                            targets[0].t().cpu().numpy(),
                            predictions[0].t().detach().cpu().numpy(),
                        )
    
                        sdr_baseline, _, _, _ = bss_eval_sources(
                            targets[0].t().cpu().numpy(),
                            mixture_signal[0].t().detach().cpu().numpy(),
                        )

                        sdr_i = sdr.mean() - sdr_baseline.mean()

                        # Saving on a csv file
                        row = {
                            "snt_id": snt_id,
                            "sdr": sdr.mean(),
                            "sdr_i": sdr_i,
                            "si-snr": -sisnr.item(),
                            "si-snr_i": -sisnr_i.item(),
                        }
                        writer.writerow(row)

                        # Metric Accumulation
                        all_sdrs.append(sdr.mean())
                        all_sdrs_i.append(sdr_i.mean())
                        all_sisnrs.append(-sisnr.item())
                        all_sisnrs_i.append(-sisnr_i.item())
                    except ValueError as e:
                        # Catch potential mir_eval errors that might still occur in edge cases
                        print(f"Error processing sample {snt_id}: {e}")

                row = {
                    "snt_id": "avg",
                    "sdr": np.array(all_sdrs).mean(),
                    "sdr_i": np.array(all_sdrs_i).mean(),
                    "si-snr": np.array(all_sisnrs).mean(),
                    "si-snr_i": np.array(all_sisnrs_i).mean(),
                }
                writer.writerow(row)

        logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
        logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
        logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
        logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))

    def save_audio(self, snt_id, mixture, targets, predictions):
        "saves the test audio (mixture, targets, and estimated sources) on disk"

        # Create output folder
        save_path = os.path.join(self.hparams.save_folder, "audio_results")
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        for ns in range(self.hparams.num_spks):
            # Estimated source
            signal = predictions[0, :, ns]
            signal = signal / signal.abs().max()
            save_file = os.path.join(
                save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
            )
            torchaudio.save(
                save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
            )

            # Original source
            signal = targets[0, :, ns]
            signal = signal / signal.abs().max()
            save_file = os.path.join(
                save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
            )
            torchaudio.save(
                save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
            )

        # Mixture
        signal = mixture[0][0, :]
        signal = signal / signal.abs().max()
        save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
        torchaudio.save(
            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
        )


def dataio_prep(hparams):
    """Creates data processing pipeline"""
    import musdb
    import random
    mus_train = musdb.DB(root="/notebooks/musdb18",subsets="train", split='train', sample_rate = hparams['sample_rate'])
    mus_valid = musdb.DB(root="/notebooks/musdb18",subsets="train", split='valid', sample_rate = hparams['sample_rate'])
    mus_test = musdb.DB(root="/notebooks/musdb18", subsets="test", sample_rate = hparams['sample_rate'])
    train_data = {}
    valid_data= {}
    test_data = {}
    for track in mus_train:
        dataobj={}
        dataobj['track'] = track
        train_data[track.name] = dataobj

    for track in mus_valid:
        dataobj={}
        dataobj['track'] = track
        valid_data[track.name] = dataobj

    for track in mus_test:
        dataobj={}
        dataobj['track'] = track
        test_data[track.name] = dataobj

    datasets = [
        sb.dataio.dataset.DynamicItemDataset(train_data),
        sb.dataio.dataset.DynamicItemDataset(valid_data),
        sb.dataio.dataset.DynamicItemDataset(test_data)
    ]

    @sb.utils.data_pipeline.takes("track")
    @sb.utils.data_pipeline.provides("name","mix_sig", "vocals","accompaniment")
    def audio_pipeline_mix(track):
        name = track.name
        track.chunk_duration = hparams["audio_length"]
        track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)

        mix_sig = torch.from_numpy(track.audio.T)[1].float()

        vocals = torch.from_numpy(track.sources['vocals'].audio.T)[1].float()

        accompaniment = torch.from_numpy(track.targets['accompaniment'].audio.T)[1].float()
        return name,mix_sig, vocals,accompaniment

    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_mix)
    sb.dataio.dataset.set_output_keys(
        datasets, ["name", "mix_sig", "vocals", "accompaniment"]
    )

    return datasets[0], datasets[1], datasets[2]

if __name__ == "__main__":
    # Load hyperparameters file with command-line overrides
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    with open(hparams_file, encoding="utf-8") as fin:
        hparams = load_hyperpyyaml(fin, overrides)
    run_opts['device']="cuda"
    print(run_opts)
    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # Update precision to bf16 if the device is CPU and precision is fp16
    if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
        hparams["precision"] = "bf16"
    
    train_data, valid_data, test_data = dataio_prep(hparams)

    # Load pretrained model if pretrained_separator is present in the yaml
    if "pretrained_separator" in hparams:
        run_on_main(hparams["pretrained_separator"].collect_files)
        hparams["pretrained_separator"].load_collected()

    # Brain class initialization
    separator = Separation(
        modules=hparams["modules"],
        opt_class=hparams["optimizer"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )
    
    # re-initialize the parameters if we don't use a pretrained model
    if "pretrained_separator" not in hparams:
        for module in separator.modules.values():
            separator.reset_layer_recursively(module)

    # Training
    separator.fit(
        separator.hparams.epoch_counter,
        train_data,
        valid_data,
        train_loader_kwargs=hparams["dataloader_opts"],
        valid_loader_kwargs=hparams["dataloader_opts"],
    )

    # Eval
    separator.evaluate(test_data, min_key="si-snr")
    separator.save_results(test_data)

Overwriting trainmoespeed.py


In [10]:
%%file Transformermoespeed.yaml
# ################################
# Model: SepFormer for source separation
# https://arxiv.org/abs/2010.13154
# Dataset : Libri2mix
# ################################
#
# Basic parameters
# Seed needs to be set at top of yaml, before objects with parameters are made
#
seed: 1234
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]

# Data params

# e.g. '/yourpath/Libri2Mix/train-clean-360/'
# the data folder is needed even if dynamic mixing is applied
data_folder: /yourpath/Libri2Mix/train-clean-360/

experiment_name: moespeed-scored-former-libri2mix
output_folder: !ref results/<experiment_name>/<seed>
train_log: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
train_data: !ref <save_folder>/libri2mix_train-360.csv
valid_data: !ref <save_folder>/libri2mix_dev.csv
test_data: !ref <save_folder>/libri2mix_test.csv
skip_prep: False

ckpt_interval_minutes: 60

# Experiment params
precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
num_spks: 2
noprogressbar: False
save_audio: True # Save estimated sources on disk
sample_rate: 16000

####################### Training Parameters ####################################
N_epochs: 100
batch_size: 1
lr: 0.00015
clip_grad_norm: 5
loss_upper_lim: 999999  # this is the upper limit for an acceptable loss
# if True, the training sequences are cut to a specified length
limit_training_signal_len: False
# this is the length of sequences if we choose to limit
# the signal length of training sequences
training_signal_len: 57000
audio_length: 6


# Parameters for data augmentation
use_wavedrop: True
use_speedperturb: True
use_rand_shift: True
min_shift: -8000
max_shift: 8000

# Speed perturbation
speed_changes: [95, 100, 105]  # List of speed changes for time-stretching

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
    orig_freq: !ref <sample_rate>
    speeds: !ref <speed_changes>

# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq_low: 0  # Min frequency band dropout probability
drop_freq_high: 1  # Max frequency band dropout probability
drop_freq_count_low: 1  # Min number of frequency bands to drop
drop_freq_count_high: 3  # Max number of frequency bands to drop
drop_freq_width: 0.05  # Width of frequency bands to drop

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: !ref <drop_freq_low>
    drop_freq_high: !ref <drop_freq_high>
    drop_freq_count_low: !ref <drop_freq_count_low>
    drop_freq_count_high: !ref <drop_freq_count_high>
    drop_freq_width: !ref <drop_freq_width>

# Time drop: randomly drops a number of temporal chunks.
drop_chunk_count_low: 1  # Min number of audio chunks to drop
drop_chunk_count_high: 5  # Max number of audio chunks to drop
drop_chunk_length_low: 1000  # Min length of audio chunks to drop
drop_chunk_length_high: 2000  # Max length of audio chunks to drop

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
    drop_length_low: !ref <drop_chunk_length_low>
    drop_length_high: !ref <drop_chunk_length_high>
    drop_count_low: !ref <drop_chunk_count_low>
    drop_count_high: !ref <drop_chunk_count_high>


# loss thresholding -- this thresholds the training loss
threshold_byloss: True
threshold: -30

# Encoder parameters
N_encoder_out: 256
out_channels: 256
kernel_size: 16
kernel_stride: 8
d_ffn: 1024
dropout: 0.5
dnn_neurons: 512

# Dataloader options
dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: 8


# Specifying the network
Encoder: !new:speechbrain.lobes.models.dual_path.Encoder
    kernel_size: !ref <kernel_size>
    out_channels: !ref <N_encoder_out>


SBtfintra: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
    num_layers: 8
    d_model: !ref <out_channels>
    nhead: 8
    d_ffn: !ref <d_ffn>
    dropout: 0
    use_positional_encoding: True
    norm_before: True

SBtfinter: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
    num_layers: 8
    d_model: !ref <out_channels>
    nhead: 8
    d_ffn: !ref <d_ffn>
    dropout: 0
    use_positional_encoding: True
    norm_before: True

MaskNet: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
    num_spks: !ref <num_spks> * 10
    in_channels: !ref <N_encoder_out>
    out_channels: !ref <out_channels>
    num_layers: 2
    K: 250
    intra_model: !ref <SBtfintra>
    inter_model: !ref <SBtfinter>
    norm: ln
    linear_layer_after_inter_intra: False
    skip_around_intra: True

Decoder: !new:speechbrain.lobes.models.dual_path.Decoder
    in_channels: !ref <N_encoder_out>
    out_channels: 1
    kernel_size: !ref <kernel_size>
    stride: !ref <kernel_stride>
    bias: False

SPK_Scores1: !new:speechbrain.nnet.containers.Sequential
    input_shape: [1, 10, 256 ,11999] 
    linear1: !name:speechbrain.nnet.linear.Linear
        n_neurons: !ref <dnn_neurons>
        bias: True
    activation: !new:torch.nn.LeakyReLU
    drop: !new:torch.nn.Dropout
        p: !ref <dropout>
    linear2: !name:speechbrain.nnet.linear.Linear
        n_neurons: !ref <dnn_neurons>
        bias: True
    activation2: !new:torch.nn.LeakyReLU
    drop2: !new:torch.nn.Dropout
        p: !ref <dropout>

    flatten: !new:torch.nn.Flatten
        start_dim: 1

    linear_out: !name:speechbrain.nnet.linear.Linear
        n_neurons: 10
        bias: True

SPK_Scores2: !new:speechbrain.nnet.containers.Sequential
    input_shape: [1, 10, 256 ,11999] # Input shape: [Batch, Channel, Feature1, Feature2]
    linear1: !name:speechbrain.nnet.linear.Linear
        n_neurons: !ref <dnn_neurons>
        bias: True
    # bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
    activation: !new:torch.nn.LeakyReLU
    drop: !new:torch.nn.Dropout
        p: !ref <dropout>
    linear2: !name:speechbrain.nnet.linear.Linear
        n_neurons: !ref <dnn_neurons>
        bias: True
    # bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
    activation2: !new:torch.nn.LeakyReLU
    drop2: !new:torch.nn.Dropout
        p: !ref <dropout>

    # --- Layers added to achieve [Batch, 10] output ---
    # At this point, the shape is [10, 1, 256, dnn_neurons]
    flatten: !new:torch.nn.Flatten
        # Flatten from the dimension after the batch dimension (dim 1)
        # This collapses [1, 256, dnn_neurons] into a single dimension
        start_dim: 1
    # Shape after flatten: [10, 1 * 256 * dnn_neurons] (e.g., [10, 256 * dnn_neurons])

    linear_out: !name:speechbrain.nnet.linear.Linear
        # This final linear layer projects the flattened features to 10 scores
        n_neurons: 10
        bias: True
    # Shape after linear_out: [10, 10] (Matches the desired output shape)
    
optimizer: !name:torch.optim.Adam
    lr: !ref <lr>
    weight_decay: 0

loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper

lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
    factor: 0.5
    patience: 2
    dont_halve_until_epoch: 5

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <N_epochs>

modules:
    encoder: !ref <Encoder>
    decoder: !ref <Decoder>
    masknet: !ref <MaskNet>
    spk_scores1: !ref <SPK_Scores1>
    spk_scores2: !ref <SPK_Scores2>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        encoder: !ref <Encoder>
        decoder: !ref <Decoder>
        masknet: !ref <MaskNet>
        spk_scores1: !ref <SPK_Scores1>
        spk_scores2: !ref <SPK_Scores2>
        counter: !ref <epoch_counter>
        # lr_scheduler: !ref <lr_scheduler>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch


Overwriting Transformermoespeed.yaml


In [16]:
# !pip install museval openunmix
# !python -m openunmix.evaluate --outdir /path/to/musdb/estimates --evaldir /path/to/museval/results

In [15]:
!zip -r /notebooks/file.zip /notebooks/results/wavformer-libri2mix/1234/save/audio_results

  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/ (stored 0%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item['The Easton Ellises - Falcon 69']_source2hat.wav (deflated 17%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item['Zeno - Signs']_source2.wav (deflated 7%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item["Juliet's Rescue - Heartbeats"]_source1.wav (deflated 4%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item['Skelpolu - Resurrection']_source1.wav (deflated 100%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item['Secretariat - Over The Top']_source1hat.wav (deflated 11%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item['PR - Oh No']_mix.wav (deflated 12%)
  adding: notebooks/results/wavformer-libri2mix/1234/save/audio_results/item['Signe Jakobsen - What Have You Done To Me']_source2.wav (de