In [1]:
# %%capture
%pip install torch
%pip install datasets
%pip install transformers==4.19.2
%pip install jiwer
%pip install torchaudio
%pip install librosa
%pip install accelerate -U

# Monitor the training process
%pip install wandb

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
%env LC_ALL=C.UTF-8
%env LANG=C.UTF-8
%env TRANSFORMERS_CACHE=/home/ec2-user/SageMaker/cache
%env HF_DATASETS_CACHE=/home/ec2-user/SageMaker/cache
%env CUDA_LAUNCH_BLOCKING=1

env: LC_ALL=C.UTF-8
env: LANG=C.UTF-8
env: TRANSFORMERS_CACHE=/home/ec2-user/SageMaker/cache
env: HF_DATASETS_CACHE=/home/ec2-user/SageMaker/cache
env: CUDA_LAUNCH_BLOCKING=1


In [3]:
import time
import os
from glob import glob
import datetime
from pathlib import Path
import numpy as np
import pandas as pd; pd.options.mode.chained_assignment=None
from tqdm import tqdm
import yaml
import torch
import torch.nn as nn
from torch import optim
import wandb
from torch.utils.data import DataLoader
import NISQA_lib as NL

In [4]:
class nisqaModel(object):    
    def __init__(self, args):
        self.args = args
        
        if 'mode' not in self.args:
            self.args['mode'] = 'main'
            
        self.runinfos = {}       
        self._getDevice()
        self._loadModel()
        self._loadDatasets()
        self.args['now'] = datetime.datetime.today()
        
        
        if self.args['mode']=='main':
            print(yaml.dump(self.args, default_flow_style=None, sort_keys=False))  
            
    def evaluate(self, mapping='first_order', do_print=True, do_plot=False):
        if self.args['dim']==True:
            self._evaluate_dim(mapping=mapping, do_print=do_print, do_plot=do_plot)
        else:
            self._evaluate_mos(mapping=mapping, do_print=do_print, do_plot=do_plot)      
            
    def predict(self):
        print('---> Predicting ...')
        if self.args['tr_parallel']:
            self.model = nn.DataParallel(self.model)           
        
        if self.args['dim']==True:
            y_val_hat, y_val = NL.predict_dim(
                self.model, 
                self.ds_val, 
                self.args['tr_bs_val'],
                self.dev,
                num_workers=self.args['tr_num_workers'])
        else:
            y_val_hat, y_val = NL.predict_mos(
                self.model, 
                self.ds_val, 
                self.args['tr_bs_val'],
                self.dev,
                num_workers=self.args['tr_num_workers'])                 
                    
        if self.args['output_dir']:
            self.ds_val.df['model'] = self.args['name']
            self.ds_val.df.to_csv(
                os.path.join(self.args['output_dir'], 'NISQA_results.csv'), 
                index=False)
            
        return self.ds_val.df

    def _train_mos(self):
        '''
        Trains speech quality model.
        '''
        # Initialize  -------------------------------------------------------------
        if self.args['tr_parallel']:
            self.model = nn.DataParallel(self.model)
        self.model.to(self.dev)

        # Runname and savepath  ---------------------------------------------------
        self.runname = self._makeRunnameAndWriteYAML()

        # Optimizer  -------------------------------------------------------------
        opt = optim.Adam(self.model.parameters(), lr=self.args['tr_lr'])        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                opt,
                'min',
                verbose=True,
                threshold=0.003,
                patience=self.args['tr_lr_patience'])
        earlyStp = NL.earlyStopper(self.args['tr_early_stop'])      
        
        biasLoss = NL.biasLoss(
            self.ds_train.df.filepath, 
            anchor_db=self.args['tr_bias_anchor_db'], 
            mapping=self.args['tr_bias_mapping'], 
            min_r=self.args['tr_bias_min_r'],
            do_print=(self.args['tr_verbose']>0),
            )

        # Dataloader    -----------------------------------------------------------
        dl_train = DataLoader(
            self.ds_train,
            batch_size=self.args['tr_bs'],
            shuffle=True,
            drop_last=False,
            pin_memory=True,
            num_workers=self.args['tr_num_workers'])
        
        # Start training loop   ---------------------------------------------------
        print('--> start training')
        
        for epoch in range(self.args['tr_epochs']):
            tic_epoch = time.time()
            batch_cnt = 0
            loss = 0.0
            y_train = self.ds_train.df[self.args['csv_mos_train']].to_numpy().reshape(-1)
            y_train_hat = np.zeros((len(self.ds_train), 1))
            self.model.train()
            
            # Progress bar
            if self.args['tr_verbose'] == 2:
                pbar = tqdm(iterable=batch_cnt, total=len(dl_train), ascii=">—",
                            bar_format='{bar} {percentage:3.0f}%, {n_fmt}/{total_fmt}, {elapsed}<{remaining}{postfix}')
                
            for xb_spec, yb_mos, (idx, n_wins) in dl_train:

                # Estimate batch ---------------------------------------------------
                xb_spec = xb_spec.to(self.dev)
                yb_mos = yb_mos.to(self.dev)
                n_wins = n_wins.to(self.dev)

                # Forward pass ----------------------------------------------------
                yb_mos_hat = self.model(xb_spec, n_wins)
                y_train_hat[idx] = yb_mos_hat.detach().cpu().numpy()

                # Loss ------------------------------------------------------------       
                lossb = biasLoss.get_loss(yb_mos, yb_mos_hat, idx)
                # Backprop  -------------------------------------------------------
                lossb.backward()
                opt.step()
                opt.zero_grad()

                # Update total loss -----------------------------------------------
                loss += lossb.item()
                batch_cnt += 1

                if self.args['tr_verbose'] == 2:
                    pbar.set_postfix(loss=lossb.item())
                    pbar.update()

            if self.args['tr_verbose'] == 2:
                pbar.close()

            loss = loss/batch_cnt
            
            biasLoss.update_bias(y_train, y_train_hat)
            

            # Evaluate   -----------------------------------------------------------
            if self.args['tr_verbose']>0:
                print('\n<---- Training ---->')
            self.ds_train.df['mos_pred'] = y_train_hat
            db_results_train, r_train = NL.eval_results(
                self.ds_train.df, 
                dcon=self.ds_train.df_con, 
                target_mos=self.args['csv_mos_train'],
                target_ci=self.args['csv_mos_train'] + '_ci',
                pred='mos_pred',
                mapping = 'first_order',
                do_print=(self.args['tr_verbose']>0)
                )
            
            
            if self.args['tr_verbose']>0:
                print('<---- Validation ---->')
            NL.predict_mos(self.model, self.ds_val, self.args['tr_bs_val'], self.dev, num_workers=self.args['tr_num_workers'])
            NL.predict_mos(self.model, self.ds_test, self.args['tr_bs_val'], self.dev, num_workers=self.args['tr_num_workers'])

            db_results, r_val = NL.eval_results(
                self.ds_val.df, 
                dcon=self.ds_val.df_con, 
                target_mos=self.args['csv_mos_val'],
                target_ci=self.args['csv_mos_val'] + '_ci',
                pred='mos_pred',
                mapping = 'first_order',
                do_print=(self.args['tr_verbose']>0)
                )    
            db_results, r_test = NL.eval_results(
                self.ds_test.df, 
                dcon=self.ds_test.df_con, 
                target_mos=self.args['csv_mos_test'],
                target_ci=self.args['csv_mos_test'] + '_ci',
                pred='mos_pred',
                mapping = 'first_order',
                do_print=(self.args['tr_verbose']>0)
                )   
            
            
            r = {
                 'train_pearson': r_train['r_p_all'],
                 'train_spearman': r_train['r_s_all'],
                 'train_mse': r_train['rmse_all'],
                 
                 'val_pearson': r_val['r_p_all'],
                 'val_spearman': r_val['r_s_all'],
                 'val_mse': r_val['rmse_all'],
                 
                 'test_pearson': r_test['r_p_all'],
                 'test_spearman': r_test['r_s_all'],
                 'test_mse': r_test['rmse_all']
                }
            
            # Scheduler update    ---------------------------------------------
            scheduler.step(loss)
            # earl_stp = earlyStp.step(r)   
              
            wandb.log(r)       

            # Print    --------------------------------------------------------
            ep_runtime = time.time() - tic_epoch
          
            self._saveResults(self.model, self.model_args, opt, epoch, loss, ep_runtime, r, db_results, earlyStp.best)

            # # Early stopping    -----------------------------------------------
            # if earl_stp:
            #     print('--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}'
            #         .format(earlyStp.best_r_p, earlyStp.best_rmse))
            #     return        

