#### Testing Dataset

In [1]:
import os
import random
import math
from glob import glob
import torch as th
import torchaudio
import pytorch_lightning as pl
from typing import Optional, List
import torch.nn.functional as F
import pandas as pd 
from tqdm.notebook import tqdm 
import numpy as np

from utils.measure_time import measure_time 

In [2]:
class AudioDataset(th.utils.data.Dataset):
    def __init__(self, mix_file_paths: List[str], ref_file_paths: List[List[str]], 
                 sr: int = 8000, chunk_size: int = 32000, least_size: int = 16000):
        super().__init__()
        self.mix_audio = []
        self.ref_audio = []
        k = len(ref_file_paths[1])
        for mix_path, ref_paths in zip(mix_file_paths, ref_file_paths):
            common_len = ref_paths[0]
            chunked_mix = self._load_audio(mix_path, sr, common_len, chunk_size, least_size)
            if not chunked_mix: continue
            ref_audio_chunks = []    
            for ref_path in ref_paths[1]:
                res = self._load_audio(ref_path, sr, common_len, chunk_size, least_size)
                ref_audio_chunks.append(res)
            if k != len(ref_audio_chunks): continue
            self.mix_audio.append(chunked_mix)
            self.ref_audio.append(ref_audio_chunks)

    def __len__(self):
        return len(self.mix_audio)

    def __getitem__(self, idx):
        mix = self.mix_audio[idx][0]
        refs = self.ref_audio[idx][0]
        # refs = [ref[idx] for ref in self.ref_audio]
        return mix, refs
    
    @staticmethod
    def _load_audio(path:str, sr: int, common_len: int, chunk_size: int, least_size: int):
        audio, _sr = torchaudio.load(path)
        audio = audio[:, :common_len]
        if _sr != sr: raise RuntimeError(f"Sample rate mismatch: {_sr} vs {sr}")
        if audio.shape[-1] < least_size: return []
        audio_chunks = []
        if least_size < audio.shape[-1] < chunk_size:
            pad_size = chunk_size - audio.shape[-1]
            audio_chunks.append(F.pad(audio, (0, pad_size), mode='constant'))
        else:
            start = 0
            while start + chunk_size <= audio.shape[-1]:
                audio = audio.squeeze() # warning
                audio_chunks.append(audio[start:start + chunk_size])
                start += least_size
        return audio_chunks 

class ExtendedAudioDataset(AudioDataset):
    def __getitems__(self, item):
        return self.__getitem__(item)

class AudioDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, csv_file:bool = False, total_percent: float = 0.1, train_percent:float = 0.8, valid_percent:float = 0.1, 
                 test_percent: float = 0.1, num_workers: int = 4, batch_size: int = 512, pin_memory = False, seed: int = 42, 
                 sample_rate: int = 8000, chunk_size: int = 32000, least_size: int = 16000):
        super().__init__()
        self.batch_size = batch_size
        self.sr = sample_rate
        self.chunk_size = chunk_size
        self.least_size = least_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.seed = seed
        self._set_seed(seed)
        self.g = th.Generator()
        self.g.manual_seed(seed)
        self.mix_paths = []
        self.ref_paths = []
        
        if csv_file:
            full_df = pd.read_csv(data_dir)
            for _, row in full_df.iterrows():
                self.mix_paths.append (row.iloc[0])
                self.ref_paths.append([row.iloc[1], sorted([row[column] for column in full_df.columns[2:]])])
        else: 
            mixed_list = glob(os.path.join(data_dir, "*.flac"))
            for mx in tqdm(mixed_list):
                mx = mx.replace('\\', '/')
                self.mix_paths.append (mx)
                mx_df = pd.read_csv(mx.replace('flac', 'csv'))
                f_real = [mx_df.iloc[0], sorted([mx_df.iloc[0][column] for column in mx_df.columns[2:]])] # THERE IS BAG need to fixed 
                self.ref_paths.append(f_real)

        self.mix_paths = self.mix_paths[:int(len(self.mix_paths) * total_percent)]
        self.ref_paths = self.ref_paths[:int(len(self.ref_paths) * total_percent)]
        random.shuffle(self.mix_paths)
        assert math.isclose(train_percent + valid_percent + test_percent, 1.0, rel_tol=1e-9), "Sum doesnt equal to 1" 
        self.train_len = int(len(self.mix_paths) * train_percent)
        self.valid_len = int(len(self.mix_paths) * valid_percent)
        self.test_len = int(len(self.mix_paths) * test_percent)

    @measure_time
    def setup(self, stage = 'train'):
        assert stage in ['train', 'eval'], "Invalid stage"
        
        if stage == 'train': 
            self.train_dataset = AudioDataset(self.mix_paths[:self.train_len], 
                                              self.ref_paths[:self.train_len], 
                                              sr = self.sr, 
                                              chunk_size = self.chunk_size, 
                                              least_size = self.least_size)
            print(f"Size of training set: {len(self.train_dataset)}")
            
            self.val_dataset = AudioDataset(self.mix_paths[self.train_len:self.train_len + self.valid_len], 
                                            self.ref_paths[self.train_len:self.train_len + self.valid_len], 
                                            sr = self.sr, 
                                            chunk_size = self.chunk_size, 
                                            least_size = self.least_size) 
            print(f"Size of validation set: {len(self.val_dataset)}")
            
        if stage == 'eval':
            self.test_dataset = AudioDataset(self.mix_paths[self.train_len + self.valid_len:], 
                                             self.ref_paths[self.train_len + self.valid_len:], 
                                             sr = self.sr, 
                                             chunk_size = self.chunk_size, 
                                             least_size = self.least_size)
            print(f"Size of test set: {len(self.test_dataset)}")

        
    def train_dataloader(self):
        return th.utils.data.DataLoader(self.train_dataset, 
                                        batch_size=self.batch_size, 
                                        pin_memory = self.pin_memory,
                                        shuffle=True, 
                                        num_workers=self.num_workers,
                                        worker_init_fn=self.seed_worker,
                                        generator=self.g)

    def val_dataloader(self):
        return th.utils.data.DataLoader(self.val_dataset, 
                                        batch_size=self.batch_size, 
                                        pin_memory = self.pin_memory,
                                        shuffle=False, 
                                        num_workers=self.num_workers,
                                        worker_init_fn=self.seed_worker,
                                        generator=self.g)
    
    def test_dataloader(self):
        return th.utils.data.DataLoader(self.test_dataset, 
                                        batch_size=self.batch_size,
                                        pin_memory = self.pin_memory, 
                                        shuffle=False, 
                                        num_workers=self.num_workers, 
                                        worker_init_fn=self.seed_worker,
                                        generator=self.g)

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        th.manual_seed(seed)
        th.cuda.manual_seed_all(seed)

    def seed_worker(self, worker_id):
        worker_seed = th.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

In [3]:
import argparse
import sys
from utils.load_config import load_config  

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/train_rnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  # Игнорирует нераспознанные аргументы
cfg = load_config(args.hparams)

In [4]:
cfg['data']

{'data_dir': 'F:/ISSAI_KSC2_unpacked/DIHARD_DATA_INFO/CONCATED_DFS_tts=2000_k_2.csv',
 'csv_file': True,
 'total_percent': 1.0,
 'train_percent': 0.8,
 'valid_percent': 0.1,
 'test_percent': 0.1,
 'num_workers': 0,
 'batch_size': 1,
 'pin_memory': True,
 'seed': 42,
 'sample_rate': 16000,
 'chunk_size': 32000,
 'least_size': 16000}

In [5]:
datamodule = AudioDataModule(**cfg['data'])
datamodule.setup(stage = 'train')

Size of training set: 800
Size of validation set: 100
Elapsed time 'setup': 00:00:02.23


In [6]:
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}

In [7]:
# Получение первого батча данных из DataLoader
dataloader = dataloaders['train'] 
sample_mix, sample_refs = next(iter(dataloader))  # Используем iter и next для доступа к данным
print(sample_mix)
print(sample_refs)

tensor([[ 0.0005,  0.0007,  0.0008,  ..., -0.0423, -0.0381, -0.0295]])
[tensor([[0.0052, 0.0042, 0.0026,  ..., 0.0020, 0.0026, 0.0033]])]


#### Training 

In [45]:
import os
import torch
import torchmetrics
import argparse
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter as TensorBoard

from utils.load_config import load_config 
from utils.training import metadata_info, configure_optimizer
from utils.measure_time import measure_time
from utils.training import p_output_log 
from models.model_rnn import Dual_RNN_model
from losses import Loss

In [46]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')

In [47]:
cfg['model']

{'in_channels': 256,
 'out_channels': 64,
 'hidden_channels': 128,
 'kernel_size': 2,
 'rnn_type': 'LSTM',
 'norm': 'ln',
 'dropout': 0,
 'bidirectional': True,
 'num_layers': 6,
 'K': 250,
 'speaker_num': 2}

