#### Testing Dataset

In [1]:
import os
import random
import math
from glob import glob
import torch as th
import torchaudio
import pytorch_lightning as pl
from typing import Optional, List
import torch.nn.functional as F
import pandas as pd 
from tqdm.notebook import tqdm 
import numpy as np

from utils.measure_time import measure_time 

In [2]:
class AudioDataset(th.utils.data.Dataset):
    def __init__(self, mix_file_paths: List[str], ref_file_paths: List[List[str]], 
                 sr: int = 8000, chunk_size: int = 32000, least_size: int = 16000):
        super().__init__()
        self.mix_audio = []
        self.ref_audio = []
        k = len(ref_file_paths[1])
        
        for mix_path, ref_paths in zip(mix_file_paths, ref_file_paths):
            common_len = ref_paths[0]
            chunked_mix = self._load_audio(mix_path, sr, common_len, chunk_size, least_size)
            if not chunked_mix: continue
            ref_audio_chunks = []
            
            chunks_num, same_shape = 0, 0.0
            for ref_path in ref_paths[1]:
                res = self._load_audio(ref_path, sr, common_len, chunk_size, least_size)
                ref_audio_chunks.append(res)
                
                if chunks_num != len(res) and chunks_num != 0:
                    raise RuntimeError('different chuncks')
                if same_shape != ref_audio_chunks[0][0].shape and same_shape != 0.0:
                    raise RuntimeError('differen shape of chuncks')
                chunks_num = len(res)
                same_shape = ref_audio_chunks[0][0].shape
  
            if k != len(ref_audio_chunks): continue
            
            if chunked_mix[0].shape != ref_audio_chunks[0][0].shape:
                raise RuntimeError('chunked_mix[0].shape != ref_audio_chunks[0][0].shape')
            
            self.mix_audio.append(chunked_mix)
            self.ref_audio.append(ref_audio_chunks)
            
    def __len__(self):
        return len(self.mix_audio)

    def __getitem__(self, idx):
        mix = self.mix_audio[idx][0]
        refs = self.ref_audio[idx]
        # refs = self.ref_audio[idx][0]
        # refs = [ref[idx] for ref in self.ref_audio]
        return mix, refs
    
    @staticmethod
    def _load_audio(path:str, sr: int, common_len: int, chunk_size: int, least_size: int):
        audio, _sr = torchaudio.load(path)
        audio = audio.squeeze()
        audio = audio[:common_len]
        if _sr != sr: raise RuntimeError(f"Sample rate mismatch: {_sr} vs {sr}")
        if audio.shape[0] < least_size: return []
        audio_chunks = []
        if least_size < audio.shape[0] < chunk_size:
            pad_size = chunk_size - audio.shape[0]
            audio_chunks.append(F.pad(audio, (0, pad_size), mode='constant'))
        else:
            start = 0
            while start + chunk_size <= audio.shape[0]:
                audio_chunks.append(audio[start:start + chunk_size])
                start += least_size
        return audio_chunks 

class ExtendedAudioDataset(AudioDataset):
    def __getitems__(self, item):
        return self.__getitem__(item)

class AudioDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, csv_file:bool = False, total_percent: float = 0.1, train_percent:float = 0.8, valid_percent:float = 0.1, 
                 test_percent: float = 0.1, num_workers: int = 4, batch_size: int = 512, pin_memory = False, seed: int = 42, 
                 sample_rate: int = 8000, chunk_size: int = 32000, least_size: int = 16000):
        super().__init__()
        self.batch_size = batch_size
        self.sr = sample_rate
        self.chunk_size = chunk_size
        self.least_size = least_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.seed = seed
        self._set_seed(seed)
        self.g = th.Generator()
        self.g.manual_seed(seed)
        self.mix_paths = []
        self.ref_paths = []
        
        if csv_file:
            full_df = pd.read_csv(data_dir)
            for _, row in full_df.iterrows():
                self.mix_paths.append (row.iloc[0])
                self.ref_paths.append([row.iloc[1], sorted([row[column] for column in full_df.columns[2:]])])
        else: 
            mixed_list = glob(os.path.join(data_dir, "*.flac"))
            for mx in tqdm(mixed_list):
                mx = mx.replace('\\', '/')
                self.mix_paths.append (mx)
                mx_df = pd.read_csv(mx.replace('flac', 'csv'))
                f_real = [mx_df.iloc[0], sorted([mx_df.iloc[0][column] for column in mx_df.columns[2:]])] # THERE IS BAG need to fixed 
                self.ref_paths.append(f_real)

        self.mix_paths = self.mix_paths[:int(len(self.mix_paths) * total_percent)]
        self.ref_paths = self.ref_paths[:int(len(self.ref_paths) * total_percent)]
        random.shuffle(self.mix_paths)
        assert math.isclose(train_percent + valid_percent + test_percent, 1.0, rel_tol=1e-9), "Sum doesnt equal to 1" 
        self.train_len = int(len(self.mix_paths) * train_percent)
        self.valid_len = int(len(self.mix_paths) * valid_percent)
        self.test_len = int(len(self.mix_paths) * test_percent)

    @measure_time
    def setup(self, stage = 'train'):
        assert stage in ['train', 'eval'], "Invalid stage"
        
        if stage == 'train': 
            self.train_dataset = AudioDataset(self.mix_paths[:self.train_len], 
                                              self.ref_paths[:self.train_len], 
                                              sr = self.sr, 
                                              chunk_size = self.chunk_size, 
                                              least_size = self.least_size)
            print(f"Size of training set: {len(self.train_dataset)}")
            
            self.val_dataset = AudioDataset(self.mix_paths[self.train_len:self.train_len + self.valid_len], 
                                            self.ref_paths[self.train_len:self.train_len + self.valid_len], 
                                            sr = self.sr, 
                                            chunk_size = self.chunk_size, 
                                            least_size = self.least_size) 
            print(f"Size of validation set: {len(self.val_dataset)}")
            
        if stage == 'eval':
            self.test_dataset = AudioDataset(self.mix_paths[self.train_len + self.valid_len:], 
                                             self.ref_paths[self.train_len + self.valid_len:], 
                                             sr = self.sr, 
                                             chunk_size = self.chunk_size, 
                                             least_size = self.least_size)
            print(f"Size of test set: {len(self.test_dataset)}")

        
    def train_dataloader(self):
        return th.utils.data.DataLoader(self.train_dataset, 
                                        batch_size=self.batch_size, 
                                        pin_memory = self.pin_memory,
                                        shuffle=True, 
                                        num_workers=self.num_workers,
                                        worker_init_fn=self.seed_worker,
                                        generator=self.g)

    def val_dataloader(self):
        return th.utils.data.DataLoader(self.val_dataset, 
                                        batch_size=self.batch_size, 
                                        pin_memory = self.pin_memory,
                                        shuffle=False, 
                                        num_workers=self.num_workers,
                                        worker_init_fn=self.seed_worker,
                                        generator=self.g)
    
    def test_dataloader(self):
        return th.utils.data.DataLoader(self.test_dataset, 
                                        batch_size=self.batch_size,
                                        pin_memory = self.pin_memory, 
                                        shuffle=False, 
                                        num_workers=self.num_workers, 
                                        worker_init_fn=self.seed_worker,
                                        generator=self.g)

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        th.manual_seed(seed)
        th.cuda.manual_seed_all(seed)

    def seed_worker(self, worker_id):
        worker_seed = th.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

In [3]:
import argparse
import sys
from utils.load_config import load_config  

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/train_rnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  # Игнорирует нераспознанные аргументы
cfg = load_config(args.hparams)

In [4]:
datamodule = AudioDataModule(**cfg['data'])
datamodule.setup(stage = 'train')

Size of training set: 800
Size of validation set: 100
Elapsed time 'setup': 00:00:02.18


In [5]:
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}

In [6]:
# Получение первого батча данных из DataLoader
dataloader = dataloaders['train'] 
sample_mix, sample_refs = next(iter(dataloader))  # Используем iter и next для доступа к данным
print(sample_mix[0])
print(len(sample_refs))

