In [1]:
import torch.optim as optim
from datetime import datetime
from utils_local import *
import os
import matplotlib.pyplot as plt

In [2]:
char2index = dict()
index2char = dict()
SOS_token = 0
EOS_token = 0
PAD_token = 0

DATASET_PATH = 'D:/nsml-dataset'

In [9]:
batch_size = 8
num_thread = 4
num_mels = 160
num_hidden_enc = 1024
num_hidden_dec = 512
num_hidden_seq = 1024
num_layers = 4
lr_1 = 1e-5
lr_2 = 1e-5
train_ratio = .9
nsc_in_ms = 40
loss_lim = 0.03
max_epochs = 100
ref_repeat = 1

filename = './model_saved/nsml_local_4_layer_residual'

In [10]:
char2index, index2char = load_label_local(os.path.join(DATASET_PATH, 'hackathon.labels'))
SOS_token = char2index['<s>']  # '<sos>' or '<s>'
EOS_token = char2index['</s>']  # '<eos>' or '</s>'
PAD_token = char2index['_']  # '-' or '_'

unicode_jamo_list = My_Unicode_Jamo_v2()

tokenizer = Tokenizer(unicode_jamo_list)
jamo_tokens = tokenizer.word2num(unicode_jamo_list)
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
net = Mel2SeqNet_General_Residual(num_mels, num_hidden_enc, num_hidden_dec, len(unicode_jamo_list), num_layers, device)
net_optimizer = optim.Adam(net.parameters(), lr=lr_1)
ctc_loss = nn.CTCLoss().to(device)

net_B = Seq2SeqNet_v2(num_hidden_seq, jamo_tokens, char2index, device)
net_B_optimizer = optim.Adam(net_B.parameters(), lr=lr_2)
net_B_criterion = nn.NLLLoss(reduction='none').to(device)

print(net)
print(net_B)

