#### Testing Dataset

In [1]:
import os
import random
import math
from glob import glob
import torch as th
import torchaudio
import pytorch_lightning as pl
from typing import Optional, List
import torch.nn.functional as F
import pandas as pd 
from tqdm.notebook import tqdm 
import numpy as np

from utils.measure_time import measure_time 

In [3]:
import argparse
import sys
from utils.load_config import load_config  

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/train_rnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  # Игнорирует нераспознанные аргументы
cfg = load_config(args.hparams)

In [4]:
datamodule = AudioDataModule(**cfg['data']).setup(stage = 'train')

Size of training set: 420
Size of validation set: 63
Elapsed time 'setup': 00:00:02.20


In [5]:
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}

In [6]:
# Получение первого батча данных из DataLoader
dataloader = dataloaders['train'] 
sample_mix, sample_refs = next(iter(dataloader))  # Используем iter и next для доступа к данным

In [7]:
print(sample_mix, '\n')
print('chunks_num', len(sample_mix), '\n')
print(sample_mix[0], '\n')
print(sample_mix[0].shape, '\n')
print('----------------------------------------------', '\n')
print('spekears num', len(sample_refs), '\n')
print('firs_speaker list:', sample_refs[0], '\n')
print('chunks_nums', len(sample_refs[0]), '\n')
print(sample_refs[0][0].shape, '\n')

[tensor([[0.0016, 0.0045, 0.0016,  ..., 0.0505, 0.1465, 0.1519]])] 

chunks_num 1 

tensor([[0.0016, 0.0045, 0.0016,  ..., 0.0505, 0.1465, 0.1519]]) 

torch.Size([1, 32000]) 

---------------------------------------------- 

spekears num 2 

firs_speaker list: [tensor([[-0.0042, -0.0083, -0.0139,  ..., -0.0009, -0.0040, -0.0038]])] 

chunks_nums 1 

torch.Size([1, 32000]) 



#### Testing dataloaders LAST UPDATE. 2.

In [1]:
import argparse
import sys
from utils.load_config import load_config  

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/train_rnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  # Игнорирует нераспознанные аргументы
cfg = load_config(args.hparams)

In [2]:
from data.DiarizationDataset import DiarizationDataset

datamodule = DiarizationDataset(**cfg['data']).setup(stage = 'train')
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}

Size of training set: 3140
Size of validation set: 641
Elapsed time 'setup': 00:00:02.50


In [2]:
import torch.nn.functional as F
import torch
import torchaudio

from typing import List, Tuple
import os.path as ospth

def get_file_name(file_path: str):
    return ospth.splitext(ospth.basename(file_path))[0]

def handle_df(audios: List[Tuple[int, str]]) -> dict:
    scp_dict = dict()
    for audio in audios:
        common_len, l = audio
        if len(audio) != 2:
            raise RuntimeError("Format error in")
        if len(audio) == 2:
            key, value = f"{get_file_name (l)}.flac", l
        if key in scp_dict:
            raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
                l, l))
        scp_dict[key] = {'common_len': common_len, 'name': value}
    return scp_dict
        
    
def read_wav(fname, return_rate=False):
    '''
         Read wavfile using Pytorch audio
         input:
               fname: wav file path
               return_rate: Whether to return the sampling rate
         output:
                src: output tensor of size C x L 
                     L is the number of audio frames 
                     C is the number of channels. 
                sr: sample rate
    '''
    src, sr = torchaudio.load(fname, channels_first=True)
    if return_rate:
        return src.squeeze(), sr
    else:
        return src.squeeze()


def write_wav(fname, src, sample_rate):
    '''
         Write wav file
         input:
               fname: wav file path
               src: frames of audio
               sample_rate: An integer which is the sample rate of the audio
         output:
               None
    '''
    torchaudio.save(fname, src, sample_rate)


class CustomAudioReader(object):
    '''
        Class that reads Wav format files
        Input:
            scp_path (str): a different scp file address
            sample_rate (int, optional): sample rate (default: 8000)
            chunk_size (int, optional): split audio size (default: 32000(4 s))
            least_size (int, optional): Minimum split size (default: 16000(2 s))
        Output:
            split audio (list)
    '''

    def __init__(self, scp_path, sample_rate=8000, chunk_size=32000, least_size=16000):
        super(CustomAudioReader, self).__init__()
        self.sample_rate = sample_rate
        self.index_dict = handle_df(scp_path)
        self.keys = list(self.index_dict.keys())
        # print(self.keys[0])
        self.audio = []
        self.chunk_size = chunk_size
        self.least_size = least_size
        self.split()

    def split(self):
        '''
            split audio with chunk_size and least_size
        '''
        for key in self.keys:
            common_len, name = self.index_dict[key]['common_len'], self.index_dict[key]['name']
            utt = read_wav(name)
            utt = utt[:common_len]
            if utt.shape[0] < self.least_size:
                continue
            if utt.shape[0] > self.least_size and utt.shape[0] < self.chunk_size:
                gap = self.chunk_size-utt.shape[0]
                self.audio.append(F.pad(utt, (0, gap), mode='constant'))
            if utt.shape[0] >= self.chunk_size:
                start = 0
                while True:
                    if start + self.chunk_size > utt.shape[0]:
                        break
                    self.audio.append(utt[start:start+self.chunk_size])
                    start += self.least_size

    def get_num_after_splitting(self):
        print(len(self.audio))

In [3]:
import torch

import numpy as np

class CustomDatasets(torch.utils.data.Dataset):
    '''
       Load audio data
       mix_scp: file path of mix audio (type: str)
       ref_scp: file path of ground truth audio (type: list[spk1,spk2])
       chunk_size (int, optional): split audio size (default: 32000(4 s))
       least_size (int, optional): Minimum split size (default: 16000(2 s))
    '''

    def __init__(self, df=None, sample_rate=16000, chunk_size=32000, least_size=16000):
        super(torch.utils.data.Dataset, self).__init__()
        k = len(df.iloc[0]) - 2
        mix_scp = []
        ref_scp = [[] for _ in range (k)]
        
        for _, row in df.iterrows():
            common_len_idx = row['common_len_idx']
            mix_scp.append([common_len_idx, row['mixed_audio']])
            i = 0
            for col in df.columns[2:]:
                audio_value = row[col]
                ref_scp[i].append([common_len_idx, audio_value])
                i += 1 
    
        self.mix_audio = CustomAudioReader(mix_scp, sample_rate=sample_rate, chunk_size=chunk_size, least_size=least_size).audio
        self.ref_audio = [CustomAudioReader(r, sample_rate=sample_rate, chunk_size=chunk_size, least_size=least_size).audio for r in ref_scp]

    def __len__(self):
        return len(self.mix_audio)

    def __getitem__(self, index):
        return self.mix_audio[index], [ref[index] for ref in self.ref_audio]

In [4]:
from utils.measure_time import measure_time 
import torch as th
import random
import numpy as np
import math
import pandas as pd

class TestingDiarizationDataset:
    def __init__(self, data_root = './', train_percent = 0.75, valid_percent = 0.15, test_percent = 0.0, shuffle=False, 
                 num_workers=0, batch_size=1, pin_memory = False, sample_rate=8000, chunk_size=32000, least_size=16000, seed = 42):
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.sample_rate = sample_rate
        self.chunk_size = chunk_size
        self.least_size = least_size
        self.seed = seed
        self._set_seed(seed)
        self.g = th.Generator()
        self.g.manual_seed(seed)
        full_data_df = pd.read_csv(data_root) 
        assert math.isclose(train_percent + valid_percent + test_percent, 1.0, rel_tol=1e-9), "Sum doesnt equal to 1" 
        train_size = int(train_percent * len(full_data_df)) 
        val_size = int(valid_percent * len(full_data_df)) 
        test_size = len(full_data_df) - train_size - val_size
        self.train_df = full_data_df.iloc[:train_size] 
        self.val_df = full_data_df.iloc[train_size:train_size + val_size] 
        self.test_df = full_data_df.iloc[train_size + val_size:]
         
    @measure_time
    def setup(self, stage = 'train'):
        assert stage in ['train', 'eval'], "Invalid stage" 
        if stage == 'train': 
            self.train_dataset = CustomDatasets(self.train_df, 
                                            sample_rate = self.sample_rate,
                                            chunk_size = self.chunk_size,
                                            least_size = self.least_size)
            print(f"Size of training set: {len(self.train_dataset)}")
            self.val_dataset = CustomDatasets(self.val_df, 
                                        sample_rate = self.sample_rate,
                                        chunk_size = self.chunk_size,
                                        least_size = self.least_size)
            print(f"Size of validation set: {len(self.val_dataset)}")
        # To Do 
        # self.test_dataset
        
        return self # warning! 
        
    def train_dataloader(self):
        return th.utils.data.DataLoader(self.train_dataset,
                                    batch_size = self.batch_size,
                                    pin_memory = self.pin_memory,
                                    shuffle = self.shuffle,
                                    num_workers = self.num_workers,
                                    worker_init_fn=self.seed_worker,
                                    generator=self.g)
        
    def val_dataloader(self):
        return th.utils.data.DataLoader(self.train_dataset,
                                    batch_size = self.batch_size,
                                    pin_memory = self.pin_memory,
                                    shuffle = False,
                                    num_workers = self.num_workers,
                                    worker_init_fn=self.seed_worker,
                                    generator=self.g)

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        th.manual_seed(seed)
        th.cuda.manual_seed_all(seed)

    def seed_worker(self, worker_id):
        worker_seed = th.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)
        
    # ToDo
    # def test_dataloader(self):