tensor([ 0.0005,  0.0007,  0.0008,  ..., -0.0423, -0.0381, -0.0295])
2


In [17]:
sample_mix[0].shape

torch.Size([32000])

In [23]:
sample_mix

tensor([[ 0.0005,  0.0007,  0.0008,  ..., -0.0423, -0.0381, -0.0295]])

In [26]:
output = sample_refs[0][0].shape
output

torch.Size([1, 32000])

In [32]:
output = sample_refs[0][0].squeeze()
output

tensor([0.0052, 0.0042, 0.0026,  ..., 0.0020, 0.0026, 0.0033])

In [33]:
output.shape

torch.Size([32000])

In [22]:
sample_refs[0][0]

tensor([[0.0052, 0.0042, 0.0026,  ..., 0.0020, 0.0026, 0.0033]])

#### Training 

In [46]:
import os
import torch
import torchmetrics
import argparse
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter as TensorBoard

from utils.load_config import load_config 
from utils.training import metadata_info, configure_optimizer
from utils.measure_time import measure_time
from utils.training import p_output_log 
from models.model_rnn import Dual_RNN_model
from losses import Loss

In [47]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')

model = Dual_RNN_model(**cfg['model'])

metadata_info(model)
writer = TensorBoard(f'tb_logs/{Path(args.hparams).stem}', comment = f"{cfg['trainer']['ckpt_folder']}")
optimizer = configure_optimizer (cfg, model)

Trainable parametrs: 2633729
Size of model: 10.05 MB, in float32 



In [20]:
import torch
from itertools import permutations

def sisnr(x, s, eps=1e-8):
    """
    calculate training loss
    input:
          x: separated signal, N x S tensor
          s: reference signal, N x S tensor
    Return:
          sisnr: N tensor
    """
    print(x.shape)
    def l2norm(mat, keepdim=False):
        return torch.norm(mat, dim=-1, keepdim=keepdim)

    if x.shape != s.shape:
        raise RuntimeError(
            "Dimention mismatch when calculate si-snr, {} vs {}".format(
                x.shape, s.shape))
    x_zm = x - torch.mean(x, dim=-1, keepdim=True)
    s_zm = s - torch.mean(s, dim=-1, keepdim=True)
    t = torch.sum(
        x_zm * s_zm, dim=-1,
        keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
    return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))


def CustomLoss(ests, egs):
    print('ests', ests)
    print('len ests', len(ests))
    print('egs', egs)
    print('len egs', len(egs))
    # spks x n x S
    refs = egs
    num_spks = len(refs)

    def sisnr_loss(permute):
        # print(f"Length of ests: {len(ests)}, Length of refs: {len(refs)}")
        # print(f"Permute: {permute}")
        # for one permute
        return sum([sisnr(ests[s], refs[t]) for s, t in enumerate(permute)]) / len(permute)  # average the value

    # P x N
    N = egs[0].size(0)
    sisnr_mat = torch.stack(
        [sisnr_loss(p) for p in permutations(range(num_spks))])
    max_perutt, _ = torch.max(sisnr_mat, dim=0)
    # si-snr
    return -torch.sum(max_perutt) / N


In [21]:
import torch

# Размерность: N = 3 (батч), S = 8 (длина сигнала)
# Предсказанные сигналы (2 говорящих)
ests = [
    torch.tensor([
        [0.5, 0.6, 0.8, 0.9, 1.0, 0.7, 0.3, 0.2]
    ]),
    torch.tensor([
        [0.1, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.5]
    ])
]

# Эталонные сигналы (2 говорящих)
egs = [
    torch.tensor([
        [0.6, 0.7, 0.9, 1.0, 0.8, 0.6, 0.4, 0.3]
    ]),
    torch.tensor([
        [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0, 0.6]
    ])
]

# Рассчитаем функцию потерь
loss = CustomLoss(ests, egs)
print(f"Custom Loss: {loss.item()}")


