In [1]:
# -*- coding: utf-8 -*-

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

import tensorflow as tf
import numpy as np
import time

from encoder_decoder_model import Model
from data_loader import text_data

  from ._conv import register_converters as _register_converters


In [2]:
def initialize_session():
    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    return tf.Session(config=config)

##################################################
max_len = 40            # sequence 단어 수 제한
BATCH_SIZE = 20         # 배치 사이즈 - 1이 아니면 입력 데이터 구성이 어려움
emb_dim = 64            # 단어 embedding dimension
hidden_dim = 128        # RNN hidden dim
learning_rate = 0.005   # Learning rate
use_clip = True         # Gradient clipping 쓸지 여부
##################################################

END_TOKEN = "<eos>"
data = text_data("./dataset/ptb", max_len=max_len, end_token=END_TOKEN)
model = Model(emb_dim=emb_dim, hidden_dim=hidden_dim, vocab_size=data.vocab_size,
              use_clip=True, learning_rate=learning_rate, end_token=data.w2idx[END_TOKEN])

sess = initialize_session()
sess.run(tf.global_variables_initializer())


def sample_test(test_input=""):
    # test_input = raw_input("test text: ") # input("test text: ") for python 2, 3
    words = test_input.split()
    input_x = np.zeros((1, max_len), dtype=np.int32)
    for i, word in enumerate(words):
        if i == max_len:
            break
        input_x[0][i] = data.w2idx[word]

    input_x_len = [i+1]
    output = sess.run(model.output, feed_dict={model.x: input_x, model.x_len: input_x_len})
    line = " ".join([data.idx2w[o] for o in output[0]])
    print(line)

def test_model():
    num_it = int(len(data.test_ids) / BATCH_SIZE)
    test_loss, test_cnt = 0, 0

    for _ in range(num_it):
        test_ids, length = data.get_test(BATCH_SIZE)
        loss = sess.run(model.loss, feed_dict={model.x: test_ids, model.x_len: length})

        test_loss += loss
        test_cnt += 1
    print("test loss: {:.3f}".format(test_loss / test_cnt))

avg_loss, it_cnt = 0, 0
it_log, it_test, it_save, it_sample = 10, 1500, 1500, 100
start_time = time.time()

for it in range(0, 10000):
    train_ids, length = data.get_train(BATCH_SIZE)
    loss, _, batch_max_len_ = sess.run([model.loss, model.update, model.batch_max_len],
                       feed_dict={model.x: train_ids, model.x_len: length})
    # Can get text embedding
    # text_vector = sess.run(model.enc_states, feed_dict={model.x: train_ids, model.x_len: length})
    print('batch_max_len', batch_max_len_)
    avg_loss += loss
    it_cnt += 1

    if it % it_log == 0:
        print(" it: {:4d} | loss: {:.3f} - {:.2f}s".format(it, avg_loss / it_cnt, time.time() - start_time))
        avg_loss, it_cnt = 0, 0

    if it % it_test == 0 and it > 0:
        test_model()
    if it % it_save == 0 and it > 0:
        model.save(sess)
    if it % it_sample == 0 and it > 0:
        sample_test(" there is no asbestos in our products now ")

#sess.close()

batch_max_len 34
 it:    0 | loss: 9.210 - 0.51s
batch_max_len 40
batch_max_len 33
batch_max_len 35
batch_max_len 34
batch_max_len 40
batch_max_len 40
batch_max_len 33
batch_max_len 40
batch_max_len 40
batch_max_len 40
 it:   10 | loss: 8.110 - 1.07s
batch_max_len 40
batch_max_len 34
batch_max_len 40
batch_max_len 40
batch_max_len 37
batch_max_len 35
batch_max_len 40
batch_max_len 39
batch_max_len 31
batch_max_len 40
 it:   20 | loss: 7.597 - 1.67s
batch_max_len 31
batch_max_len 40
batch_max_len 40
batch_max_len 34
batch_max_len 37
batch_max_len 40
batch_max_len 40
batch_max_len 39
batch_max_len 39
batch_max_len 40
 it:   30 | loss: 7.470 - 2.23s
batch_max_len 36
batch_max_len 40
batch_max_len 39
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 36
batch_max_len 35
batch_max_len 38
batch_max_len 31
 it:   40 | loss: 7.128 - 2.78s
batch_max_len 40
batch_max_len 39
batch_max_len 40
batch_max_len 37
batch_max_len 40
batch_max_len 40
batch_max_len 37
batch_max_len 35
batch_m

batch_max_len 36
batch_max_len 35
batch_max_len 40
batch_max_len 33
batch_max_len 40
batch_max_len 36
batch_max_len 36
batch_max_len 32
 it:  410 | loss: 2.965 - 23.77s
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 35
batch_max_len 33
 it:  420 | loss: 2.664 - 24.38s
batch_max_len 30
batch_max_len 33
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
 it:  430 | loss: 2.581 - 24.97s
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 39
batch_max_len 40
batch_max_len 40
batch_max_len 39
batch_max_len 40
batch_max_len 40
batch_max_len 40
 it:  440 | loss: 2.750 - 25.56s
batch_max_len 32
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 38
batch_max_len 40
batch_max_len 40
batch_max_len 30
batch_max_len 40
batch_max_len 34
 it:  450 | loss: 2.599 - 26.12s
batch_max_len 38
ba

batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
 it:  810 | loss: 1.213 - 46.32s
batch_max_len 31
batch_max_len 38
batch_max_len 40
batch_max_len 38
batch_max_len 40
batch_max_len 33
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
 it:  820 | loss: 1.393 - 46.86s
batch_max_len 40
batch_max_len 38
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 30
batch_max_len 35
batch_max_len 32
batch_max_len 40
batch_max_len 38
 it:  830 | loss: 1.114 - 47.40s
batch_max_len 32
batch_max_len 40
batch_max_len 35
batch_max_len 37
batch_max_len 40
batch_max_len 35
batch_max_len 39
batch_max_len 40
batch_max_len 40
batch_max_len 40
 it:  840 | loss: 1.183 - 47.95s
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 32
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 40
batch_max_len 34
 it:  850 | loss: 1.526 - 48.50s
batch_max_len 40
batch_max_len 40
batch_max_len 31
ba

KeyboardInterrupt: 