#         print('--> Training done. best_r_p {:0.2f} best_rmse_map {:0.2f}'
#                             .format(earlyStp.best_r_p, earlyStp.best_rmse))        
        return        
     
        
    
    def _evaluate_mos(self, mapping='first_order', do_print=True, do_plot=False):
        '''
        Evaluates the model's predictions.
        '''        
        print('--> MOS:')
        self.db_results, self.r = NL.eval_results(
            self.ds_val.df,
            dcon=self.ds_val.df_con,
            target_mos='mos',
            target_ci='mos_ci',
            pred='mos_pred',
            mapping=mapping,
            do_print=do_print,
            do_plot=do_plot
            )

    def _makeRunnameAndWriteYAML(self):
        runname = self.args['name']
        print('runname: ' + runname)
        yaml_path = os.path.join(self.args['output_dir'], runname+'.yaml')
        Path(self.args['output_dir']).mkdir(parents=True, exist_ok=True)
        with open(yaml_path, 'w') as file:
            yaml.dump(self.args, file, default_flow_style=None, sort_keys=False)
        return runname
    
    def _loadDatasets(self):
        if self.args['mode']=='predict_file':
            self._loadDatasetsFile()
        elif self.args['mode']=='predict_dir':
            self._loadDatasetsFolder()  
        elif self.args['mode']=='predict_csv':
            self._loadDatasetsCSVpredict()
        elif self.args['mode']=='main':
            self._loadDatasetsCSV()
        else:
            raise NotImplementedError('mode not available')                        
            
    
    def _loadDatasetsFolder(self):
        files = glob( os.path.join(self.args['data_dir'], '*.wav') )
        files = [os.path.basename(files) for files in files]
        df_val = pd.DataFrame(files, columns=['deg'])
     
        print('# files: {}'.format( len(df_val) ))
        if len(df_val)==0:
            raise ValueError('No wav files found in data_dir')   
        
        # creating Datasets ---------------------------------------------------                        
        self.ds_val = NL.SpeechQualityDataset(
            df_val,
            df_con=None,
            data_dir = self.args['data_dir'],
            filename_column = 'deg',
            mos_column = 'predict_only',              
            seg_length = self.args['ms_seg_length'],
            max_length = self.args['ms_max_segments'],
            to_memory = None,
            to_memory_workers = None,
            seg_hop_length = self.args['ms_seg_hop_length'],
            transform = None,
            ms_n_fft = self.args['ms_n_fft'],
            ms_hop_length = self.args['ms_hop_length'],
            ms_win_length = self.args['ms_win_length'],
            ms_n_mels = self.args['ms_n_mels'],
            ms_sr = self.args['ms_sr'],
            ms_fmax = self.args['ms_fmax'],
            ms_channel = self.args['ms_channel'],
            double_ended = self.args['double_ended'],
            dim = self.args['dim'],
            filename_column_ref = None,
            )
        
        
    def _loadDatasetsFile(self):
        data_dir = os.path.dirname(self.args['deg'])
        file_name = os.path.basename(self.args['deg'])        
        df_val = pd.DataFrame([file_name], columns=['deg'])
                
        # creating Datasets ---------------------------------------------------                        
        self.ds_val = NL.SpeechQualityDataset(
            df_val,
            df_con=None,
            data_dir = data_dir,
            filename_column = 'deg',
            mos_column = 'predict_only',              
            seg_length = self.args['ms_seg_length'],
            max_length = self.args['ms_max_segments'],
            to_memory = None,
            to_memory_workers = None,
            seg_hop_length = self.args['ms_seg_hop_length'],
            transform = None,
            ms_n_fft = self.args['ms_n_fft'],
            ms_hop_length = self.args['ms_hop_length'],
            ms_win_length = self.args['ms_win_length'],
            ms_n_mels = self.args['ms_n_mels'],
            ms_sr = self.args['ms_sr'],
            ms_fmax = self.args['ms_fmax'],
            ms_channel = self.args['ms_channel'],
            double_ended = self.args['double_ended'],
            dim = self.args['dim'],
            filename_column_ref = None,
        )
                
        
    def _loadDatasetsCSVpredict(self):         
        '''
        Loads validation dataset for prediction only.
        '''            
        csv_file_path = os.path.join(self.args['data_dir'], self.args['csv_file'])
        dfile = pd.read_csv(csv_file_path)
        if 'csv_con' in self.args:
            csv_con_file_path = os.path.join(self.args['data_dir'], self.args['csv_con'])
            dcon = pd.read_csv(csv_con_file_path)        
        else:
            dcon = None
        

        # creating Datasets ---------------------------------------------------                        
        self.ds_val = NL.SpeechQualityDataset(
            dfile,
            df_con=dcon,
            data_dir = self.args['data_dir'],
            filename_column = self.args['csv_deg'],
            mos_column = 'predict_only',              
            seg_length = self.args['ms_seg_length'],
            max_length = self.args['ms_max_segments'],
            to_memory = False,
            to_memory_workers = None,
            seg_hop_length = self.args['ms_seg_hop_length'],
            transform = None,
            ms_n_fft = self.args['ms_n_fft'],
            ms_hop_length = self.args['ms_hop_length'],
            ms_win_length = self.args['ms_win_length'],
            ms_n_mels = self.args['ms_n_mels'],
            ms_sr = self.args['ms_sr'],
            ms_fmax = self.args['ms_fmax'],
            ms_channel = self.args['ms_channel'],
            double_ended = self.args['double_ended'],
            dim = self.args['dim'],
            filename_column_ref = self.args['csv_ref'],
        )

        
    def _loadDatasetsCSV(self):    
        dfile_train = pd.read_csv(self.args['csv_file'])
        dfile_val = pd.read_csv(self.args['csv_file_val'])
        dfile_test = pd.read_csv(self.args['csv_file_test'])
        
        df_train = dfile_train.reset_index()
        df_val = dfile_val.reset_index()
        df_test = dfile_test.reset_index()
        
        if self.args['csv_con'] is not None:
            dcon = None        
            dcon_train = None        
            dcon_val = None        
        else:
            dcon = None        
            dcon_train = None        
            dcon_val = None        
        
        print('Training size: {}, Validation size: {}'.format(len(df_train), len(df_val)))
        
        # creating Datasets ---------------------------------------------------                        
        self.ds_train = NL.SpeechQualityDataset(
            df_train,
            df_con=dcon_train,
            data_dir = self.args['data_dir'] + '/' + self.args['csv_db_train'][0],
            filename_column = self.args['csv_deg'],
            mos_column = self.args['csv_mos_train'],            
            seg_length = self.args['ms_seg_length'],
            max_length = self.args['ms_max_segments'],
            to_memory = self.args['tr_ds_to_memory'],
            to_memory_workers = self.args['tr_ds_to_memory_workers'],
            seg_hop_length = self.args['ms_seg_hop_length'],
            transform = None,
            ms_n_fft = self.args['ms_n_fft'],
            ms_hop_length = self.args['ms_hop_length'],
            ms_win_length = self.args['ms_win_length'],
            ms_n_mels = self.args['ms_n_mels'],
            ms_sr = self.args['ms_sr'],
            ms_fmax = self.args['ms_fmax'],
            ms_channel = self.args['ms_channel'],
            double_ended = self.args['double_ended'],
            dim = self.args['dim'],
            filename_column_ref = self.args['csv_ref'],
        )


        self.ds_val = NL.SpeechQualityDataset(
            df_val,
            df_con=dcon_val,
            data_dir = self.args['data_dir'] + '/' + self.args['csv_db_val'][0],
            filename_column = self.args['csv_deg'],
            mos_column = self.args['csv_mos_val'],              
            seg_length = self.args['ms_seg_length'],
            max_length = self.args['ms_max_segments'],
            to_memory = self.args['tr_ds_to_memory'],
            to_memory_workers = self.args['tr_ds_to_memory_workers'],
            seg_hop_length = self.args['ms_seg_hop_length'],
            transform = None,
            ms_n_fft = self.args['ms_n_fft'],
            ms_hop_length = self.args['ms_hop_length'],
            ms_win_length = self.args['ms_win_length'],
            ms_n_mels = self.args['ms_n_mels'],
            ms_sr = self.args['ms_sr'],
            ms_fmax = self.args['ms_fmax'],
            ms_channel = self.args['ms_channel'],
            double_ended = self.args['double_ended'],
            dim = self.args['dim'],
            filename_column_ref = self.args['csv_ref'],                        
            )

        
        self.ds_test = NL.SpeechQualityDataset(
            df_test,
            df_con=dcon_val,
            data_dir = self.args['data_dir'] + '/' + self.args['csv_db_test'][0],
            filename_column = self.args['csv_deg'],
            mos_column = self.args['csv_mos_val'],              
            seg_length = self.args['ms_seg_length'],
            max_length = self.args['ms_max_segments'],
            to_memory = self.args['tr_ds_to_memory'],
            to_memory_workers = self.args['tr_ds_to_memory_workers'],
            seg_hop_length = self.args['ms_seg_hop_length'],
            transform = None,
            ms_n_fft = self.args['ms_n_fft'],
            ms_hop_length = self.args['ms_hop_length'],
            ms_win_length = self.args['ms_win_length'],
            ms_n_mels = self.args['ms_n_mels'],
            ms_sr = self.args['ms_sr'],
            ms_fmax = self.args['ms_fmax'],
            ms_channel = self.args['ms_channel'],
            double_ended = self.args['double_ended'],
            dim = self.args['dim'],
            filename_column_ref = self.args['csv_ref'],                        
            )
        
        self.runinfos['ds_train_len'] = len(self.ds_train)
        self.runinfos['ds_val_len'] = len(self.ds_val)
        self.runinfos['ds_test_len'] = len(self.ds_test)
    
    def _loadModel(self):    
        '''
        Loads the Pytorch models with given input arguments.
        '''   
        # if True overwrite input arguments from pretrained model
        if self.args['pretrained_model']:
            if os.path.isabs(self.args['pretrained_model']):
                model_path = os.path.join(self.args['pretrained_model'])
            else:
                model_path = os.path.join(os.getcwd(), self.args['pretrained_model'])
            checkpoint = torch.load(model_path, map_location=self.dev)
            
            # update checkpoint arguments with new arguments
            checkpoint['args'].update(self.args)
            self.args = checkpoint['args']
            
        if self.args['model']=='NISQA_DIM':
            self.args['dim'] = True
            self.args['csv_mos_train'] = None # column names hardcoded for dim models
            self.args['csv_mos_val'] = None  
        else:
            self.args['dim'] = False
            
        if self.args['model']=='NISQA_DE':
            self.args['double_ended'] = True
        else:
            self.args['double_ended'] = False     
            self.args['csv_ref'] = None

        # Load Model
        self.model_args = {
            
            'ms_seg_length': self.args['ms_seg_length'],
            'ms_n_mels': self.args['ms_n_mels'],
            
            'cnn_model': self.args['cnn_model'],
            'cnn_c_out_1': self.args['cnn_c_out_1'],
            'cnn_c_out_2': self.args['cnn_c_out_2'],
            'cnn_c_out_3': self.args['cnn_c_out_3'],
            'cnn_kernel_size': self.args['cnn_kernel_size'],
            'cnn_dropout': self.args['cnn_dropout'],
            'cnn_pool_1': self.args['cnn_pool_1'],
            'cnn_pool_2': self.args['cnn_pool_2'],
            'cnn_pool_3': self.args['cnn_pool_3'],
            'cnn_fc_out_h': self.args['cnn_fc_out_h'],
            
            'td': self.args['td'],
            'td_sa_d_model': self.args['td_sa_d_model'],
            'td_sa_nhead': self.args['td_sa_nhead'],
            'td_sa_pos_enc': self.args['td_sa_pos_enc'],
            'td_sa_num_layers': self.args['td_sa_num_layers'],
            'td_sa_h': self.args['td_sa_h'],
            'td_sa_dropout': self.args['td_sa_dropout'],
            'td_lstm_h': self.args['td_lstm_h'],
            'td_lstm_num_layers': self.args['td_lstm_num_layers'],
            'td_lstm_dropout': self.args['td_lstm_dropout'],
            'td_lstm_bidirectional': self.args['td_lstm_bidirectional'],
            
            'td_2': self.args['td_2'],
            'td_2_sa_d_model': self.args['td_2_sa_d_model'],
            'td_2_sa_nhead': self.args['td_2_sa_nhead'],
            'td_2_sa_pos_enc': self.args['td_2_sa_pos_enc'],
            'td_2_sa_num_layers': self.args['td_2_sa_num_layers'],
            'td_2_sa_h': self.args['td_2_sa_h'],
            'td_2_sa_dropout': self.args['td_2_sa_dropout'],
            'td_2_lstm_h': self.args['td_2_lstm_h'],
            'td_2_lstm_num_layers': self.args['td_2_lstm_num_layers'],
            'td_2_lstm_dropout': self.args['td_2_lstm_dropout'],
            'td_2_lstm_bidirectional': self.args['td_2_lstm_bidirectional'],                
            
            'pool': self.args['pool'],
            'pool_att_h': self.args['pool_att_h'],
            'pool_att_dropout': self.args['pool_att_dropout'],
            }
            
        if self.args['double_ended']:
            self.model_args.update({
                'de_align': self.args['de_align'],
                'de_align_apply': self.args['de_align_apply'],
                'de_fuse_dim': self.args['de_fuse_dim'],
                'de_fuse': self.args['de_fuse'],        
                })
                      
        print('Model architecture: ' + self.args['model'])
        if self.args['model']=='NISQA':
            self.model = NL.NISQA(**self.model_args)     
        elif self.args['model']=='NISQA_DIM':
            self.model = NL.NISQA_DIM(**self.model_args)     
        elif self.args['model']=='NISQA_DE':
            self.model = NL.NISQA_DE(**self.model_args)     
        else:
            raise NotImplementedError('Model not available')                        
        
        # Load weights if pretrained model is used ------------------------------------
        if self.args['pretrained_model']:
            missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
            # print('Loaded pretrained model from ' + self.args['pretrained_model'])
            if missing_keys:
                # print('missing_keys:')
                print(missing_keys)
            if unexpected_keys:
                # print('unexpected_keys:')
                print(unexpected_keys)        
            
    def _getDevice(self):
        '''
        Train on GPU if available.
        '''         
        if torch.cuda.is_available():
            self.dev = torch.device("cuda")
        else:
            self.dev = torch.device("cpu")
    
        if "tr_device" in self.args:
            if self.args['tr_device']=='cpu':
                self.dev = torch.device("cpu")
            elif self.args['tr_device']=='cuda':
                self.dev = torch.device("cuda")
        print('Device: {}'.format(self.dev))
        
        if "tr_parallel" in self.args:
            if (self.dev==torch.device("cpu")) and self.args['tr_parallel']==True:
                self.args['tr_parallel']==False 
                # print('Using CPU -> tr_parallel set to False')

    def _saveResults(self, model, model_args, opt, epoch, loss, ep_runtime, r, db_results, best):
        '''
        Save model/results in dictionary and write results csv.
        ''' 
        if (self.args['tr_checkpoint'] == 'best_only'):
            filename = self.runname + '.tar'
        else:
            filename = self.runname + '_' + ('ep_{:03d}'.format(epoch+1)) + '.tar'
        model_path = os.path.join(self.args['output_dir'], filename)
        results_path = os.path.join(self.args['output_dir'], self.runname+'__results.csv')
        Path(self.args['output_dir']).mkdir(parents=True, exist_ok=True)              
        
        results = {
            'runname': self.runname,
            'epoch': '{:05d}'.format(epoch+1),
            'filename': filename,
            'loss': loss,
            'ep_runtime': '{:0.2f}'.format(ep_runtime),
            **self.runinfos,
            **r,
            **self.args,
            }
        
        for key in results: 
            results[key] = str(results[key])                        

        if epoch==0:
            self.results_hist = pd.DataFrame(results, index=[0])
        else:
            self.results_hist.loc[epoch] = results
        self.results_hist.to_csv(results_path, index=False)


        if (self.args['tr_checkpoint'] == 'every_epoch') or (self.args['tr_checkpoint'] == 'best_only' and best):
      
            if hasattr(model, 'module'):
                state_dict = model.module.state_dict()
                model_name = model.module.name
            else:
                state_dict = model.state_dict()
                model_name = model.name
    
            torch_dict = {
                'runname': self.runname,
                'epoch': epoch+1,
                'model_args': model_args,
                'args': self.args,
                'model_state_dict': state_dict,
                'optimizer_state_dict': opt.state_dict(),
                'db_results': db_results,
                'results': results,
                'model_name': model_name,
                }
            
            torch.save(torch_dict, model_path)
            
            model_files = sorted(glob(os.path.join(self.args['output_dir'], self.runname + '_ep_*.tar')))
            if len(model_files) > 5:
                for old_file in model_files[:-5]:
                    os.remove(old_file)
            
        elif (self.args['tr_checkpoint']!='every_epoch') and (self.args['tr_checkpoint']!='best_only'):
            raise ValueError('selected tr_checkpoint option not available')

            