In [48]:
model = Dual_RNN_model(**cfg['model'])

In [49]:
metadata_info(model)
writer = TensorBoard(f'tb_logs/{Path(args.hparams).stem}', comment = f"{cfg['trainer']['ckpt_folder']}")
optimizer = configure_optimizer (cfg, model)

Trainable parametrs: 2633729
Size of model: 10.05 MB, in float32 



In [50]:
cfg['trainer']

{'num_epochs': 100,
 'device': 'cuda',
 'best_weights': False,
 'checkpointing': False,
 'checkpoint_interval': 5,
 'model_name': 'Dual_Path_RNN',
 'path_to_weights': './weights',
 'ckpt_folder': './checkpoints/train_rnn',
 'speaker_num': 2,
 'resume': False}

In [65]:
import torch
from itertools import permutations

def sisnr(x, s, eps=1e-8):
    """
    calculate training loss
    input:
          x: separated signal, N x S tensor
          s: reference signal, N x S tensor
    Return:
          sisnr: N tensor
    """

    def l2norm(mat, keepdim=False):
        return torch.norm(mat, dim=-1, keepdim=keepdim)

    if x.shape != s.shape:
        raise RuntimeError(
            "Dimention mismatch when calculate si-snr, {} vs {}".format(
                x.shape, s.shape))
    x_zm = x - torch.mean(x, dim=-1, keepdim=True)
    s_zm = s - torch.mean(s, dim=-1, keepdim=True)
    t = torch.sum(
        x_zm * s_zm, dim=-1,
        keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
    return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))


def CustomLoss(ests, egs):
    # spks x n x S
    refs = egs
    num_spks = len(refs)

    def sisnr_loss(permute):
        print(f"Length of ests: {len(ests)}, Length of refs: {len(refs)}")
        print(f"Permute: {permute}")
        # for one permute
        return sum([sisnr(ests[s], refs[t]) for s, t in enumerate(permute)]) / len(permute)  # average the value

    # P x N
    N = egs[0].size(0)
    sisnr_mat = torch.stack(
        [sisnr_loss(p) for p in permutations(range(num_spks))])
    max_perutt, _ = torch.max(sisnr_mat, dim=0)
    # si-snr
    return -torch.sum(max_perutt) / N


In [66]:
import torch

# Размерность: N = 3 (батч), S = 8 (длина сигнала)
# Предсказанные сигналы (2 говорящих)
ests = [
    torch.tensor([
        [0.5, 0.6, 0.8, 0.9, 1.0, 0.7, 0.3, 0.2],
        [0.3, 0.4, 0.6, 0.8, 0.5, 0.3, 0.1, 0.0],
        [0.2, 0.1, 0.3, 0.5, 0.7, 0.6, 0.4, 0.2]
    ]),
    torch.tensor([
        [0.1, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.5],
        [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 0.9, 0.8],
        [0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1]
    ])
]

# Эталонные сигналы (2 говорящих)
egs = [
    torch.tensor([
        [0.6, 0.7, 0.9, 1.0, 0.8, 0.6, 0.4, 0.3],
        [0.2, 0.3, 0.5, 0.7, 0.6, 0.4, 0.2, 0.1],
        [0.1, 0.2, 0.4, 0.6, 0.8, 0.7, 0.5, 0.3]
    ]),
    torch.tensor([
        [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0, 0.6],
        [0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 0.7],
        [0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1, 0.2]
    ])
]

# Рассчитаем функцию потерь
loss = CustomLoss(ests, egs)
print(f"Custom Loss: {loss.item()}")


Length of ests: 2, Length of refs: 2
Permute: (0, 1)
Length of ests: 2, Length of refs: 2
Permute: (1, 0)
Custom Loss: -9.563565254211426


In [67]:
import torch