In [5]:
datamodule = TestingDiarizationDataset(**cfg['data']).setup(stage = 'train')
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}

Size of training set: 3140
Size of validation set: 641
Elapsed time 'setup': 00:00:02.54


In [6]:
# Получение первого батча данных из DataLoader
dataloader = dataloaders['train'] 
sample_mix, sample_refs = next(iter(dataloader))  
print(sample_mix)
print('chunks_num', len(sample_mix))
print(sample_mix[0])
print(sample_mix[0].shape)
print('----------------------------------------------')
print('spekears num', len(sample_refs))
print('firs_speaker list:', sample_refs[0])
print('chunks_nums', len(sample_refs[0]))
print(sample_refs[0][0].shape)

tensor([[0.0209, 0.0117, 0.0137,  ..., 0.0167, 0.0109, 0.0182]])
chunks_num 1
tensor([0.0209, 0.0117, 0.0137,  ..., 0.0167, 0.0109, 0.0182])
torch.Size([32000])
----------------------------------------------
spekears num 2
firs_speaker list: tensor([[ 0.0179,  0.0152,  0.0104,  ...,  0.0005, -0.0008, -0.0029]])
chunks_nums 1
torch.Size([32000])


In [None]:
tensor([[0.0209, 0.0117, 0.0137,  ..., 0.0167, 0.0109, 0.0182]])
chunks_num 1
tensor([0.0209, 0.0117, 0.0137,  ..., 0.0167, 0.0109, 0.0182])
torch.Size([32000])
----------------------------------------------
spekears num 2
firs_speaker list: tensor([[ 0.0179,  0.0152,  0.0104,  ...,  0.0005, -0.0008, -0.0029]])
chunks_nums 1
torch.Size([32000])

In [None]:
# Size of training set: 3140
# Size of validation set: 641
# Elapsed time 'setup': 00:00:03.57

#### Developed | Loss Functions

In [2]:
import torch
from losses import sdr_loss
from torchmetrics.audio import PermutationInvariantTraining
from torchmetrics.functional.audio import signal_distortion_ratio
from torchmetrics.audio import SignalDistortionRatio
from torchmetrics.functional.audio import permutation_invariant_training

from losses import sisnr_pit
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio

seed = 42
torch.manual_seed(seed)

batch = 1
spk = 2
time = 32000
sample_mix = [torch.randn(batch, time) for _ in range(spk)]
sample_refs = [torch.randn(batch, time) for _ in range(spk)]

# print(sample_mix[0].shape)

# Функция для смешивания тензоров с заданной схожестью
def mix_tensors_with_similarity(tensor1, tensor2, unsimilarity=0.05):
    # Векторная интерполяция между tensor1 и tensor2
    return unsimilarity * tensor1 + (1 - unsimilarity) * tensor2

# Применяем смешивание для всех спикеров
sample_mix_similar = [
    mix_tensors_with_similarity(m, r, unsimilarity=0.999999999999999999999999)
    for m, r in zip(sample_mix, sample_refs)
]

sample_mix = sample_mix_similar
r_sample_mix = sample_mix[::-1]

sample_mix_tensor = torch.stack(sample_mix, dim=1)  # по оси 1 (spk) ожидаемый вывод: torch.Size([1, 2, 3200])
sample_refs_tensor = torch.stack(sample_refs, dim=1)  # по оси 1 (spk)  # ожидаемый вывод: torch.Size([1, 2, 3200])
r_sample_mix_tensor = torch.stack(r_sample_mix, dim=1)  # по оси 1 (spk) ожидаемый вывод: torch.Size([1, 2, 3200])

# print(sample_mix_tensor.shape)
# print(sample_refs_tensor.shape)

pit = PermutationInvariantTraining(signal_distortion_ratio, mode="speaker-wise", eval_func="max").to('cuda')
print('taudio:', - pit(r_sample_mix_tensor, sample_refs_tensor)) # Warining "-" minus before
sdr = SignalDistortionRatio()
print('another taudio:', - sdr(sample_mix_tensor, sample_refs_tensor)) # Warining "-" minus before
print('my sdr:', sdr_loss(sample_mix, sample_refs), '\n')
# print('')

print('my sisnr:', sisnr_pit(sample_mix, sample_refs).item())
print('reversed:', sisnr_pit(r_sample_mix, sample_refs).item())
sisnr = ScaleInvariantSignalNoiseRatio().to('cuda')
sample_mix_tensor = sample_mix_tensor.to('cuda')
sample_refs_tensor = sample_refs_tensor.to('cuda')
sisnr.update(sample_mix_tensor, sample_refs_tensor)
TEMP = sisnr.compute()
print('sisnr taudio:', - TEMP.item()) # Warining "-" minus before

taudio: tensor(17.5640, device='cuda:0')
another taudio: tensor(18.1067)
my sdr: tensor(3.0031) 

my sisnr: 47.03880310058594
reversed: 47.03880310058594
sisnr taudio: 51.2036018371582


In [2]:
seed = 44
torch.manual_seed(seed)

batch = 1
spk = 2
time = 32000
sample_mix = [torch.randn(batch, time) for _ in range(spk)]
sample_refs = [torch.randn(batch, time) for _ in range(spk)]

# print(sample_mix[0].shape)

# Функция для смешивания тензоров с заданной схожестью
def mix_tensors_with_similarity(tensor1, tensor2, unsimilarity=0.05):
    # Векторная интерполяция между tensor1 и tensor2
    return unsimilarity * tensor1 + (1 - unsimilarity) * tensor2

# Применяем смешивание для всех спикеров
sample_mix_similar = [
    mix_tensors_with_similarity(m, r, unsimilarity=0.000000000000000000000000000000001)
    for m, r in zip(sample_mix, sample_refs)
]

sample_mix = sample_mix_similar
r_sample_mix = sample_mix[::-1]

sample_mix_tensor = torch.stack(sample_mix, dim=1)  # по оси 1 (spk) ожидаемый вывод: torch.Size([1, 2, 3200])
sample_refs_tensor = torch.stack(sample_refs, dim=1)  # по оси 1 (spk)  # ожидаемый вывод: torch.Size([1, 2, 3200])
r_sample_mix_tensor = torch.stack(r_sample_mix, dim=1)  # по оси 1 (spk) ожидаемый вывод: torch.Size([1, 2, 3200])

pit2 = PermutationInvariantTraining(signal_distortion_ratio, mode="speaker-wise", eval_func="max")
print('taudio:', - pit2(r_sample_mix_tensor, sample_refs_tensor)) # Warining "-" minus before

taudio: tensor(-138.7370)


Вердикт - использовать мою функцию в качестве sisnr и использовать torchaudio для sdr. 

Интересный факт - функция sdr ведет себя sdr(x) -> min как sisnr (x) -> min.

In [1]:
import argparse
import sys

import torch
from torchmetrics.audio import PermutationInvariantTraining as PIT
from torchmetrics.functional.audio import signal_distortion_ratio as sdr
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio as sisnr
# from torchmetrics.audio import ScaleInvariantSignalNoiseRatio as sisnr
from tqdm.notebook import tqdm

# from losses import sisnr_pit
from utils.load_config import load_config  
from models import MODELS
from data.DiarizationDataset import DiarizationDataset

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/dev_dualpathrnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  
cfg = load_config(args.hparams)

In [2]:
model_class = MODELS[cfg['xp_config']['model_type']]
model = model_class(**cfg['model'])
device = cfg['trainer']['device']
model.to(device)
datamodule = DiarizationDataset(**cfg['data']).setup(stage = 'eval')
test_dataloader = datamodule.test_dataloader()

Size of test set: 848
Elapsed time 'setup': 00:00:00.72


In [3]:
weight = './weights/__DualPath_RNN_179_-3.1895.pt'
dicts = torch.load(weight, map_location=device, weights_only=False)
model.load_state_dict(dicts['model_state_dict'])

model.eval()
pit_sdr = PIT(sdr).to(device)
pit_sisnr = PIT(sisnr).to(device)

loss_funcs = {name: PIT(func).to(cfg['trainer']['device']) for name, func in {"sisnr": sisnr, "sdr": sdr}.items()}

In [5]:
from utils.training import * 