In [5]:
import yaml
yaml_file_path = 'config.yaml'  
with open(yaml_file_path, "r") as ymlfile:
    args_yaml = yaml.load(ymlfile, Loader=yaml.FullLoader)

args = args_yaml

In [None]:
wandb.login(key=args['wandb_key'])
wandb.init(project=args['wandb_proj_name'], config=args, name=args['wandb_run_name'])

nisqa = nisqaModel(args)
nisqa._train_mos()

wandb.finish()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mshubham-kumar1[0m ([33mshubham-kumar1-shl[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ec2-user/.netrc


Device: cuda
Model architecture: NISQA
Training size: 11020, Validation size: 2700
pretrained_model: /home/ec2-user/SageMaker/Noise_modelling/pre_trained_model/NISQA/nisqa_mos_only.tar
name: NISQA
tr_epochs: 30
tr_early_stop: 20
tr_bs: 40
tr_bs_val: 40
tr_lr: 0.001
tr_lr_patience: 15
tr_num_workers: 4
tr_parallel: true
tr_ds_to_memory: false
tr_ds_to_memory_workers: 0
tr_verbose: 2
tr_device: null
ms_sr: null
ms_fmax: 20000
ms_n_fft: 4096
ms_hop_length: 0.01
ms_win_length: 0.02
ms_n_mels: 48
ms_seg_length: 15
ms_seg_hop_length: 4
ms_max_segments: 1300
cnn_model: adapt
cnn_c_out_1: 16
cnn_c_out_2: 32
cnn_c_out_3: 64
cnn_kernel_size: !!python/tuple [3, 3]
cnn_dropout: 0.2
cnn_fc_out_h: null
cnn_pool_1: [24, 7]
cnn_pool_2: [12, 5]
cnn_pool_3: [6, 3]
td: self_att
td_sa_d_model: 64
td_sa_nhead: 1
td_sa_pool_size: null
td_sa_pos_enc: false
td_sa_num_layers: 2
td_sa_h: 64
td_sa_dropout: 0.1
td_lstm_h: null
td_lstm_num_layers: null
td_lstm_dropout: null
td_lstm_bidirectional: null
td_2: skip
t



--> start training


—————————— 100%, 276/276, 09:10<00:00, loss=2.36



<---- Training ---->
<---- Validation ---->