Mel2SeqNet_General_Residual(
  (encoder): Encoder_General_Residual(
    (fc): Linear(in_features=160, out_features=1024, bias=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.2, inplace=False)
    (gru_layers): ModuleList(
      (0): Residual_GRU(
        (gru): GRU(1024, 512, batch_first=True, bidirectional=True)
      )
      (1): Residual_GRU(
        (gru): GRU(1024, 512, batch_first=True, bidirectional=True)
      )
      (2): Residual_GRU(
        (gru): GRU(1024, 512, batch_first=True, bidirectional=True)
      )
      (3): Residual_GRU(
        (gru): GRU(1024, 512, batch_first=True, bidirectional=True)
      )
    )
  )
  (decoder): CTC_Decoder_General_Residual(
    (fc_embed): Linear(in_features=1024, out_features=1024, bias=True)
    (relu_embed): ReLU()
    (dropout_embed): Dropout(p=0.2, inplace=False)
    (gru_layers): ModuleList(
      (0): GRU(1024, 512, batch_first=True)
      (1): Residual_GRU(
        (gru): GRU(512, 512, batch_first=True)
      )
      (2): Resid

In [12]:
wav_paths, script_paths, korean_script_paths = get_paths(DATASET_PATH)

print('wav_paths len: {}'.format(len(wav_paths)))
print('script_paths len: {}'.format(len(script_paths)))
print('korean_script_paths len: {}'.format(len(korean_script_paths)))

korean_script_list, jamo_script_list = get_korean_and_jamo_list_v2_local(korean_script_paths)

print('Korean script 0: {}'.format(korean_script_list[0]))
print('Korean script 0 length: {}'.format(len(korean_script_list[0])))
print('Jamo script 0: {}'.format(jamo_script_list[0]))
print('Jamo script 0 length: {}'.format(len(jamo_script_list[0])))

script_path_list = get_script_list(script_paths, SOS_token, EOS_token)

ground_truth_list = [(tokenizer.word2num(['<s>'] + list(jamo_script_list[i]) + ['</s>'])) for i in
                     range(len(jamo_script_list))]

wav_paths len: 29805
script_paths len: 29805
korean_script_paths len: 29805
Korean script 0: 예약하고 싶은데 어떻게 해야하나요?
Korean script 0 length: 19
Jamo script 0: 예약하고 싶은데 어떻게 해야하나요?
Jamo script 0 length: 38


In [13]:
print('Train Ratio: {}'.format(train_ratio))
split_index = int(train_ratio * len(wav_paths))

wav_path_list_train = wav_paths[:split_index]
ground_truth_list_train = ground_truth_list[:split_index]
korean_script_list_train = korean_script_list[:split_index]
script_path_list_train = script_path_list[:split_index]

wav_path_list_eval = wav_paths[split_index:]
ground_truth_list_eval = ground_truth_list[split_index:]
korean_script_list_eval = korean_script_list[split_index:]
script_path_list_eval = script_path_list[split_index:]

print('Total:Train:Eval = {}:{}:{}'.format(len(wav_paths), len(wav_path_list_train), len(wav_path_list_eval)))

preloader_train = Threading_Batched_Preloader_v2_local(wav_path_list_train, ground_truth_list_train,
                                                 script_path_list_train, korean_script_list_train, batch_size,
                                                 num_mels, nsc_in_ms, is_train=True)

preloader_eval = Threading_Batched_Preloader_v2_local(wav_path_list_eval, ground_truth_list_eval, script_path_list_eval,
                                                korean_script_list_eval, batch_size, num_mels, nsc_in_ms,
                                                is_train=False)

Train Ratio: 0.9
Total:Train:Eval = 29805:26824:2981


In [14]:
best_loss = 1e10
best_eval_cer = 1e10

# load all target scripts for reducing disk i/o
target_path = os.path.join(DATASET_PATH, 'train_label')
load_targets(target_path)

In [None]:
loss_train_history = list()
seq2seq_loss_train_history = list()
seq2seq_loss_train_ref_history = list()

loss_eval_history = list()
seq2seq_loss_eval_history = list()
seq2seq_loss_eval_ref_history = list()

try:
    loss_train_history = list(np.load(os.path.join(filename, 'loss_train_history.npy')))
    seq2seq_loss_train_history = list(np.load(os.path.join(filename, 'seq2seq_loss_train_history.npy')))
    seq2seq_loss_train_ref_history = list(np.load(os.path.join(filename, 'loss_train_history.npy')))

    print("Train history loaded")
    
    loss_eval_history = list(np.load(os.path.join(filename, 'loss_eval_history.npy')))
    seq2seq_loss_eval_history = list(np.load(os.path.join(filename, 'seq2seq_loss_eval_history.npy')))
    seq2seq_loss_eval_ref_history = list(np.load(os.path.join(filename, 'seq2seq_loss_eval_ref_history.npy')))

    print("Evaluation history loaded")
    
    state = torch.load(os.path.join(filename, 'modelA.pt'))
    net.load_state_dict(state['model'])
    net_optimizer.load_state_dict(state['optimizer'])
    
    print('Model A loaded')

    state = torch.load(os.path.join(filename, 'modelB.pt'))
    net_B.load_state_dict(state['model'])
    net_B_optimizer.load_state_dict(state['optimizer'])
    
    print('Model B loaded')
    
    for g in net_optimizer.param_groups:
        g['lr'] = lr_1
        print('Learning rate of the net: {}'.format(g['lr']))

    for g in net_B_optimizer.param_groups:
        g['lr'] = lr_2
        print('Learning rate of the net B: {}'.format(g['lr']))
    
    print("Loaded from {}".format(filename))
    
except:
    print("No previous data on {}".format(filename))

print('start')

train_begin = time.time()

for epoch in range(max_epochs):

    print((datetime.now().strftime('%m-%d %H:%M:%S')))

    net.train()
    net_B.train()

    preloader_train.initialize_batch(num_thread)
    loss_list_train = list()
    seq2seq_loss_list_train = list()
    seq2seq_loss_list_train_ref = list()

    print("Initialized Training Preloader")
    count = 0

    total_dist = 0
    total_length = 1
    total_dist_ref = 0
    total_length_ref = 1

    while not preloader_train.end_flag:
        batch = preloader_train.get_batch()
        if batch is not None:
            tensor_input, ground_truth, loss_mask, length_list, batched_num_script, batched_num_script_loss_mask = batch
            pred_tensor, loss = train(net, net_optimizer, ctc_loss, tensor_input.to(device),
                                      ground_truth.to(device), length_list.to(device), device)
            loss_list_train.append(loss)
            jamo_result = Decode_Prediction_No_Filtering(pred_tensor, tokenizer)
            true_string_list = Decode_Num_Script(batched_num_script.detach().cpu().numpy(), index2char)

            for i in range(ref_repeat):
                lev_input_ref = ground_truth

                lev_pred_ref, attentions_ref, seq2seq_loss_ref = net_B.net_train(lev_input_ref.to(device),
                                                                                 batched_num_script.to(device),
                                                                                 batched_num_script_loss_mask.to(
                                                                                     device),
                                                                                 net_B_optimizer,
                                                                                 net_B_criterion)

                pred_string_list_ref = Decode_Lev_Prediction(lev_pred_ref, index2char)
                seq2seq_loss_list_train_ref.append(seq2seq_loss_ref)
                dist_ref, length_ref = char_distance_list(true_string_list, pred_string_list_ref)

                pred_string_list = [None]

                dist = 0
                length = 0

                if (loss < loss_lim):
                    lev_input = Decode_CTC_Prediction_And_Batch(pred_tensor)
                    lev_pred, attentions, seq2seq_loss = net_B.net_train(lev_input.to(device),
                                                                         batched_num_script.to(device),
                                                                         batched_num_script_loss_mask.to(device),
                                                                         net_B_optimizer, net_B_criterion)
                    pred_string_list = Decode_Lev_Prediction(lev_pred, index2char)
                    seq2seq_loss_list_train.append(seq2seq_loss)
                    dist, length = char_distance_list(true_string_list, pred_string_list)

            total_dist_ref += dist_ref
            total_length_ref += length_ref

            total_dist += dist
            total_length += length

            count += 1

            if count % (int(len(wav_path_list_train) / batch_size / 15)) == 0:
                print("Train: Count {} | {} => {}".format(count, true_string_list[0], pred_string_list_ref[0]))

                print("Train: Count {} | {} => {} => {}".format(count, true_string_list[0], jamo_result[0],
                                                                      pred_string_list[0]))

        else:
            print("Training Batch is None")

    train_loss = np.mean(np.asarray(loss_list_train))
    train_cer = np.mean(np.asarray(total_dist / total_length))
    train_cer_ref = np.mean(np.asarray(total_dist_ref / total_length_ref))

    print("Mean Train Loss: {}".format(train_loss))
    print("Total Train CER: {}".format(train_cer))
    print("Total Train Reference CER: {}".format(train_cer_ref))

    preloader_eval.initialize_batch(num_thread)
    loss_list_eval = list()
    seq2seq_loss_list_eval = list()
    seq2seq_loss_list_eval_ref = list()

    print("Initialized Evaluation Preloader")

    count = 0
    total_dist = 0
    total_length = 1
    total_dist_ref = 0
    total_length_ref = 1

    net.eval()
    net_B.eval()

    while not preloader_eval.end_flag:
        batch = preloader_eval.get_batch()
        if batch is not None:
            tensor_input, ground_truth, loss_mask, length_list, batched_num_script, batched_num_script_loss_mask = batch
            pred_tensor, loss = evaluate(net, ctc_loss, tensor_input.to(device), ground_truth.to(device),
                                         length_list.to(device), device)
            loss_list_eval.append(loss)

            jamo_result = Decode_Prediction_No_Filtering(pred_tensor, tokenizer)

            true_string_list = Decode_Num_Script(batched_num_script.detach().cpu().numpy(), index2char)

            lev_input_ref = ground_truth
            lev_pred_ref, attentions_ref, seq2seq_loss_ref = net_B.net_eval(lev_input_ref.to(device),
                                                                            batched_num_script.to(device),
                                                                            batched_num_script_loss_mask.to(device),
                                                                            net_B_criterion)

            pred_string_list_ref = Decode_Lev_Prediction(lev_pred_ref, index2char)
            seq2seq_loss_list_train_ref.append(seq2seq_loss_ref)
            dist_ref, length_ref = char_distance_list(true_string_list, pred_string_list_ref)

            pred_string_list = [None]

            dist = 0
            length = 0

            if (loss < loss_lim):
                lev_input = Decode_CTC_Prediction_And_Batch(pred_tensor)
                lev_pred, attentions, seq2seq_loss = net_B.net_eval(lev_input.to(device),
                                                                    batched_num_script.to(device),
                                                                    batched_num_script_loss_mask.to(device),
                                                                    net_B_criterion)
                pred_string_list = Decode_Lev_Prediction(lev_pred, index2char)
                seq2seq_loss_list_train.append(seq2seq_loss)
                dist, length = char_distance_list(true_string_list, pred_string_list)

            total_dist_ref += dist_ref
            total_length_ref += length_ref

            total_dist += dist
            total_length += length

            count += 1

            if count % (int(len(wav_path_list_eval) / batch_size / 15)) == 0:
                print("Eval: Count {} | {} => {}".format(count, true_string_list[0], pred_string_list_ref[0]))

                print("Eval: Count {} | {} => {} => {}".format(count, true_string_list[0], jamo_result[0],
                                                                     pred_string_list[0]))

        else:
            print("Training Batch is None")

    eval_cer = total_dist / total_length
    eval_cer_ref = total_dist_ref / total_length_ref
    eval_loss = np.mean(np.asarray(loss_list_eval))

    print("Mean Evaluation Loss: {}".format(eval_loss))
    print("Total Evaluation CER: {}".format(eval_cer))
    print("Total Evaluation Reference CER: {}".format(eval_cer_ref))
    
    loss_train_history.append(train_loss)
    seq2seq_loss_train_history.append(train_cer)
    seq2seq_loss_train_ref_history.append(train_cer_ref)

    loss_eval_history.append(eval_loss)
    seq2seq_loss_eval_history.append(eval_cer)
    seq2seq_loss_eval_ref_history.append(eval_cer_ref)
    
    f, (ax1, ax2, ax3) = plt.subplots(1, 3)
    ax1.plot(loss_train_history)
    ax1.plot(loss_eval_history)
    
    ax2.plot(seq2seq_loss_train_history)
    ax2.plot(seq2seq_loss_eval_history)
    
    ax3.plot(seq2seq_loss_train_ref_history)
    ax3.plot(seq2seq_loss_eval_ref_history)
    
    plt.show()
    
    if not os.path.exists(filename):
        os.makedirs(filename)

    np.save(os.path.join(filename, 'loss_train_history.npy'), loss_train_history)
    np.save(os.path.join(filename, 'seq2seq_loss_train_history.npy'), seq2seq_loss_train_history)
    np.save(os.path.join(filename, 'loss_train_history.npy'), seq2seq_loss_train_ref_history)

    np.save(os.path.join(filename, 'loss_eval_history.npy'), loss_eval_history)
    np.save(os.path.join(filename, 'seq2seq_loss_eval_history.npy'), seq2seq_loss_eval_history)
    np.save(os.path.join(filename, 'seq2seq_loss_eval_ref_history.npy'), seq2seq_loss_eval_ref_history)

    state = {
        'model': net.state_dict(),
        'optimizer': net_optimizer.state_dict()
    }
    torch.save(state, os.path.join(filename, 'modelA.pt'))

    state = {
        'model': net_B.state_dict(),
        'optimizer': net_B_optimizer.state_dict()
    }
    torch.save(state, os.path.join(filename, 'modelB.pt'))

    best_model = (eval_cer < best_eval_cer)
    
    if best_model:
        state = {
            'model': net.state_dict(),
            'optimizer': net_optimizer.state_dict()
        }
        torch.save(state, os.path.join(filename, 'modelA_best.pt'))

        state = {
            'model': net_B.state_dict(),
            'optimizer': net_B_optimizer.state_dict()
        }
        torch.save(state, os.path.join(filename, 'modelB_best.pt'))

        best_eval_cer = eval_cer   



Train history loaded
Evaluation history loaded
Model A loaded
Model B loaded
Learning rate of the net: 1e-05
Learning rate of the net B: 1e-05
Loaded from ./model_saved/nsml_local_4_layer_residual
start
10-06 23:59:21
6706
batch initialized
Initialized Training Preloader
Train: Count 223 | 오후 7시요  => 오후 5시  
Train: Count 223 | 오후 7시요  => <s>오후 *시요.</s> => 오후 5시  
Train: Count 446 | 샐러드바 이용시 아이 요금은 어떻게 되나요?  => 샐러드  용용  이   금은어어게게 되나요? 
Train: Count 446 | 샐러드바 이용시 아이 요금은 어떻게 되나요?  => <s>샐라드바이 용신 아이오부은 어떻게 되나요?</s> => 샐러드드  용   이이    떻떻게되나요? 
Train: Count 669 | 그럼 다음에 다시 전화 드릴께요  => 그럼 다음에 다  전화드리릴 요요 
Train: Count 669 | 그럼 다음에 다시 전화 드릴께요  => <s>검 당ㅁ엘 다기 전확드릴께요.</s> => 김  음일 다기 전화드리릴릴요요 
Train: Count 892 | 설명 감사합니다  => 설명 감사합니다 
Train: Count 892 | 설명 감사합니다  => <s>ㅅ명 감사합니다.</s> => 4  감사합니다 
Train: Count 1115 | 삼성카드 할인 혜택 있나요?  => 삼성카드 할인 혜택 있나요? 
Train: Count 1115 | 삼성카드 할인 혜택 있나요?  => <s>한성카드 할인혜땔 수 ㅣ있나요?</s> => 한성카드 할인  수  있나요? 
Train: Count 1338 | 파스타 쿠폰 있는데 이것만 따로 포장 가능한가요?  => 파스타 쿠폰