In [1]:
import argparse
import sys
from utils.load_config import load_config  

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/train_rnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  # Игнорирует нераспознанные аргументы

In [2]:
import os
import argparse
from pathlib import Path

import torch
import torchmetrics
from torch.utils.tensorboard import SummaryWriter as TensorBoard
from tqdm.notebook import tqdm

from losses import sisnr_loss, sdr_loss
from utils.load_config import load_config 
from utils.training import metadata_info, configure_optimizer, p_output_log
from utils.measure_time import measure_time
from models import MODELS
from data.DiarizationDataset import DiarizationDataset


torch.cuda.empty_cache()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')

cfg = load_config(args.hparams)
datamodule = DiarizationDataset(**cfg['data']).setup(stage = 'train')
dataloaders = {'train': datamodule.train_dataloader(), 'valid': datamodule.val_dataloader()}
model_class = MODELS[cfg['xp_config']['model_type']]
model = model_class(**cfg['model'])
metadata_info(model)
writer = TensorBoard(f'tb_logs/{Path(args.hparams).stem}', comment = f"{cfg['trainer']['ckpt_folder']}")
optimizer = configure_optimizer (cfg, model)

Size of training set: 1371
Size of validation set: 178
Elapsed time 'setup': 00:00:01.83
Trainable parametrs: 2633729
Size of model: 10.05 MB, in float32


In [3]:
from utils.checkpointer import Checkpointer
from utils.training import *


class Trainer:
    def __init__(self, num_epochs = 100, device='cuda', best_weights = False, checkpointing = False, 
                 checkpoint_interval = 10, model_name = '', trained_model = './', path_to_weights= './weights', 
                 ckpt_folder = '', speaker_num = 2, resume = False) -> None:
        self.num_epochs = num_epochs
        self.device = device
        self.best_weights = best_weights
        self.ckpointer = Checkpointer(model_name, path_to_weights, ckpt_folder, metrics = False)
        self.checkpointing = checkpointing
        self.checkpoint_interval = checkpoint_interval
        self.model_name = model_name
        os.makedirs(path_to_weights, exist_ok=True)
        self.path_to_weights = path_to_weights
        self.ckpt_folder = ckpt_folder
        self.speaker_num = speaker_num
        self.resume = resume
        self.trained_model = trained_model

    @measure_time
    def fit(self, model, dataloaders, criterion, optimizer, writer) -> None:
        model.to(self.device)
        start_epoch, min_val_loss, model, optimizer = self.load_pretrained_model(model, optimizer)
        epoch_state = EpochState(metrics = None)
        for epoch in tqdm(range(start_epoch, self.num_epochs)):
            for phase in ['train', 'valid']:
                model.train() if phase == 'train' else model.eval()
                dataloader = dataloaders[phase] 
                running_loss = 0.0
                for inputs, labels in dataloader:
                    inputs, labels = inputs.to(self.device), [l.to(self.device) for l in labels]
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                    running_loss += loss.item()
                epoch_loss = running_loss / len(dataloader.dataset)
                epoch_state.update_state(epoch_loss, phase)
                p_output_log(self.num_epochs, epoch, phase, epoch_state)
                
                if phase == 'valid' and self.best_weights and epoch_loss < min_val_loss:
                    min_val_loss = epoch_loss
                    self.ckpointer.save_best_weight(model, optimizer, epoch, epoch_state)
            
            torch_logger(writer, epoch, epoch_state)
            
            if self.checkpointing and (epoch + 1) % self.checkpoint_interval == 0:
                self.ckpointer.save_checkpoint(model, optimizer, epoch, epoch_state)

    def load_pretrained_model(self, model, optimizer):
        if self.trained_model:
            print(f"Load pretrained mode: {self.trained_model}", '\n')
            checkpoint = torch.load(self.trained_model, map_location=self.device, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            return checkpoint['epoch'] + 1, checkpoint['val_loss'] , model, optimizer
        else:
            return 0, float('inf'), model, optimizer

In [4]:
Trainer(**cfg['trainer']).fit(model, 
                              dataloaders, 
                              sisnr_loss, 
                              optimizer, 
                              writer)

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1/200
TRAIN, Loss: -0.0333
VALID, Loss: -0.0854
------------------------------------------------------------------------------------------------------------ 



KeyboardInterrupt: 

In [None]:
# -0.4693
# VALID, Loss: -0.6935