running_sdr = 0.0
running_sisnr = 0.0
for inputs, labels in tqdm(test_dataloader):
    inputs, labels = inputs.to(device), [l.to(device) for l in labels]
    with torch.no_grad():
        outputs = [s.detach() for s in model(inputs)]
        outputs, labels = tensify(outputs).to(device), tensify(labels).to(device) 
        losses = {'sisnr': - loss_funcs['sisnr'](outputs, labels),
                  'sdr': - loss_funcs['sdr'](outputs, labels)} 
        running_sisnr += losses['sisnr'].item()
        running_sdr += losses['sdr'].item()
        pit_sdr.update(outputs, labels)
        pit_sisnr.update(outputs, labels)

print('sdr', running_sdr / len(test_dataloader))
print('sisnr', running_sisnr / len(test_dataloader))

  0%|          | 0/848 [00:00<?, ?it/s]

sdr -5.2180030130174995
sisnr -4.563263170192686


In [None]:
# --------------------------------

#### Training.

In [1]:
import argparse
from pathlib import Path

import torch
from torch.utils.tensorboard import SummaryWriter as TensorBoard
from torchmetrics.audio import PermutationInvariantTraining as PIT
from torchmetrics.functional.audio import signal_distortion_ratio as sdr
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio as sisnr

from utils.load_config import load_config 
from utils.training import metadata_info, configure_optimizer
from models import MODELS
from trainer import Trainer
from data import DiarizationDataset

torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')

cfg_path = './configs/superior_sepformer.yml'
# Loading config file    
cfg = load_config(cfg_path)
# Load data 
datamodule = DiarizationDataset(**cfg['data']).setup(stage = 'train')
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}
# Load model
model_class = MODELS[cfg['trainer']['model_name']]
model = model_class(**cfg['model'])
# Meta-data
metadata_info(model)
# TensorBoard
writer = TensorBoard(f'tb_logs/{Path(cfg_path).stem}', comment = f"{cfg['trainer']['ckpt_folder']}")
# Optimizer
optimizer = configure_optimizer (cfg, model)
# Loss and metrics
loss_funcs = {name: PIT(func).to(cfg['trainer']['device']) for name, func in {"sisnr": sisnr, "sdr": sdr}.items()}

Size of training set: 135
Size of validation set: 15
Elapsed time 'setup': 00:00:00.12




Trainable parametrs: 68353025
Size of model: 260.75 MB, in float32
--------------------------------------------------------------------


In [2]:
import os

import torch
from tqdm.notebook import tqdm

from utils.measure_time import measure_time
from utils.checkpointer import Checkpointer
from utils.training import * 

