In [None]:
"""
# colab에서만 사용하는 코드. import될 때 주석처리 되어있어야 한다.

# drive mount. colab에 내 구글 드라이브 연결
from google.colab import drive
drive.mount('/content/drive')

# import_ipynb module 설치
!pip install import_ipynb

# import를 위한 경로이동
%cd /content/drive/MyDrive/team_malmungchi/colab/speaker_verification/code
!ls
"""

In [None]:
import os
import sys
import tensorflow as tf
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import SGD
import pickle

#sys.path.append("/content/drive/MyDrive/team_malmungchi/colab/speaker_verification/code")
import import_ipynb
from constants import *
import batcher
import network
import loss
import fitter
import utils
from test_eer import test_frame, test_utt

# eager execution 사용. test_step에서 eer 계산에 numpy를 사용하기 때문
tf.config.run_functions_eagerly(True)

In [None]:
 def train(model_name, batcher_name, loss_name, data_paths, hyper_params={}, pre_checkpoint_dir=None, saving_tag=None, root='/content/drive/MyDrive/team_malmungchi/colab/speaker_verification'):
    """ speaker verification model train + val + test 
    
    학습 설정을 받아 speaker verification 모델을 학습 후 가장 EER이 낮은 모델을 저장
    가장 EER이 낮은 모델으로 frame 단위 테스트, utterance 단위 테스트
    체크포인트는 {root}/model/{model_name}-{batcher_name}-{loss_name}-{saving_tag}에 {epoch}-{val_eer}.hdf5형식으로 저장
    transfer learning을 하고싶다면 pre_checkpoint_dir에 pretrain된 모델의 체크포인트가 존재하는 checkpoint 폴더를 입력(root제외)

    Args:
      model_name : model name in network.ipynb
      batcher_name : batcher name in batcher.ipynb (now naive_batcher or simMat_batcher)
      loss_name : loss nane in loss.ipynb (now cross_entropy or simMat_loss)
      data_paths : train_dataset_path or [train_dataset_path, val_dataset_path, test_frame_dataset_path, test_utt_dataset_path]  (no root)
      hyper_params : dictionary for hyper parameters. {learning_rate, initial_epoch, max_epoch}
      pre_checkpoint_dir : in case transfer learning, set this argument for pre-trained model's checkpoint dir (no root)
      saving_tag : additional name tag for checkpoint directory
      root : root path that contains directories 'model', 'code', 'data'

    Returns:
      None
    
    Raises:

    """
    
    # argument 처리
    if not isinstance(data_paths,list):
      train_dataset_path = root+'/'+data_paths
      val_dataset_path = root+'/data/dataset/val_cpp_517_25_49_40.pickle'
      test_frame_dataset_path = root+'/data/dataset/test_cpp_516_25_49_40.pickle'
      test_utt_dataset_path = root+'/data/dataset/testUtt_cpp_356_25.pickle'
    elif len(data_paths)==1:
      train_dataset_path = root+'/'+data_paths[0]
      val_dataset_path = root+'/data/dataset/val_cpp_517_25_49_40.pickle'
      test_frame_dataset_path = root+'/data/dataset/test_cpp_516_25_49_40.pickle'
      test_utt_dataset_path = root+'/data/dataset/testUtt_cpp_356_25.pickle'
    elif len(data_paths)==4:
      train_dataset_path = root+'/'+data_paths[0]
      val_dataset_path = root+'/'+data_paths[1]
      test_frame_dataset_path = root+'/'+data_paths[2]
      test_utt_dataset_path = root+'/'+data_paths[3]
    else: raise Exception('argument data_paths is not right format')

    initial_epoch = hyper_params.get('initial_epoch', 0)
    learning_rate = hyper_params.get('learning_rate', 0.001)
    early_stopping_patience = hyper_params.get('early_stopping_patience', 100)
    reduce_lr_patience = hyper_params.get('reduce_lr_patience', 40)
    reduce_lr_min_lr = hyper_params.get('reduce_lr_min_lr', 0.00001)
    

    # Batcher 생성 (train dataset 로드)
    print('==================================================')
    print('create batcher...')
    Batcher = batcher.get_batcher(batcher_name, train_dataset_path)
    print('batcher is created')


    # validation dataset 로드
    print('==================================================')
    print('load validation dataset...')
    with open(val_dataset_path,"rb") as f: val_X = pickle.load(f)
    print('validation dataset is loaded')
    print('shape of data :', val_X.shape)
    

    # 모델 생성
    print('==================================================')
    print('create model...')
    Model = network.get_network(model_name)
    print(model_name+' is created')
    Model.summary()


    # optimizer 객체 생성
    optimizer = tf.optimizers.Adam(learning_rate)


    # loss 객체 생성
    Loss = loss.get_loss(loss_name)


    # 체크포인트 경로 생성
    print('==================================================')
    print('preparing checkpoint...')
    checkpoint_dir = root+f'/model/{model_name}-{batcher_name}-{loss_name}'
    if saving_tag: checkpoint_dir+='-'+saving_tag
    # transfer learning의 경우 pretrain된 체크포인트 로드
    if pre_checkpoint_dir:  
      checkpoint_dir += '--transferedFrom--'+pre_checkpoint_dir.split('/')[-1]
      best_ckpt = utils.load_best_checkpoint(root+'/'+pre_checkpoint_dir)
      Model.load_weights(root+'/'+pre_checkpoint_dir+'/'+best_ckpt)
      print('load pre-trainded checkpoint that model:'+pre_checkpoint_dir.split('/')[-1]+f', epoch:{initial_epoch}, EER:'+best_ckpt.split('-')[-1][:-5])
    if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir)
    # 이전에 학습중이었던 경우 해당 체크포인트 로드해 이어서 학습
    best_ckpt = utils.load_best_checkpoint(checkpoint_dir)
    if best_ckpt:
      initial_epoch = int(best_ckpt.split('-')[0])
      Model.load_weights(checkpoint_dir+'/'+best_ckpt)
      print('load exist checkpoint that model:'+checkpoint_dir.split('/')[-1]+f', epoch:{initial_epoch}, EER:'+best_ckpt.split('-')[-1][:-5])
    # 체크포인트 객체 생성
    checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir+'/{epoch:05d}-{val_eer:.4f}.hdf5', monitor='val_eer', mode='min', save_best_only=True)
    hyper_params['initial_epoch']=initial_epoch


    # early_stopping 객체 생성. 30epoch동안 eer이 0.1%도 감소하지 않는다면 중지
    early_stopping = EarlyStopping(monitor='val_eer', min_delta=0.001, patience=early_stopping_patience, mode='min', verbose=1)

    
    # reduce_lr 객체 생성. 10epoch동안 val_eer이 감소하지 않는다면 lr 절반으로 줄이기
    reduce_lr = ReduceLROnPlateau(monitor='val_eer', factor=0.5, patience=reduce_lr_patience, mode='min', min_lr=reduce_lr_min_lr, verbose=1)

    
    # train model
    print('==================================================')
    print('start training')
    callbacks = [checkpoint, early_stopping, reduce_lr]
    fitter.fit(Model, Batcher, val_X, Loss, hyper_params, optimizer, callbacks)
    del Batcher
    del val_X
    print('training is end')


    # load best weights
    print('==================================================')
    best_ckpt = utils.load_best_checkpoint(checkpoint_dir)
    Model.load_weights(checkpoint_dir+'/'+best_ckpt)
    best_epoch = int(best_ckpt.split('-')[0])
    print(f'load best checkpoint that epoch:{best_epoch}, EER:'+best_ckpt.split('-')[-1][:-5])


    # test
    test_frame(Model,test_frame_dataset_path)
    test_utt(Model,test_utt_dataset_path)

In [None]:
""" example

# naive train
train('ACRNN', 'naive_batcher', 'cross_entropy', 'data/dataset/train_300_200_128_512.pickle')

# simMat train
train('CNN', 'simMat_batcher', 'simMat_loss', 'data/dataset/train_300_200_128_512.pickle')

# transfer learning
train('naive_model', 'naive_batcher', 'cross_entropy', 'data/dataset/train_300_200_128_512.pickle')
train('naive_model', 'simMat_batcher', 'simMat_loss', 'data/dataset/train_300_200_128_512.pickle', pre_checkpoint_dir='model/naive_model-naive_batcher-cross_entropy')
"""