# Эталонные сигналы (2 говорящих, батч из 3, длина сигнала 8)
egs = [
    torch.tensor([
        [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.8, 0.6],
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
        [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ]),
    torch.tensor([
        [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        [0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1],
        [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    ])
]

# Предсказанные сигналы, идентичные эталонным
ests = [
    torch.tensor([
        [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.8, 0.6],
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
        [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ]),
    torch.tensor([
        [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        [0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1],
        [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    ])
]

# Рассчитаем функцию потерь
loss = CustomLoss(ests, egs)
print(f"Custom Loss (ideal case): {loss.item()}")


Length of ests: 2, Length of refs: 2
Permute: (0, 1)
Length of ests: 2, Length of refs: 2
Permute: (1, 0)
Custom Loss (ideal case): -148.2356414794922


In [68]:
from itertools import permutations
import torch
from torch import nn

EPS = 1e-8


class MixerMSE(nn.Module):

    def __init__(self):

        super(MixerMSE, self).__init__()

        self.criterion1 = nn.MSELoss()

        self.criterion2 = nn.MSELoss()

    def forward(self, x, target):

        loss = self.criterion1(x[0, 0, :], target[0, 0, :]) + self.criterion2(x[0, 1, :], target[0, 1, :])

        return loss


def cal_loss_no(source, estimate_source, source_lengths):
    """
        Args:
            source: [B, C, T], B is batch size,C is the number of speaker,T is the length of each batch
            estimate_source: [B, C, T]
            source_lengths: [B]
    """
    max_snr, perms, max_snr_idx = cal_si_snr(source, estimate_source, source_lengths)

    loss = 0 - torch.mean(max_snr)

    reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx)

    return loss, max_snr, estimate_source, reorder_estimate_source


def cal_si_snr(source, estimate_source, source_lengths):
    """
        Calculate SI-SNR with PIT training.

        Args:
            source: [B, C, T], B is batch size
            estimate_source: [B, C, T]
            source_lengths: [B], each item is between [0, T]
    """
    assert source.size() == estimate_source.size()
    B, C, T = source.size()  # get all parameters
    # mask padding position along T
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # print("s_target.type()", s_target.type(), s_estimate.type())
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # print("pair_wise_dot.type()", pair_wise_dot.type(), "s_target_energy.type()", s_target_energy.type())
    # print("pair_wise_proj.type()", pair_wise_proj.type())
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # print(e_noise.type())
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
    # print("pair_wise_si_snr",pair_wise_si_snr.type())

    # Get max_snr of each utterance
    # permutations, [C!, C]
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    # print(index.type())
    # 如果不加.type(torch.float),perms-one-hot为long，在执行torch.einsum时会报错
    perms_one_hot = torch.unsqueeze(perms, dim=0).type(torch.float)
    # print("perms_one_hot", perms_one_hot.type())
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    # print("snr_set.type()",snr_set.type())
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C

    return max_snr, perms, max_snr_idx


def cal_loss_pit(source, estimate_source, source_lengths):
    """
        Args:
            source: [B, C, T], B is batch size,C is the number of speaker,T is the length of each batch
            estimate_source: [B, C, T]
            source_lengths: [B]
    """
    max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source, estimate_source, source_lengths)

    loss = 0 - torch.mean(max_snr)

    reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx)

    return loss, max_snr, estimate_source, reorder_estimate_source


def cal_si_snr_with_pit(source, estimate_source, source_lengths):
    """
        Calculate SI-SNR with PIT training.

        Args:
            source: [B, C, T], B is batch size
            estimate_source: [B, C, T]
            source_lengths: [B], each item is between [0, T]
    """
    assert source.size() == estimate_source.size()
    B, C, T = source.size()  # get all parameters
    # mask padding position along T
    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    # Step 1. Zero-mean norm
    num_samples = source_lengths.view(-1, 1, 1).float()  # [B, 1, 1]
    mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate
    # mask padding position along T
    zero_mean_target *= mask
    zero_mean_estimate *= mask

    # Step 2. SI-SNR with PIT
    # reshape to use broadcast
    s_target = torch.unsqueeze(zero_mean_target, dim=1)  # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2)  # [B, C, 1, T]
    # print("s_target.type()", s_target.type(), s_estimate.type())
    # s_target = <s', s>s / ||s||^2
    pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)  # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS  # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy  # [B, C, C, T]
    # print("pair_wise_dot.type()", pair_wise_dot.type(), "s_target_energy.type()", s_target_energy.type())
    # print("pair_wise_proj.type()", pair_wise_proj.type())
    # e_noise = s' - s_target
    e_noise = s_estimate - pair_wise_proj  # [B, C, C, T]
    # print(e_noise.type())
    # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
    # print("pair_wise_si_snr",pair_wise_si_snr.type())


    # Get max_snr of each utterance
    # permutations, [C!, C] 
    perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
    # one-hot, [C!, C, C]
    index = torch.unsqueeze(perms, 2)
    # print(index.type())
    # 如果不加.type(torch.float),perms-one-hot为long，在执行torch.einsum时会报错
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1).type(torch.float)
    # print("perms_one_hot", perms_one_hot.type())
    # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
    snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
    # print("snr_set.type()",snr_set.type())
    max_snr_idx = torch.argmax(snr_set, dim=1)  # [B]
    # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1))  # [B, 1]
    max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
    max_snr /= C
    return max_snr, perms, max_snr_idx


def reorder_source(source, perms, max_snr_idx):
    """
        Args:
            source: [B, C, T]
            perms: [C!, C], permutations
            max_snr_idx: [B], each item is between [0, C!)
        Returns:
            reorder_source: [B, C, T]
    """
    B, C, *_ = source.size()
    # [B, C], permutation whose SI-SNR is max of each utterance
    # for each utterance, reorder estimate source according this permutation
    max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
    # print('max_snr_perm', max_snr_perm)
    # maybe use torch.gather()/index_select()/scatter() to impl this?
    reorder_source = torch.zeros_like(source)

    for b in range(B):
        for c in range(C):
            reorder_source[b, c] = source[b, max_snr_perm[b][c]]

    return reorder_source


def get_mask(source, source_lengths):
    """
        Args:
            source: [B, C, T]
            source_lengths: [B]
        Returns:
            mask: [B, 1, T]
    """
    B, _, T = source.size()

    mask = source.new_ones((B, 1, T))

    for i in range(B):
        mask[i, :, source_lengths[i]:] = 0

    return mask


if __name__ == "__main__":
    torch.manual_seed(123)
    B, C, T = 1, 2, 32000
    # fake data
    source = torch.randint(4, (B, C, T))
    estimate_source = torch.randint(4, (B, C, T))
    source[0, :, -3:] = 0
    estimate_source[0, :, -3:] = 0
    source_lengths = torch.FloatTensor([T, T - 1]).type(torch.int)
    print('source', source)
    print('estimate_source', estimate_source)
    print('source_lengths', source_lengths)

    loss, max_snr, estimate_source, reorder_estimate_source = cal_loss_no(source, estimate_source, source_lengths)
    print('loss', loss)
    print('max_snr', max_snr)
    print('reorder_estimate_source', reorder_estimate_source)

source tensor([[[2, 1, 2,  ..., 0, 0, 0],
         [2, 1, 2,  ..., 0, 0, 0]]])
estimate_source tensor([[[0, 0, 2,  ..., 0, 0, 0],
         [2, 3, 1,  ..., 0, 0, 0]]])
source_lengths tensor([32000, 31999], dtype=torch.int32)
loss tensor(43.5560)
max_snr tensor([[-43.5560],
        [-43.5560]])
reorder_estimate_source tensor([[[0, 0, 2,  ..., 0, 0, 0],
         [2, 3, 1,  ..., 0, 0, 0]]])


In [57]:
class Trainer:
    def __init__(self, num_epochs = 100, device='cuda', best_weights = False, checkpointing = False, 
                 checkpoint_interval = 10, model_name = '', path_to_weights= './weights', ckpt_folder = '',
                 speaker_num = 2, resume = False) -> None:
        self.num_epochs = num_epochs
        self.device = device
        self.best_weights = best_weights
        self.checkpointing = checkpointing
        self.checkpoint_interval = checkpoint_interval
        self.model_name = model_name
        os.makedirs(path_to_weights, exist_ok=True)
        self.path_to_weights = path_to_weights
        self.ckpt_folder = ckpt_folder
        self.speaker_num = speaker_num
        self.resume = resume

    @measure_time
    def fit(self, model, dataloaders, criterion, optimizer, writer) -> None:
        model.to(self.device)
        min_val_loss = float('inf')
        for epoch in range(self.num_epochs):
            for phase in ['train', 'valid']:
                model.train() if phase == 'train' else model.eval()
                dataloader = dataloaders[phase] 
                running_loss = 0.0
                for inputs, labels in tqdm(dataloader):
                    inputs, labels = inputs.to(self.device), [label.to(self.device) for label in labels]
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                    running_loss += loss.item() * inputs.size(0)

                break
                epoch_loss = running_loss / len(dataloader.dataset)
                print(epoch_loss)
                # p_output_log(self.num_epochs, epoch, epoch_loss)

In [58]:
Trainer(**cfg['trainer']).fit(model, dataloaders, CustomLoss, optimizer, writer)

  0%|          | 0/800 [00:00<?, ?it/s]

Length of ests: 2, Length of refs: 1
Permute: (0,)
Length of ests: 2, Length of refs: 4
Permute: (0, 1, 2, 3)


IndexError: list index out of range