In [None]:
import os
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import SGD

import import_ipynb
from constants import CHECKPOINTS_DIR, NUM_FRAME, NUM_FBANK
import batcher
import net_factory
import loss
from utils import load_best_checkpoint, ensures_dir


def fit_model_softmax(dsm: DeepSpeakerModel, kx_train, ky_train, kx_test, ky_test,
                      batch_size=BATCH_SIZE, max_epochs=1000, initial_epoch=0):
    checkpoint_name = dsm.m.name + '_checkpoint'
    checkpoint_filename = os.path.join(CHECKPOINTS_SOFTMAX_DIR, checkpoint_name + '_{epoch}.h5')
    checkpoint = ModelCheckpoint(monitor='val_accuracy', filepath=checkpoint_filename, save_best_only=True)

    # if the accuracy does not increase by 0.1% over 20 epochs, we stop the training.
    early_stopping = EarlyStopping(monitor='val_accuracy', min_delta=0.001, patience=20, verbose=1, mode='max')

    # if the accuracy does not increase over 10 epochs, we reduce the learning rate by half.
    reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=10, min_lr=0.0001, verbose=1)

    max_len_train = len(kx_train) - len(kx_train) % batch_size
    kx_train = kx_train[0:max_len_train]
    ky_train = ky_train[0:max_len_train]
    max_len_test = len(kx_test) - len(kx_test) % batch_size
    kx_test = kx_test[0:max_len_test]
    ky_test = ky_test[0:max_len_test]

    dsm.m.fit(x=kx_train,
              y=ky_train,
              batch_size=batch_size,
              epochs=initial_epoch + max_epochs,
              initial_epoch=initial_epoch,
              verbose=1,
              shuffle=True,
              validation_data=(kx_test, ky_test),
              callbacks=[early_stopping, reduce_lr, checkpoint])

def train_simMat(working_dir):
    
    

    
def train(working_dir, batcher_name, model_name, loss_name, fit_name):
    
    # Batcher 로드
    Batcher = batcher(batcher_name)
    num_speaker= len(Batcher.num_speaker)
    
    # 모델 생성
    Model = net_factory(model_name)
    Model.m.compile(optimizer='adam', loss=loss(loss_name))
    
    # 에폭 설정
    initial_epoch, max_epoch = 0, 1000
    
    # 체크포인트 로드. TODO : name별로 다른 체크포인트 폴더 만들기.
    ensures_dir(CHECKPOINTS_DIR)
    pre_training_checkpoint = load_best_checkpoint(CHECKPOINTS_DIR)
    if pre_training_checkpoint:
        initial_epoch = int(pre_training_checkpoint.split('/')[-1].split('.')[0].split('_')[-1])
        Molde.m.load_weights(pre_training_checkpoint)  # latest one.
    
    # 체크포인트 객체 생성
    checkpoint_name = dsm.m.name + '_checkpoint'
    checkpoint_filename = os.path.join(CHECKPOINTS_DIR, checkpoint_name + '_{epoch}.h5')
    checkpoint = ModelCheckpoint(monitor='val_accuracy', filepath=checkpoint_filename, save_best_only=True)

    # early_stopping 객체 생성
    early_stopping = EarlyStopping(monitor='val_accuracy', min_delta=0.001, patience=20, verbose=1, mode='max')

    # reduce_lr 객체 생성
    reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=10, min_lr=0.0001, verbose=1)
    
    Model.m.fit(x=Batcher.train_generator(), y=None, steps_per_epoch=2000, shuffle=False,
              epochs=1000, validation_data=Batcher.test_generator(), validation_steps=len(test_batches),
              callbacks=[checkpoint])
    
    # 실전에서 등록 화자의 발화들의 평균과 평가 발화를 비교하는 것도 필요하겠다. 따로 loss를 정의하고 m.evaluation해야 하려나
    
    dsm.m.fit(x=kx_train,
              y=ky_train,
              batch_size=batch_size,
              epochs=initial_epoch + max_epochs,
              initial_epoch=initial_epoch,
              verbose=1,
              shuffle=True,
              validation_data=(kx_test, ky_test),
              callbacks=[early_stopping, reduce_lr, checkpoint])
    
    
    max_epoch = 1000
    """
    
    