ests [tensor([[0.5000, 0.6000, 0.8000, 0.9000, 1.0000, 0.7000, 0.3000, 0.2000]]), tensor([[0.1000, 0.2000, 0.4000, 0.6000, 0.7000, 0.8000, 0.9000, 0.5000]])]
len ests 2
egs [tensor([[0.6000, 0.7000, 0.9000, 1.0000, 0.8000, 0.6000, 0.4000, 0.3000]]), tensor([[0.0000, 0.1000, 0.3000, 0.5000, 0.7000, 0.9000, 1.0000, 0.6000]])]
len egs 2
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
torch.Size([1, 8])
Custom Loss: -10.556488990783691


In [None]:
import torch

# Эталонные сигналы (2 говорящих, батч из 3, длина сигнала 8)
egs = [
    torch.tensor([
        [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.8, 0.6],
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
        [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ]),
    torch.tensor([
        [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        [0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1],
        [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    ])
]

# Предсказанные сигналы, идентичные эталонным
ests = [
    torch.tensor([
        [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.8, 0.6],
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
        [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ]),
    torch.tensor([
        [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        [0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1],
        [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    ])
]

# Рассчитаем функцию потерь
loss = CustomLoss(ests, egs)
print(f"Custom Loss (ideal case): {loss.item()}")


In [36]:
class Trainer:
    def __init__(self, num_epochs = 100, device='cuda', best_weights = False, checkpointing = False, 
                 checkpoint_interval = 10, model_name = '', path_to_weights= './weights', ckpt_folder = '',
                 speaker_num = 2, resume = False) -> None:
        self.num_epochs = num_epochs
        self.device = device
        self.best_weights = best_weights
        self.checkpointing = checkpointing
        self.checkpoint_interval = checkpoint_interval
        self.model_name = model_name
        os.makedirs(path_to_weights, exist_ok=True)
        self.path_to_weights = path_to_weights
        self.ckpt_folder = ckpt_folder
        self.speaker_num = speaker_num
        self.resume = resume

    @measure_time
    def fit(self, model, dataloaders, criterion, optimizer, writer) -> None:
        model.to(self.device)
        min_val_loss = float('inf')
        for epoch in range(self.num_epochs):
            for phase in ['train', 'valid']:
                model.train() if phase == 'train' else model.eval()
                dataloader = dataloaders[phase] 
                running_loss = 0.0
                for inputs, labels in tqdm(dataloader):
                    inputs, labels = inputs.to(self.device), [label.to(self.device) for label in labels]
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        print(outputs)
                        loss = criterion(outputs, labels)
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                    running_loss += loss.item() * inputs.size(0)

                break
                epoch_loss = running_loss / len(dataloader.dataset)
                print(epoch_loss)
                # p_output_log(self.num_epochs, epoch, epoch_loss)

In [37]:
Trainer(**cfg['trainer']).fit(model, dataloaders, CustomLoss, optimizer, writer)

  0%|          | 0/800 [00:00<?, ?it/s]

[tensor([[ 8.9875e-05, -4.2176e-05,  9.0446e-06,  ...,  7.8968e-03,
          4.8392e-03,  4.3211e-03]], device='cuda:0',
       grad_fn=<SqueezeBackward1>), tensor([[ 1.1987e-04, -1.8660e-05, -1.4660e-05,  ...,  2.4879e-03,
          1.1837e-03, -8.5887e-03]], device='cuda:0',
       grad_fn=<SqueezeBackward1>)]
egs [tensor([[-0.0025,  0.0034,  0.0062,  ..., -0.0254, -0.0289, -0.0309]],
       device='cuda:0'), tensor([[ 0.0685,  0.0726,  0.0726,  ..., -0.0013, -0.0017, -0.0021]],
       device='cuda:0'), tensor([[-0.0311, -0.0313, -0.0363,  ..., -0.0239, -0.0259, -0.0262]],
       device='cuda:0')]
len egs 3
torch.Size([1, 32000])
torch.Size([1, 32000])


IndexError: list index out of range