class TrainerTest:
    def __init__(self, epochs = 100, device='cuda', best_weights = False, checkpointing = False, 
                 checkpoint_interval = 10, model_name = '', trained_model = './', path_to_weights= './weights', 
                 ckpt_folder = '', speaker_num = 2, alpha = 0.5, beta = 0.5) -> None:
        self.epochs = epochs
        self.device = device
        self.best_weights = best_weights
        self.ckpointer = Checkpointer(model_name, path_to_weights, ckpt_folder)
        self.checkpointing = checkpointing
        self.checkpoint_interval = checkpoint_interval
        self.model_name = model_name
        os.makedirs(path_to_weights, exist_ok=True)
        self.path_to_weights = path_to_weights
        self.ckpt_folder = ckpt_folder
        self.speaker_num = speaker_num
        self.trained_model = trained_model
        self.alpha = alpha
        self.beta = beta

    @measure_time
    def fit(self, model, dataloaders, criterions, optimizer, writer) -> None:
        model.to(self.device)
        start_epoch, min_val_loss, model, optimizer = self.load_pretrained_model(model, optimizer)
        epoch_state = EpochState(metrics = criterions, epochs=self.epochs)
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
        for epoch in range(start_epoch, self.epochs):
            for phase in ['train', 'valid']:
                model.train() if phase == 'train' else model.eval()
                dataloader = dataloaders[phase] 
                for inputs, labels in tqdm(dataloader):
                    inputs, labels = inputs.to(self.device), [l.to(self.device) for l in labels]
                    '''
                    batch = 1 spk = 2 time = 16000
                    inputs: [batch, time] 
                    outputs and labels: [torch.randn(batch, time) for _ in range(spk)]  
                    expectention outputs and labels for torch audio-loss: torch.Size([batch, spk, time])
                    '''
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        labels = tensify(labels).to(self.device) 
                        outputs = tensify(outputs).to(self.device) if not isinstance(outputs, torch.Tensor) else outputs.to(self.device)
                        losses = {'sisnr': - criterions['sisnr'](outputs, labels),
                                  'sdr': - criterions['sdr'](outputs, labels)}  
                        loss = self.alpha * losses['sisnr'] + self.beta * losses['sdr']
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
                            optimizer.step()
                    epoch_state.update_loss(phase, loss)
                    epoch_state.update_metrics(phase, losses)
                epoch_loss = epoch_state.compute_loss(phase, len(dataloader))
                epoch_metrics = epoch_state.compute_metrics(phase, len(dataloader))
                epoch_state.p_output(epoch, phase)
                if phase == 'valid':
                    if self.best_weights and epoch_loss < min_val_loss:
                        min_val_loss = epoch_loss
                        self.ckpointer.save_best_weight(model, optimizer, epoch, epoch_state)
                    # scheduler.step(epoch_loss)
            torch_logger(writer, epoch, epoch_state)
            if self.checkpointing and (epoch + 1) % self.checkpoint_interval == 0:
                self.ckpointer.save_checkpoint(model, optimizer, epoch, epoch_state)
            epoch_state.reset_state()

    def load_pretrained_model(self, model, optimizer):
        if self.trained_model:
            print(f"Load pretrained mode: {self.trained_model}", '\n')
            checkpoint = torch.load(self.trained_model, map_location=self.device, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            return checkpoint['epoch'] + 1, checkpoint['val_loss'] , model, optimizer
        else:
            return 0, float('inf'), model, optimizer

In [3]:
# Train
TrainerTest(**cfg['trainer']).fit(model, 
                                  dataloaders, 
                                  loss_funcs, 
                                  optimizer, 
                                  writer)

  0%|          | 0/135 [00:00<?, ?it/s]

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/900
TRAIN, loss: -0.3161 | sisnr: 0.0135 | sdr: -0.7190 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.1635 | sisnr: 0.3019 | sdr: -0.7322 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 2/900
TRAIN, loss: -0.7614 | sisnr: -0.3903 | sdr: -1.2151 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.2815 | sisnr: 0.1793 | sdr: -0.8447 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 3/900
TRAIN, loss: -0.8586 | sisnr: -0.4591 | sdr: -1.3470 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.2960 | sisnr: 0.1696 | sdr: -0.8650 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 4/900
TRAIN, loss: -0.8920 | sisnr: -0.4923 | sdr: -1.3806 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.3326 | sisnr: 0.1261 | sdr: -0.8933 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 5/900
TRAIN, loss: -0.9191 | sisnr: -0.4940 | sdr: -1.4387 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.3567 | sisnr: 0.1154 | sdr: -0.9337 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 6/900
TRAIN, loss: -0.9458 | sisnr: -0.5177 | sdr: -1.4691 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.3752 | sisnr: 0.0825 | sdr: -0.9348 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 7/900
TRAIN, loss: -0.9531 | sisnr: -0.5374 | sdr: -1.4612 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.3715 | sisnr: 0.1072 | sdr: -0.9566 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 8/900
TRAIN, loss: -0.9853 | sisnr: -0.5447 | sdr: -1.5238 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.3959 | sisnr: 0.0999 | sdr: -1.0018 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 9/900
TRAIN, loss: -1.0064 | sisnr: -0.5465 | sdr: -1.5684 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4099 | sisnr: 0.0884 | sdr: -1.0190 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 10/900
TRAIN, loss: -1.0274 | sisnr: -0.5580 | sdr: -1.6011 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4200 | sisnr: 0.0773 | sdr: -1.0278 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 11/900
TRAIN, loss: -1.0346 | sisnr: -0.5706 | sdr: -1.6017 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4268 | sisnr: 0.0568 | sdr: -1.0178 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 12/900
TRAIN, loss: -1.0476 | sisnr: -0.5707 | sdr: -1.6306 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4367 | sisnr: 0.0596 | sdr: -1.0434 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 13/900
TRAIN, loss: -1.0607 | sisnr: -0.5788 | sdr: -1.6497 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4427 | sisnr: 0.0481 | sdr: -1.0425 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 14/900
TRAIN, loss: -1.0694 | sisnr: -0.5810 | sdr: -1.6664 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4393 | sisnr: 0.0611 | sdr: -1.0510 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 15/900
TRAIN, loss: -1.0774 | sisnr: -0.5871 | sdr: -1.6766 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4430 | sisnr: 0.0479 | sdr: -1.0431 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 16/900
TRAIN, loss: -1.0699 | sisnr: -0.5840 | sdr: -1.6639 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4397 | sisnr: 0.0498 | sdr: -1.0380 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 17/900
TRAIN, loss: -1.0861 | sisnr: -0.5981 | sdr: -1.6827 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4410 | sisnr: 0.0546 | sdr: -1.0468 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 18/900
TRAIN, loss: -1.0960 | sisnr: -0.5993 | sdr: -1.7031 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4376 | sisnr: 0.0637 | sdr: -1.0504 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 19/900
TRAIN, loss: -1.0999 | sisnr: -0.6006 | sdr: -1.7101 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4386 | sisnr: 0.0612 | sdr: -1.0494 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 20/900
TRAIN, loss: -1.1075 | sisnr: -0.6091 | sdr: -1.7166 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4481 | sisnr: 0.0436 | sdr: -1.0490 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 21/900
TRAIN, loss: -1.1129 | sisnr: -0.6118 | sdr: -1.7254 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4449 | sisnr: 0.0486 | sdr: -1.0481 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 22/900
TRAIN, loss: -1.1169 | sisnr: -0.6136 | sdr: -1.7322 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4495 | sisnr: 0.0440 | sdr: -1.0526 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 23/900
TRAIN, loss: -1.1237 | sisnr: -0.6186 | sdr: -1.7410 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4488 | sisnr: 0.0434 | sdr: -1.0505 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 24/900
TRAIN, loss: -1.1257 | sisnr: -0.6146 | sdr: -1.7503 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4539 | sisnr: 0.0304 | sdr: -1.0459 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 25/900
TRAIN, loss: -1.1303 | sisnr: -0.6235 | sdr: -1.7498 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4492 | sisnr: 0.0424 | sdr: -1.0501 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 26/900
TRAIN, loss: -1.1349 | sisnr: -0.6257 | sdr: -1.7572 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4557 | sisnr: 0.0291 | sdr: -1.0481 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

Epoch 27/900
TRAIN, loss: -1.1394 | sisnr: -0.6288 | sdr: -1.7635 |


  0%|          | 0/15 [00:00<?, ?it/s]

VALID, loss: -0.4580 | sisnr: 0.0256 | sdr: -1.0490 |
--------------------------------------------------------------------


  0%|          | 0/135 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [4]:
from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner

# Оборачиваем модель в Trainer
trainer = Trainer(accelerator="gpu", max_epochs=1)
tuner = Tuner(trainer)

# Поиск лучшего learning rate
lr_finder = tuner.lr_find(model, train_dataloaders=dataloaders['train'])

# Оптимальный learning rate
best_lr = lr_finder.suggestion()
print("Оптимальный learning rate:", best_lr)

# График зависимости loss от lr
lr_finder.plot(show=True)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `SuperiorSepformer`

In [4]:
import torch
import torch.optim as optim

optimizer = optim.Adam([torch.tensor(1.0, requires_grad=True)], lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

# Передача обычного числа
scheduler.step(0.05)

# Передача скалярного тензора
metric = torch.tensor(0.05)
scheduler.step(metric)


#### Ligtning

In [None]:
# Mel sepformer 

#### Saving

In [1]:
import csv

log = """Epoch 1/200
TRAIN, loss: -1.2249 | sisnr: -0.8991 | sdr: -1.6231 |
VALID, loss: -1.5157 | sisnr: -1.2375 | sdr: -1.8557 |
--------------------------------------------------------------------
Epoch 2/200
TRAIN, loss: -2.1142 | sisnr: -1.7203 | sdr: -2.5957 |
VALID, loss: -2.2680 | sisnr: -1.9200 | sdr: -2.6934 |
--------------------------------------------------------------------
Epoch 3/200
TRAIN, loss: -2.8568 | sisnr: -2.4274 | sdr: -3.3815 |
VALID, loss: -3.0377 | sisnr: -2.6459 | sdr: -3.5167 |
--------------------------------------------------------------------
Epoch 4/200
TRAIN, loss: -3.5181 | sisnr: -3.0665 | sdr: -4.0700 |
VALID, loss: -1.9302 | sisnr: -1.6749 | sdr: -2.2423 |
--------------------------------------------------------------------
Epoch 5/200
TRAIN, loss: -3.8550 | sisnr: -3.4071 | sdr: -4.4024 |
VALID, loss: -3.3887 | sisnr: -3.0413 | sdr: -3.8132 |
--------------------------------------------------------------------
Epoch 6/200
TRAIN, loss: -4.1481 | sisnr: -3.7050 | sdr: -4.6898 |
VALID, loss: -3.8374 | sisnr: -3.4440 | sdr: -4.3182 |
--------------------------------------------------------------------
Epoch 7/200
TRAIN, loss: -4.8348 | sisnr: -4.3830 | sdr: -5.3870 |
VALID, loss: -4.7052 | sisnr: -4.2890 | sdr: -5.2138 |
--------------------------------------------------------------------
Epoch 8/200
TRAIN, loss: -5.3020 | sisnr: -4.8613 | sdr: -5.8406 |
VALID, loss: -5.3846 | sisnr: -4.9074 | sdr: -5.9679 |
--------------------------------------------------------------------
Epoch 9/200
TRAIN, loss: -5.7526 | sisnr: -5.3066 | sdr: -6.2976 |
VALID, loss: -5.4953 | sisnr: -5.0003 | sdr: -6.1002 |
--------------------------------------------------------------------
Epoch 10/200
TRAIN, loss: -6.1007 | sisnr: -5.6617 | sdr: -6.6374 |
VALID, loss: -5.9594 | sisnr: -5.4607 | sdr: -6.5691 |
--------------------------------------------------------------------
Epoch 11/200
TRAIN, loss: -6.2755 | sisnr: -5.8382 | sdr: -6.8100 |
VALID, loss: -4.3845 | sisnr: -3.9057 | sdr: -4.9697 |
--------------------------------------------------------------------
Epoch 12/200
TRAIN, loss: -6.6196 | sisnr: -6.1841 | sdr: -7.1518 |
VALID, loss: -6.0972 | sisnr: -5.6632 | sdr: -6.6278 |
--------------------------------------------------------------------
Epoch 13/200
TRAIN, loss: -6.8672 | sisnr: -6.4381 | sdr: -7.3917 |
VALID, loss: -6.1875 | sisnr: -5.7395 | sdr: -6.7350 |
--------------------------------------------------------------------
Epoch 14/200
TRAIN, loss: -6.7335 | sisnr: -6.3104 | sdr: -7.2506 |
VALID, loss: -5.9863 | sisnr: -5.5417 | sdr: -6.5297 |
--------------------------------------------------------------------
Epoch 15/200
TRAIN, loss: -6.9610 | sisnr: -6.5406 | sdr: -7.4748 |
VALID, loss: -6.2609 | sisnr: -5.7833 | sdr: -6.8445 |
--------------------------------------------------------------------
Epoch 16/200
TRAIN, loss: -6.9797 | sisnr: -6.5590 | sdr: -7.4939 |
VALID, loss: -6.7391 | sisnr: -6.3142 | sdr: -7.2583 |
--------------------------------------------------------------------
Epoch 17/200
TRAIN, loss: -7.2018 | sisnr: -6.7872 | sdr: -7.7085 |
VALID, loss: -7.0540 | sisnr: -6.6068 | sdr: -7.6005 |
--------------------------------------------------------------------
Epoch 18/200
TRAIN, loss: -7.0898 | sisnr: -6.6783 | sdr: -7.5928 |
VALID, loss: -6.8033 | sisnr: -6.3702 | sdr: -7.3327 |
--------------------------------------------------------------------
Epoch 19/200
TRAIN, loss: -7.3717 | sisnr: -6.9605 | sdr: -7.8743 |
VALID, loss: -6.8814 | sisnr: -6.4557 | sdr: -7.4017 |
--------------------------------------------------------------------
Epoch 20/200
TRAIN, loss: -6.9482 | sisnr: -6.5361 | sdr: -7.4518 |
VALID, loss: -6.6264 | sisnr: -6.2075 | sdr: -7.1385 |
--------------------------------------------------------------------
Epoch 21/200
TRAIN, loss: -7.2933 | sisnr: -6.8850 | sdr: -7.7924 |
VALID, loss: -6.6639 | sisnr: -6.2204 | sdr: -7.2059 |
--------------------------------------------------------------------
Epoch 22/200
TRAIN, loss: -6.4923 | sisnr: -6.0802 | sdr: -6.9958 |
VALID, loss: -6.2044 | sisnr: -5.7849 | sdr: -6.7172 |
--------------------------------------------------------------------
Epoch 23/200
TRAIN, loss: -6.8907 | sisnr: -6.4842 | sdr: -7.3876 |
VALID, loss: -6.1372 | sisnr: -5.7200 | sdr: -6.6471 |
--------------------------------------------------------------------
Epoch 24/200
TRAIN, loss: -7.1909 | sisnr: -6.7829 | sdr: -7.6894 |
VALID, loss: -6.3481 | sisnr: -5.9461 | sdr: -6.8395 |
--------------------------------------------------------------------
Epoch 25/200
TRAIN, loss: -7.2353 | sisnr: -6.8292 | sdr: -7.7316 |
VALID, loss: -6.5742 | sisnr: -6.1763 | sdr: -7.0606 |
--------------------------------------------------------------------
Epoch 26/200
TRAIN, loss: -7.3218 | sisnr: -6.9156 | sdr: -7.8183 |
VALID, loss: -6.9426 | sisnr: -6.5027 | sdr: -7.4803 |
--------------------------------------------------------------------
Epoch 27/200
TRAIN, loss: -7.4863 | sisnr: -7.0801 | sdr: -7.9826 |
VALID, loss: -6.6326 | sisnr: -6.2051 | sdr: -7.1552 |
--------------------------------------------------------------------
Epoch 28/200
TRAIN, loss: -7.5068 | sisnr: -7.1005 | sdr: -8.0033 |
VALID, loss: -6.6399 | sisnr: -6.2279 | sdr: -7.1434 |
--------------------------------------------------------------------
Epoch 29/200
TRAIN, loss: -7.6027 | sisnr: -7.1993 | sdr: -8.0958 |
VALID, loss: -7.1478 | sisnr: -6.7268 | sdr: -7.6625 |
--------------------------------------------------------------------
Epoch 30/200
TRAIN, loss: -7.2847 | sisnr: -6.8840 | sdr: -7.7743 |
VALID, loss: -6.6978 | sisnr: -6.2879 | sdr: -7.1987 |
--------------------------------------------------------------------
Epoch 31/200
TRAIN, loss: -6.9961 | sisnr: -6.5969 | sdr: -7.4840 |
VALID, loss: -6.5984 | sisnr: -6.2147 | sdr: -7.0673 |
--------------------------------------------------------------------
Epoch 32/200
TRAIN, loss: -7.1961 | sisnr: -6.7966 | sdr: -7.6844 |
VALID, loss: -7.0008 | sisnr: -6.6004 | sdr: -7.4903 |
--------------------------------------------------------------------
Epoch 33/200
TRAIN, loss: -7.6189 | sisnr: -7.2205 | sdr: -8.1058 |
VALID, loss: -5.9690 | sisnr: -5.5658 | sdr: -6.4618 |
--------------------------------------------------------------------
Epoch 34/200
TRAIN, loss: -7.2954 | sisnr: -6.8985 | sdr: -7.7805 |
VALID, loss: -7.1037 | sisnr: -6.6875 | sdr: -7.6125 |
--------------------------------------------------------------------
Epoch 35/200
TRAIN, loss: -7.6552 | sisnr: -7.2550 | sdr: -8.1444 |
VALID, loss: -6.8849 | sisnr: -6.4699 | sdr: -7.3922 |
--------------------------------------------------------------------
Epoch 36/200
TRAIN, loss: -7.7905 | sisnr: -7.3916 | sdr: -8.2781 |
VALID, loss: -6.8013 | sisnr: -6.3619 | sdr: -7.3383 |
--------------------------------------------------------------------
Epoch 37/200
TRAIN, loss: -7.3751 | sisnr: -6.9751 | sdr: -7.8640 |
VALID, loss: -6.4656 | sisnr: -6.0467 | sdr: -6.9776 |
--------------------------------------------------------------------
Epoch 38/200
TRAIN, loss: -6.8159 | sisnr: -6.4201 | sdr: -7.2996 |
VALID, loss: -6.6676 | sisnr: -6.2465 | sdr: -7.1822 |
--------------------------------------------------------------------
Epoch 39/200
TRAIN, loss: -7.1392 | sisnr: -6.7419 | sdr: -7.6247 |
VALID, loss: -5.9670 | sisnr: -5.5696 | sdr: -6.4527 |
--------------------------------------------------------------------
Epoch 40/200
TRAIN, loss: -6.8061 | sisnr: -6.4092 | sdr: -7.2911 |
VALID, loss: -6.6983 | sisnr: -6.2873 | sdr: -7.2008 |
--------------------------------------------------------------------
Epoch 41/200
TRAIN, loss: -7.0143 | sisnr: -6.6188 | sdr: -7.4978 |
VALID, loss: -6.8773 | sisnr: -6.4693 | sdr: -7.3761 |
--------------------------------------------------------------------
Epoch 42/200
TRAIN, loss: -7.1242 | sisnr: -6.7291 | sdr: -7.6072 |
VALID, loss: -6.0281 | sisnr: -5.5868 | sdr: -6.5673 |
--------------------------------------------------------------------
Epoch 43/200
TRAIN, loss: -6.9717 | sisnr: -6.5765 | sdr: -7.4546 |
VALID, loss: -6.5418 | sisnr: -6.1353 | sdr: -7.0387 |
--------------------------------------------------------------------
Epoch 44/200
TRAIN, loss: -7.0707 | sisnr: -6.6763 | sdr: -7.5527 |
VALID, loss: -5.9960 | sisnr: -5.6278 | sdr: -6.4460 |
--------------------------------------------------------------------
Epoch 45/200
TRAIN, loss: -7.0334 | sisnr: -6.6415 | sdr: -7.5124 |
VALID, loss: -6.8909 | sisnr: -6.4769 | sdr: -7.3970 |
--------------------------------------------------------------------
Epoch 46/200
TRAIN, loss: -7.2047 | sisnr: -6.8106 | sdr: -7.6865 |
VALID, loss: -6.2342 | sisnr: -5.8445 | sdr: -6.7105 |
--------------------------------------------------------------------
Epoch 47/200
TRAIN, loss: -7.3047 | sisnr: -6.9177 | sdr: -7.7778 |
VALID, loss: -6.0995 | sisnr: -5.7139 | sdr: -6.5708 |
--------------------------------------------------------------------
Epoch 48/200
TRAIN, loss: -7.3005 | sisnr: -6.9063 | sdr: -7.7822 |
VALID, loss: -6.4764 | sisnr: -6.0901 | sdr: -6.9485 |
--------------------------------------------------------------------
Epoch 49/200
TRAIN, loss: -6.9525 | sisnr: -6.5609 | sdr: -7.4312 |
VALID, loss: -5.4282 | sisnr: -5.0278 | sdr: -5.9175 |
--------------------------------------------------------------------
Epoch 50/200
TRAIN, loss: -6.8184 | sisnr: -6.4258 | sdr: -7.2983 |
VALID, loss: -5.9082 | sisnr: -5.5213 | sdr: -6.3811 |
--------------------------------------------------------------------
Epoch 51/200
TRAIN, loss: -6.7916 | sisnr: -6.3992 | sdr: -7.2712 |
VALID, loss: -6.0780 | sisnr: -5.6737 | sdr: -6.5721 |
--------------------------------------------------------------------
Epoch 52/200
TRAIN, loss: -7.0879 | sisnr: -6.6968 | sdr: -7.5660 |
VALID, loss: -6.5364 | sisnr: -6.1449 | sdr: -7.0149 |
--------------------------------------------------------------------
Epoch 53/200
TRAIN, loss: -7.3054 | sisnr: -6.9127 | sdr: -7.7855 |
VALID, loss: -6.5629 | sisnr: -6.1583 | sdr: -7.0575 |
--------------------------------------------------------------------
Epoch 54/200
TRAIN, loss: -7.3261 | sisnr: -6.9349 | sdr: -7.8043 |
VALID, loss: -6.4022 | sisnr: -5.9659 | sdr: -6.9355 |
--------------------------------------------------------------------
Epoch 55/200
TRAIN, loss: -7.3064 | sisnr: -6.9130 | sdr: -7.7872 |
VALID, loss: -6.0837 | sisnr: -5.6259 | sdr: -6.6434 |
--------------------------------------------------------------------
Epoch 56/200
TRAIN, loss: -6.8455 | sisnr: -6.4537 | sdr: -7.3243 |
VALID, loss: -6.5540 | sisnr: -6.1593 | sdr: -7.0363 |
--------------------------------------------------------------------
Epoch 57/200
TRAIN, loss: -7.1631 | sisnr: -6.7705 | sdr: -7.6429 |
VALID, loss: -6.5300 | sisnr: -6.1304 | sdr: -7.0184 |
--------------------------------------------------------------------
Epoch 58/200
TRAIN, loss: -7.2546 | sisnr: -6.8677 | sdr: -7.7275 |
VALID, loss: -6.5930 | sisnr: -6.2129 | sdr: -7.0575 |
--------------------------------------------------------------------
Epoch 59/200
TRAIN, loss: -7.2764 | sisnr: -6.8894 | sdr: -7.7494 |
VALID, loss: -6.0057 | sisnr: -5.6039 | sdr: -6.4967 |
--------------------------------------------------------------------
Epoch 60/200
TRAIN, loss: -7.2165 | sisnr: -6.8264 | sdr: -7.6933 |
VALID, loss: -6.3453 | sisnr: -5.9623 | sdr: -6.8135 |
--------------------------------------------------------------------
Epoch 61/200
TRAIN, loss: -7.2592 | sisnr: -6.8715 | sdr: -7.7331 |
VALID, loss: -6.6821 | sisnr: -6.2732 | sdr: -7.1818 |
--------------------------------------------------------------------
Epoch 62/200
TRAIN, loss: -7.3158 | sisnr: -6.9278 | sdr: -7.7902 |
VALID, loss: -6.4328 | sisnr: -6.0531 | sdr: -6.8969 |
--------------------------------------------------------------------
Epoch 63/200
TRAIN, loss: -7.4211 | sisnr: -7.0336 | sdr: -7.8946 |
VALID, loss: -6.3253 | sisnr: -5.9053 | sdr: -6.8387 |
--------------------------------------------------------------------
Epoch 64/200
TRAIN, loss: -7.6945 | sisnr: -7.3038 | sdr: -8.1722 |
VALID, loss: -7.0419 | sisnr: -6.6333 | sdr: -7.5412 |
--------------------------------------------------------------------
Epoch 65/200
TRAIN, loss: -7.2459 | sisnr: -6.8547 | sdr: -7.7239 |
VALID, loss: -6.8539 | sisnr: -6.4524 | sdr: -7.3446 |
--------------------------------------------------------------------
Epoch 66/200
TRAIN, loss: -7.5394 | sisnr: -7.1479 | sdr: -8.0179 |
VALID, loss: -6.8390 | sisnr: -6.4283 | sdr: -7.3409 |
--------------------------------------------------------------------
Epoch 67/200
TRAIN, loss: -7.5297 | sisnr: -7.1386 | sdr: -8.0079 |
VALID, loss: -6.8479 | sisnr: -6.4400 | sdr: -7.3465 |
--------------------------------------------------------------------
Epoch 68/200
TRAIN, loss: -7.6717 | sisnr: -7.2822 | sdr: -8.1478 |
VALID, loss: -6.9980 | sisnr: -6.5962 | sdr: -7.4892 |
--------------------------------------------------------------------
Epoch 69/200
TRAIN, loss: -7.4379 | sisnr: -7.0484 | sdr: -7.9140 |
VALID, loss: -6.8589 | sisnr: -6.4631 | sdr: -7.3427 |
--------------------------------------------------------------------
Epoch 70/200
TRAIN, loss: -7.7183 | sisnr: -7.3332 | sdr: -8.1890 |
VALID, loss: -6.7430 | sisnr: -6.3451 | sdr: -7.2293 |
--------------------------------------------------------------------
Epoch 71/200
TRAIN, loss: -7.0969 | sisnr: -6.7109 | sdr: -7.5686 |
VALID, loss: -6.6540 | sisnr: -6.2768 | sdr: -7.1150 |
--------------------------------------------------------------------
Epoch 72/200
TRAIN, loss: -7.6091 | sisnr: -7.2234 | sdr: -8.0805 |
VALID, loss: -6.9562 | sisnr: -6.5730 | sdr: -7.4246 |
--------------------------------------------------------------------
Epoch 73/200
TRAIN, loss: -7.7031 | sisnr: -7.3183 | sdr: -8.1734 |
VALID, loss: -7.0216 | sisnr: -6.6318 | sdr: -7.4980 |
--------------------------------------------------------------------
Epoch 74/200
TRAIN, loss: -7.6321 | sisnr: -7.2469 | sdr: -8.1028 |
VALID, loss: -7.0983 | sisnr: -6.6833 | sdr: -7.6055 |
--------------------------------------------------------------------
Epoch 75/200
TRAIN, loss: -7.8932 | sisnr: -7.5075 | sdr: -8.3645 |
VALID, loss: -6.9401 | sisnr: -6.5485 | sdr: -7.4189 |
--------------------------------------------------------------------
Epoch 76/200
TRAIN, loss: -7.6400 | sisnr: -7.2543 | sdr: -8.1113 |
VALID, loss: -7.0621 | sisnr: -6.6855 | sdr: -7.5224 |
--------------------------------------------------------------------
Epoch 77/200
TRAIN, loss: -7.5100 | sisnr: -7.1234 | sdr: -7.9826 |
TRAIN, loss: -7.5100 | sisnr: -7.1234 | sdr: -7.9826 |
VALID, loss: -6.0354 | sisnr: -5.6314 | sdr: -6.5293 |
--------------------------------------------------------------------
Epoch 78/200
TRAIN, loss: -7.6612 | sisnr: -7.2725 | sdr: -8.1363 |
VALID, loss: -5.9898 | sisnr: -5.6051 | sdr: -6.4599 |
--------------------------------------------------------------------
Epoch 79/200
TRAIN, loss: -7.7874 | sisnr: -7.4005 | sdr: -8.2603 |
VALID, loss: -6.7497 | sisnr: -6.3433 | sdr: -7.2463 |
--------------------------------------------------------------------
Epoch 80/200
TRAIN, loss: -7.6624 | sisnr: -7.2790 | sdr: -8.1311 |
VALID, loss: -6.9575 | sisnr: -6.5531 | sdr: -7.4518 |
--------------------------------------------------------------------
Epoch 81/200
TRAIN, loss: -7.8274 | sisnr: -7.4419 | sdr: -8.2985 |
VALID, loss: -6.7766 | sisnr: -6.3641 | sdr: -7.2808 |
--------------------------------------------------------------------
Epoch 82/200
TRAIN, loss: -7.7192 | sisnr: -7.3342 | sdr: -8.1896 |
VALID, loss: -6.9869 | sisnr: -6.5718 | sdr: -7.4942 |
--------------------------------------------------------------------
Epoch 83/200
TRAIN, loss: -7.8552 | sisnr: -7.4725 | sdr: -8.3229 |
VALID, loss: -7.0355 | sisnr: -6.6367 | sdr: -7.5228 |
--------------------------------------------------------------------
Epoch 84/200
TRAIN, loss: -7.9783 | sisnr: -7.5960 | sdr: -8.4457 |
VALID, loss: -7.0789 | sisnr: -6.6744 | sdr: -7.5732 |
--------------------------------------------------------------------
Epoch 85/200
TRAIN, loss: -7.9747 | sisnr: -7.5897 | sdr: -8.4453 |
VALID, loss: -6.7453 | sisnr: -6.3609 | sdr: -7.2151 |
--------------------------------------------------------------------
Epoch 86/200
TRAIN, loss: -7.9535 | sisnr: -7.5678 | sdr: -8.4248 |
VALID, loss: -7.3051 | sisnr: -6.8965 | sdr: -7.8045 |
--------------------------------------------------------------------
Epoch 87/200
TRAIN, loss: -7.6815 | sisnr: -7.2987 | sdr: -8.1494 |
VALID, loss: -7.0143 | sisnr: -6.6212 | sdr: -7.4948 |
--------------------------------------------------------------------
Epoch 88/200
TRAIN, loss: -8.0541 | sisnr: -7.6712 | sdr: -8.5221 |
VALID, loss: -7.2238 | sisnr: -6.8435 | sdr: -7.6885 |
--------------------------------------------------------------------
Epoch 89/200
TRAIN, loss: -7.9671 | sisnr: -7.5838 | sdr: -8.4356 |
VALID, loss: -6.6541 | sisnr: -6.2596 | sdr: -7.1362 |
--------------------------------------------------------------------
Epoch 90/200
TRAIN, loss: -8.1201 | sisnr: -7.7383 | sdr: -8.5868 |
VALID, loss: -6.0590 | sisnr: -5.6782 | sdr: -6.5244 |
--------------------------------------------------------------------
Epoch 91/200
TRAIN, loss: -8.0352 | sisnr: -7.6509 | sdr: -8.5050 |
VALID, loss: -7.0325 | sisnr: -6.6073 | sdr: -7.5522 |
--------------------------------------------------------------------
Epoch 92/200
TRAIN, loss: -8.1673 | sisnr: -7.7834 | sdr: -8.6366 |
VALID, loss: -7.1992 | sisnr: -6.7821 | sdr: -7.7090 |
--------------------------------------------------------------------
Epoch 93/200
TRAIN, loss: -8.2769 | sisnr: -7.8938 | sdr: -8.7452 |
VALID, loss: -7.2632 | sisnr: -6.8655 | sdr: -7.7494 |
--------------------------------------------------------------------
Epoch 94/200
TRAIN, loss: -8.1995 | sisnr: -7.8177 | sdr: -8.6660 |
VALID, loss: -6.5032 | sisnr: -6.1030 | sdr: -6.9922 |
--------------------------------------------------------------------
Epoch 95/200
TRAIN, loss: -8.2936 | sisnr: -7.9100 | sdr: -8.7624 |
VALID, loss: -6.7545 | sisnr: -6.3291 | sdr: -7.2743 |
--------------------------------------------------------------------
Epoch 96/200
TRAIN, loss: -8.2408 | sisnr: -7.8563 | sdr: -8.7107 |
VALID, loss: -6.2832 | sisnr: -5.8946 | sdr: -6.7581 |
--------------------------------------------------------------------
Epoch 97/200
TRAIN, loss: -7.9001 | sisnr: -7.5181 | sdr: -8.3670 |
VALID, loss: -6.9372 | sisnr: -6.5408 | sdr: -7.4218 |
--------------------------------------------------------------------
Epoch 98/200
TRAIN, loss: -7.9848 | sisnr: -7.6017 | sdr: -8.4531 |
VALID, loss: -6.8823 | sisnr: -6.4932 | sdr: -7.3579 |
--------------------------------------------------------------------
Epoch 99/200
TRAIN, loss: -8.0258 | sisnr: -7.6422 | sdr: -8.4946 |
VALID, loss: -5.8188 | sisnr: -5.3969 | sdr: -6.3345 |
--------------------------------------------------------------------
Epoch 100/200
TRAIN, loss: -7.6943 | sisnr: -7.3112 | sdr: -8.1624 |
VALID, loss: -6.6938 | sisnr: -6.2715 | sdr: -7.2100 |
--------------------------------------------------------------------
Epoch 101/200
TRAIN, loss: -7.6996 | sisnr: -7.3164 | sdr: -8.1679 |
VALID, loss: -6.3166 | sisnr: -5.9113 | sdr: -6.8120 |
--------------------------------------------------------------------
Epoch 102/200
TRAIN, loss: -7.1173 | sisnr: -6.7295 | sdr: -7.5913 |
VALID, loss: -5.8846 | sisnr: -5.4347 | sdr: -6.4345 |
--------------------------------------------------------------------
Epoch 103/200
TRAIN, loss: -7.3787 | sisnr: -6.9937 | sdr: -7.8493 |
VALID, loss: -6.3741 | sisnr: -5.9597 | sdr: -6.8808 |
--------------------------------------------------------------------
Epoch 104/200
TRAIN, loss: -6.9537 | sisnr: -6.5728 | sdr: -7.4191 |
VALID, loss: -6.3376 | sisnr: -5.9371 | sdr: -6.8271 |
--------------------------------------------------------------------
Epoch 105/200
TRAIN, loss: -7.2484 | sisnr: -6.8671 | sdr: -7.7145 |
VALID, loss: -6.7083 | sisnr: -6.3107 | sdr: -7.1942 |
--------------------------------------------------------------------
Epoch 106/200
TRAIN, loss: -7.3348 | sisnr: -6.9560 | sdr: -7.7979 |
VALID, loss: -6.4565 | sisnr: -6.0385 | sdr: -6.9675 |
--------------------------------------------------------------------
Epoch 107/200
TRAIN, loss: -7.7559 | sisnr: -7.3724 | sdr: -8.2246 |
VALID, loss: -6.0838 | sisnr: -5.6240 | sdr: -6.6459 |
--------------------------------------------------------------------
Epoch 108/200
TRAIN, loss: -7.1114 | sisnr: -6.7297 | sdr: -7.5779 |
VALID, loss: -6.5929 | sisnr: -6.1884 | sdr: -7.0873 |
--------------------------------------------------------------------
Epoch 109/200
TRAIN, loss: -7.3454 | sisnr: -6.9632 | sdr: -7.8125 |
VALID, loss: -6.8711 | sisnr: -6.4629 | sdr: -7.3700 |
--------------------------------------------------------------------
Epoch 110/200
TRAIN, loss: -7.5979 | sisnr: -7.2150 | sdr: -8.0660 |
VALID, loss: -6.2920 | sisnr: -5.8743 | sdr: -6.8025 |
--------------------------------------------------------------------
Epoch 111/200
TRAIN, loss: -7.6962 | sisnr: -7.3106 | sdr: -8.1675 |
VALID, loss: -6.7929 | sisnr: -6.3969 | sdr: -7.2769 |
--------------------------------------------------------------------
Epoch 112/200
TRAIN, loss: -7.2625 | sisnr: -6.8755 | sdr: -7.7355 |
VALID, loss: -6.3301 | sisnr: -5.8886 | sdr: -6.8696 |
--------------------------------------------------------------------
Epoch 113/200
TRAIN, loss: -7.9230 | sisnr: -7.5383 | sdr: -8.3932 |
VALID, loss: -6.7914 | sisnr: -6.3652 | sdr: -7.3122 |
--------------------------------------------------------------------
Epoch 114/200
TRAIN, loss: -7.8121 | sisnr: -7.4290 | sdr: -8.2804 |
VALID, loss: -6.3975 | sisnr: -5.9889 | sdr: -6.8968 |
--------------------------------------------------------------------
Epoch 115/200
TRAIN, loss: -7.6951 | sisnr: -7.3125 | sdr: -8.1626 |
VALID, loss: -5.9344 | sisnr: -5.5220 | sdr: -6.4384 |
--------------------------------------------------------------------
Epoch 116/200
TRAIN, loss: -7.7713 | sisnr: -7.3847 | sdr: -8.2439 |
VALID, loss: -6.8493 | sisnr: -6.4312 | sdr: -7.3603 |
--------------------------------------------------------------------
Epoch 117/200
TRAIN, loss: -7.8300 | sisnr: -7.4492 | sdr: -8.2953 |
VALID, loss: -6.5118 | sisnr: -6.0818 | sdr: -7.0373 |
--------------------------------------------------------------------
Epoch 118/200
TRAIN, loss: -7.9795 | sisnr: -7.5934 | sdr: -8.4514 |
VALID, loss: -6.7924 | sisnr: -6.3916 | sdr: -7.2823 |
--------------------------------------------------------------------
Epoch 119/200
TRAIN, loss: -8.0084 | sisnr: -7.6226 | sdr: -8.4799 |
VALID, loss: -6.6711 | sisnr: -6.2539 | sdr: -7.1811 |
--------------------------------------------------------------------
Epoch 120/200
TRAIN, loss: -7.8534 | sisnr: -7.4683 | sdr: -8.3241 |
VALID, loss: -6.4449 | sisnr: -6.0283 | sdr: -6.9540 |
--------------------------------------------------------------------
Epoch 121/200
TRAIN, loss: -8.0260 | sisnr: -7.6428 | sdr: -8.4943 |
VALID, loss: -6.8365 | sisnr: -6.4296 | sdr: -7.3339 |
--------------------------------------------------------------------
Epoch 122/200
TRAIN, loss: -7.8491 | sisnr: -7.4638 | sdr: -8.3202 |
VALID, loss: -5.9352 | sisnr: -5.5267 | sdr: -6.4345 |
--------------------------------------------------------------------
Epoch 123/200
TRAIN, loss: -7.8827 | sisnr: -7.4982 | sdr: -8.3528 |
VALID, loss: -6.9906 | sisnr: -6.5833 | sdr: -7.4885 |
--------------------------------------------------------------------
Epoch 124/200
TRAIN, loss: -7.3348 | sisnr: -6.9511 | sdr: -7.8038 |
VALID, loss: -6.3729 | sisnr: -5.9697 | sdr: -6.8657 |
--------------------------------------------------------------------
Epoch 125/200
TRAIN, loss: -7.6897 | sisnr: -7.3034 | sdr: -8.1617 |
VALID, loss: -6.1440 | sisnr: -5.7435 | sdr: -6.6335 |
--------------------------------------------------------------------
Epoch 126/200
TRAIN, loss: -7.3681 | sisnr: -6.9853 | sdr: -7.8360 |
VALID, loss: -6.0060 | sisnr: -5.6043 | sdr: -6.4970 |
--------------------------------------------------------------------
Epoch 127/200
TRAIN, loss: -7.5552 | sisnr: -7.1711 | sdr: -8.0245 |
VALID, loss: -5.7100 | sisnr: -5.2787 | sdr: -6.2371 |
--------------------------------------------------------------------
Epoch 128/200
TRAIN, loss: -7.5386 | sisnr: -7.1554 | sdr: -8.0069 |
VALID, loss: -5.9418 | sisnr: -5.5494 | sdr: -6.4213 |
--------------------------------------------------------------------
Epoch 129/200
TRAIN, loss: -7.7473 | sisnr: -7.3618 | sdr: -8.2185 |
VALID, loss: -6.1912 | sisnr: -5.7833 | sdr: -6.6897 |
--------------------------------------------------------------------
Epoch 130/200
TRAIN, loss: -7.6480 | sisnr: -7.2639 | sdr: -8.1175 |
VALID, loss: -6.4299 | sisnr: -6.0130 | sdr: -6.9395 |
--------------------------------------------------------------------
Epoch 131/200
TRAIN, loss: -7.7825 | sisnr: -7.3999 | sdr: -8.2501 |
VALID, loss: -6.6271 | sisnr: -6.2172 | sdr: -7.1281 |
--------------------------------------------------------------------
Epoch 132/200
TRAIN, loss: -7.7215 | sisnr: -7.3357 | sdr: -8.1929 |
VALID, loss: -6.2635 | sisnr: -5.8425 | sdr: -6.7781 |
--------------------------------------------------------------------
Epoch 133/200
TRAIN, loss: -7.9973 | sisnr: -7.6103 | sdr: -8.4703 |
VALID, loss: -6.0068 | sisnr: -5.5926 | sdr: -6.5129 |
--------------------------------------------------------------------
Epoch 134/200
TRAIN, loss: -8.1496 | sisnr: -7.7636 | sdr: -8.6214 |
VALID, loss: -6.6899 | sisnr: -6.2729 | sdr: -7.1995 |
--------------------------------------------------------------------
Epoch 135/200
TRAIN, loss: -8.1855 | sisnr: -7.7990 | sdr: -8.6579 |
VALID, loss: -6.7587 | sisnr: -6.3518 | sdr: -7.2560 |
--------------------------------------------------------------------
Epoch 136/200
TRAIN, loss: -7.9097 | sisnr: -7.5250 | sdr: -8.3800 |
VALID, loss: -6.7963 | sisnr: -6.3830 | sdr: -7.3015 |
--------------------------------------------------------------------
Epoch 137/200
TRAIN, loss: -8.0404 | sisnr: -7.6522 | sdr: -8.5149 |
VALID, loss: -6.7986 | sisnr: -6.3898 | sdr: -7.2983 |
--------------------------------------------------------------------
Epoch 138/200
TRAIN, loss: -8.3324 | sisnr: -7.9441 | sdr: -8.8070 |
VALID, loss: -6.8945 | sisnr: -6.4916 | sdr: -7.3870 |
--------------------------------------------------------------------
Epoch 139/200
TRAIN, loss: -8.1466 | sisnr: -7.7608 | sdr: -8.6181 |
VALID, loss: -6.5352 | sisnr: -6.1392 | sdr: -7.0192 |
--------------------------------------------------------------------
Epoch 140/200
TRAIN, loss: -7.9028 | sisnr: -7.5165 | sdr: -8.3750 |
VALID, loss: -6.6968 | sisnr: -6.2977 | sdr: -7.1845 |
--------------------------------------------------------------------
Epoch 141/200
TRAIN, loss: -7.8107 | sisnr: -7.4245 | sdr: -8.2827 |
VALID, loss: -6.8545 | sisnr: -6.4426 | sdr: -7.3580 |
--------------------------------------------------------------------
Epoch 142/200
TRAIN, loss: -8.0738 | sisnr: -7.6862 | sdr: -8.5476 |
VALID, loss: -6.4748 | sisnr: -6.0490 | sdr: -6.9951 |
--------------------------------------------------------------------
Epoch 143/200
TRAIN, loss: -7.9945 | sisnr: -7.6057 | sdr: -8.4697 |
VALID, loss: -6.5958 | sisnr: -6.1803 | sdr: -7.1036 |
--------------------------------------------------------------------
Epoch 144/200
TRAIN, loss: -8.0198 | sisnr: -7.6299 | sdr: -8.4963 |
VALID, loss: -6.4638 | sisnr: -6.0384 | sdr: -6.9837 |
--------------------------------------------------------------------
Epoch 145/200
TRAIN, loss: -7.9533 | sisnr: -7.5647 | sdr: -8.4283 |
VALID, loss: -6.8246 | sisnr: -6.4188 | sdr: -7.3206 |
--------------------------------------------------------------------
Epoch 146/200
TRAIN, loss: -7.8598 | sisnr: -7.4764 | sdr: -8.3284 |
VALID, loss: -6.7586 | sisnr: -6.3251 | sdr: -7.2885 |
--------------------------------------------------------------------
Epoch 147/200
TRAIN, loss: -7.9914 | sisnr: -7.6009 | sdr: -8.4687 |
VALID, loss: -6.2392 | sisnr: -5.8401 | sdr: -6.7271 |
--------------------------------------------------------------------
Epoch 148/200
TRAIN, loss: -7.6684 | sisnr: -7.2773 | sdr: -8.1465 |
VALID, loss: -6.6775 | sisnr: -6.2730 | sdr: -7.1719 |
--------------------------------------------------------------------
Epoch 149/200
TRAIN, loss: -8.2045 | sisnr: -7.8134 | sdr: -8.6825 |
VALID, loss: -6.2450 | sisnr: -5.8413 | sdr: -6.7384 |
--------------------------------------------------------------------
Epoch 150/200
TRAIN, loss: -8.4640 | sisnr: -8.0738 | sdr: -8.9408 |
VALID, loss: -6.9657 | sisnr: -6.5629 | sdr: -7.4579 |
--------------------------------------------------------------------
Epoch 151/200
TRAIN, loss: -8.6822 | sisnr: -8.2884 | sdr: -9.1636 |
VALID, loss: -6.5461 | sisnr: -6.1303 | sdr: -7.0544 |
--------------------------------------------------------------------
Epoch 152/200
TRAIN, loss: -8.7924 | sisnr: -8.4005 | sdr: -9.2715 |
VALID, loss: -7.0991 | sisnr: -6.6759 | sdr: -7.6164 |
--------------------------------------------------------------------
Epoch 153/200
TRAIN, loss: -8.7358 | sisnr: -8.3453 | sdr: -9.2132 |
VALID, loss: -6.7246 | sisnr: -6.3066 | sdr: -7.2356 |
--------------------------------------------------------------------
Epoch 154/200
TRAIN, loss: -8.5136 | sisnr: -8.1239 | sdr: -8.9899 |
VALID, loss: -6.1220 | sisnr: -5.6851 | sdr: -6.6559 |
--------------------------------------------------------------------
Epoch 155/200
TRAIN, loss: -8.7194 | sisnr: -8.3283 | sdr: -9.1974 |
VALID, loss: -6.2451 | sisnr: -5.8485 | sdr: -6.7299 |
--------------------------------------------------------------------
Epoch 156/200
TRAIN, loss: -8.7451 | sisnr: -8.3554 | sdr: -9.2215 |
VALID, loss: -7.2331 | sisnr: -6.8170 | sdr: -7.7418 |
--------------------------------------------------------------------
Epoch 157/200
TRAIN, loss: -9.0217 | sisnr: -8.6279 | sdr: -9.5029 |
VALID, loss: -7.3077 | sisnr: -6.8876 | sdr: -7.8212 |
--------------------------------------------------------------------
Epoch 158/200
TRAIN, loss: -8.5982 | sisnr: -8.2055 | sdr: -9.0782 |
VALID, loss: -6.9922 | sisnr: -6.5698 | sdr: -7.5084 |
--------------------------------------------------------------------
Epoch 159/200
TRAIN, loss: -7.9755 | sisnr: -7.5833 | sdr: -8.4549 |
VALID, loss: -7.0194 | sisnr: -6.5965 | sdr: -7.5362 |
--------------------------------------------------------------------
Epoch 160/200
TRAIN, loss: -8.3852 | sisnr: -7.9940 | sdr: -8.8632 |
VALID, loss: -7.0532 | sisnr: -6.6289 | sdr: -7.5718 |
--------------------------------------------------------------------
Epoch 161/200
TRAIN, loss: -8.7058 | sisnr: -8.3137 | sdr: -9.1850 |
VALID, loss: -6.7074 | sisnr: -6.2923 | sdr: -7.2147 |
--------------------------------------------------------------------"""

In [4]:
def _logger(logs, stage, index=0, csv_filename = './'):
    lines = log.split("\n")
    data = []
    
    for line in lines:
        if line.startswith("Epoch"):
            epoch = int(line.split()[1].split('/')[0]) - 1
        elif line.startswith(stage):
            val = float(line.split("|")[index].split(":")[1].strip())
            data.append([1735105736, epoch, val])
            
    with open(csv_filename, mode="w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["Wall time", "Step", "Value"])  
        writer.writerows(data)

    print(f"Данные успешно сохранены в файл {csv_filename}")

In [6]:
_logger(log, 
        "TRAIN", 
        index = 1,
        csv_filename = "./checkpoints/dualpathrnn/train_sisnr_160_epoch.csv")

Данные успешно сохранены в файл ./checkpoints/dualpathrnn/train_sisnr_160_epoch.csv


In [8]:
_logger(log, 
        "TRAIN", 
        index = 2,
        csv_filename = "./checkpoints/dualpathrnn/train_sdr_160_epoch.csv")

Данные успешно сохранены в файл ./checkpoints/dualpathrnn/train_sdr_160_epoch.csv


In [7]:
_logger(log, 
        "VALID", 
        index = 1,
        csv_filename = "./checkpoints/dualpathrnn/valid_sisnr_160_epoch.csv")

Данные успешно сохранены в файл ./checkpoints/dualpathrnn/valid_sisnr_160_epoch.csv


In [9]:
_logger(log, 
        "VALID", 
        index = 2,
        csv_filename = "./checkpoints/dualpathrnn/valid_sdr_160_epoch.csv")

Данные успешно сохранены в файл ./checkpoints/dualpathrnn/valid_sdr_160_epoch.csv
