In [None]:
# -*- coding: utf-8 -*-

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

import tensorflow as tf
import numpy as np
import time

from rnn_model import Model
from imdb_loader import text_data

def initialize_session():
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    return tf.Session(config=config)

##################################################
max_len = 200           # sequence 단어 수 제한
max_vocab = 20000       # maximum 단어 개수
BATCH_SIZE = 10         # 배치 사이즈
emb_dim = 64            # 단어 embedding dimension
hidden_dim = 128        # RNN hidden dim
learning_rate = 0.0025  # Learning rate
use_clip = True         # Gradient clipping 쓸지 여부
##################################################

END_TOKEN = "<eos>"
data = text_data("./dataset/aclImdb/", max_len=max_len, end_token=END_TOKEN)
model = Model(max_len=max_len,
              emb_dim=emb_dim,
              hidden_dim=hidden_dim,
              vocab_size=data.vocab_size,
              class_size=2,
              use_clip=True, learning_rate=learning_rate, end_token=data.w2idx[END_TOKEN])

sess = initialize_session()
sess.run(tf.global_variables_initializer())

def test_model():
    num_it = int(len(data.test_ids) / BATCH_SIZE)
    num_it = 100
    same, test_loss, test_cnt = .0, 0, 0

    for _ in range(num_it):
        test_ids, length, label = data.get_test(BATCH_SIZE)
        loss = sess.run(model.loss, feed_dict={model.x: test_ids, model.x_len: length, model.y: label})

        for i, o in enumerate(out):
            if o == label[i]:
                same += 1
        test_loss += loss
        test_cnt += 1
    print(" --> test_loss: {:.3f} | test_acc: {:.3f}".format(test_loss / test_cnt, same/test_cnt/BATCH_SIZE))

# 0: neg, 1: pos
avg_loss, it_cnt, same = 0, 0, .0
it_log, it_test, it_save, it_sample = 10, 100, 1000, 100
start_time = time.time()

for it in range(0, 10000):
    train_ids, length, label = data.get_train(BATCH_SIZE)
    loss, _, out = sess.run([model.loss, model.update, model.out_label],
                            feed_dict={model.x: train_ids, model.x_len: length, model.y: label, model.keep_prob: 0.5})
    for i, o in enumerate(out):
        if o == label[i]:
            same += 1
    avg_loss += loss
    it_cnt += 1

    if it % it_log == 0 and it:
        print(" it: {:4d} | loss: {:.3f} | acc: {:.3f} - {:.2f}s".format(
            it, avg_loss / it_cnt, same/BATCH_SIZE/it_log, time.time() - start_time))
        avg_loss, it_cnt, same = 0, 0, .0

    if it % it_test == 0 and it > 0:
        test_model()
    if it % it_save == 0 and it > 0:
        model.save(sess)


  from ._conv import register_converters as _register_converters


 it:   10 | loss: 0.692 | acc: 0.530 - 2.97s
 it:   20 | loss: 0.698 | acc: 0.610 - 5.21s
 it:   30 | loss: 0.703 | acc: 0.470 - 7.35s
 it:   40 | loss: 0.692 | acc: 0.520 - 9.44s
 it:   50 | loss: 0.692 | acc: 0.520 - 11.56s
 it:   60 | loss: 0.689 | acc: 0.560 - 13.60s
 it:   70 | loss: 0.677 | acc: 0.650 - 15.85s
 it:   80 | loss: 0.713 | acc: 0.520 - 18.28s
 it:   90 | loss: 0.688 | acc: 0.500 - 20.46s
 it:  100 | loss: 0.682 | acc: 0.590 - 22.60s
 --> test_loss: 0.611 | test_acc: 0.700
 it:  110 | loss: 0.736 | acc: 0.570 - 33.04s
 it:  120 | loss: 0.689 | acc: 0.530 - 35.44s
 it:  130 | loss: 0.704 | acc: 0.450 - 37.80s
 it:  140 | loss: 0.668 | acc: 0.550 - 40.04s
 it:  150 | loss: 0.695 | acc: 0.550 - 42.56s
 it:  160 | loss: 0.681 | acc: 0.570 - 44.82s
 it:  170 | loss: 0.678 | acc: 0.560 - 47.05s
 it:  180 | loss: 0.719 | acc: 0.650 - 49.48s
 it:  190 | loss: 0.707 | acc: 0.520 - 51.73s
 it:  200 | loss: 0.702 | acc: 0.520 - 54.12s
 --> test_loss: 0.685 | test_acc: 0.600
 it:

 it: 1570 | loss: 0.397 | acc: 0.810 - 506.91s
 it: 1580 | loss: 0.383 | acc: 0.850 - 509.39s
 it: 1590 | loss: 0.304 | acc: 0.870 - 511.63s
 it: 1600 | loss: 0.309 | acc: 0.860 - 514.16s
 --> test_loss: 0.343 | test_acc: 0.800
 it: 1610 | loss: 0.318 | acc: 0.870 - 528.07s
 it: 1620 | loss: 0.334 | acc: 0.840 - 531.00s
 it: 1630 | loss: 0.265 | acc: 0.900 - 533.85s
 it: 1640 | loss: 0.308 | acc: 0.890 - 536.91s
 it: 1650 | loss: 0.395 | acc: 0.860 - 539.38s
 it: 1660 | loss: 0.325 | acc: 0.870 - 541.50s
 it: 1670 | loss: 0.362 | acc: 0.840 - 543.74s
 it: 1680 | loss: 0.371 | acc: 0.850 - 545.92s
 it: 1690 | loss: 0.313 | acc: 0.840 - 548.76s
 it: 1700 | loss: 0.343 | acc: 0.850 - 551.68s
 --> test_loss: 0.322 | test_acc: 1.000
 it: 1710 | loss: 0.340 | acc: 0.850 - 565.54s
 it: 1720 | loss: 0.347 | acc: 0.840 - 567.94s
 it: 1730 | loss: 0.433 | acc: 0.780 - 570.49s
 it: 1740 | loss: 0.346 | acc: 0.830 - 572.73s
 it: 1750 | loss: 0.302 | acc: 0.890 - 574.95s
 it: 1760 | loss: 0.367 | a

 it: 3170 | loss: 0.072 | acc: 0.960 - 1052.67s
 it: 3180 | loss: 0.154 | acc: 0.950 - 1055.67s
 it: 3190 | loss: 0.073 | acc: 0.960 - 1058.61s
 it: 3200 | loss: 0.101 | acc: 0.940 - 1061.68s
 --> test_loss: 0.416 | test_acc: 1.000
 it: 3210 | loss: 0.068 | acc: 0.980 - 1073.41s
 it: 3220 | loss: 0.091 | acc: 0.970 - 1075.70s
 it: 3230 | loss: 0.036 | acc: 0.980 - 1078.06s
 it: 3240 | loss: 0.060 | acc: 0.980 - 1080.31s
 it: 3250 | loss: 0.069 | acc: 0.960 - 1082.83s
 it: 3260 | loss: 0.143 | acc: 0.940 - 1085.19s
 it: 3270 | loss: 0.043 | acc: 0.990 - 1087.51s
 it: 3280 | loss: 0.043 | acc: 0.990 - 1090.02s
 it: 3290 | loss: 0.116 | acc: 0.960 - 1092.38s
 it: 3300 | loss: 0.033 | acc: 1.000 - 1094.74s
 --> test_loss: 0.395 | test_acc: 1.000
 it: 3310 | loss: 0.062 | acc: 0.980 - 1105.65s
 it: 3320 | loss: 0.069 | acc: 0.970 - 1107.95s
 it: 3330 | loss: 0.030 | acc: 0.990 - 1110.24s
 it: 3340 | loss: 0.123 | acc: 0.980 - 1112.79s
 it: 3350 | loss: 0.090 | acc: 0.960 - 1115.11s
 it: 336

 it: 4740 | loss: 0.005 | acc: 1.000 - 1555.46s
 it: 4750 | loss: 0.098 | acc: 0.980 - 1557.94s
 it: 4760 | loss: 0.010 | acc: 1.000 - 1560.06s
 it: 4770 | loss: 0.113 | acc: 0.970 - 1562.23s
 it: 4780 | loss: 0.034 | acc: 0.990 - 1564.34s
 it: 4790 | loss: 0.096 | acc: 0.980 - 1566.47s
 it: 4800 | loss: 0.012 | acc: 1.000 - 1568.58s
 --> test_loss: 0.563 | test_acc: 1.000
 it: 4810 | loss: 0.013 | acc: 1.000 - 1578.34s
 it: 4820 | loss: 0.033 | acc: 0.990 - 1580.88s
 it: 4830 | loss: 0.039 | acc: 0.980 - 1583.69s
 it: 4840 | loss: 0.048 | acc: 0.990 - 1586.40s
 it: 4850 | loss: 0.009 | acc: 1.000 - 1589.14s
 it: 4860 | loss: 0.075 | acc: 0.980 - 1591.88s
 it: 4870 | loss: 0.026 | acc: 0.990 - 1594.65s
 it: 4880 | loss: 0.063 | acc: 0.970 - 1597.40s
 it: 4890 | loss: 0.063 | acc: 0.990 - 1599.97s
 it: 4900 | loss: 0.010 | acc: 1.000 - 1602.43s
 --> test_loss: 0.475 | test_acc: 1.000
 it: 4910 | loss: 0.021 | acc: 0.990 - 1612.18s
 it: 4920 | loss: 0.017 | acc: 0.990 - 1614.76s
 it: 493

 --> test_loss: 0.812 | test_acc: 1.000
 it: 6310 | loss: 0.001 | acc: 1.000 - 2048.65s
 it: 6320 | loss: 0.000 | acc: 1.000 - 2050.74s
 it: 6330 | loss: 0.000 | acc: 1.000 - 2052.80s
 it: 6340 | loss: 0.004 | acc: 1.000 - 2054.93s
 it: 6350 | loss: 0.002 | acc: 1.000 - 2057.09s
 it: 6360 | loss: 0.000 | acc: 1.000 - 2059.17s
 it: 6370 | loss: 0.001 | acc: 1.000 - 2061.18s
 it: 6380 | loss: 0.000 | acc: 1.000 - 2063.25s
 it: 6390 | loss: 0.000 | acc: 1.000 - 2065.29s
 it: 6400 | loss: 0.000 | acc: 1.000 - 2067.35s
 --> test_loss: 0.840 | test_acc: 1.000
 it: 6410 | loss: 0.000 | acc: 1.000 - 2077.11s
 it: 6420 | loss: 0.000 | acc: 1.000 - 2079.18s
 it: 6430 | loss: 0.000 | acc: 1.000 - 2081.33s
 it: 6440 | loss: 0.000 | acc: 1.000 - 2083.33s
 it: 6450 | loss: 0.001 | acc: 1.000 - 2085.40s
 it: 6460 | loss: 0.000 | acc: 1.000 - 2087.46s
 it: 6470 | loss: 0.000 | acc: 1.000 - 2089.50s
 it: 6480 | loss: 0.000 | acc: 1.000 - 2091.57s
 it: 6490 | loss: 0.000 | acc: 1.000 - 2093.68s
 it: 650

 it: 7880 | loss: 0.000 | acc: 1.000 - 2520.62s
 it: 7890 | loss: 0.001 | acc: 1.000 - 2523.43s
 it: 7900 | loss: 0.000 | acc: 1.000 - 2526.04s
 --> test_loss: 1.062 | test_acc: 1.000
 it: 7910 | loss: 0.000 | acc: 1.000 - 2538.30s
 it: 7920 | loss: 0.000 | acc: 1.000 - 2540.87s
 it: 7930 | loss: 0.000 | acc: 1.000 - 2543.02s
 it: 7940 | loss: 0.006 | acc: 1.000 - 2545.12s
 it: 7950 | loss: 0.009 | acc: 0.990 - 2547.23s
 it: 7960 | loss: 0.001 | acc: 1.000 - 2549.56s
 it: 7970 | loss: 0.004 | acc: 1.000 - 2552.19s
 it: 7980 | loss: 0.003 | acc: 1.000 - 2554.56s
 it: 7990 | loss: 0.001 | acc: 1.000 - 2556.65s
 it: 8000 | loss: 0.036 | acc: 0.990 - 2558.78s
 --> test_loss: 0.977 | test_acc: 0.900
 * model saved at 'models/cnn'
 it: 8010 | loss: 0.002 | acc: 1.000 - 2569.39s
 it: 8020 | loss: 0.009 | acc: 0.990 - 2571.50s
 it: 8030 | loss: 0.001 | acc: 1.000 - 2573.64s
 it: 8040 | loss: 0.001 | acc: 1.000 - 2575.75s
 it: 8050 | loss: 0.005 | acc: 1.000 - 2577.91s
 it: 8060 | loss: 0.001 |

 it: 9450 | loss: 0.000 | acc: 1.000 - 3005.01s
 it: 9460 | loss: 0.001 | acc: 1.000 - 3007.10s
 it: 9470 | loss: 0.001 | acc: 1.000 - 3009.16s
 it: 9480 | loss: 0.001 | acc: 1.000 - 3011.28s
 it: 9490 | loss: 0.001 | acc: 1.000 - 3013.82s
 it: 9500 | loss: 0.001 | acc: 1.000 - 3016.21s
 --> test_loss: 0.950 | test_acc: 1.000
 it: 9510 | loss: 0.009 | acc: 1.000 - 3026.12s
 it: 9520 | loss: 0.000 | acc: 1.000 - 3028.28s
 it: 9530 | loss: 0.001 | acc: 1.000 - 3030.40s
 it: 9540 | loss: 0.001 | acc: 1.000 - 3032.56s
 it: 9550 | loss: 0.001 | acc: 1.000 - 3034.65s
 it: 9560 | loss: 0.001 | acc: 1.000 - 3036.79s
 it: 9570 | loss: 0.001 | acc: 1.000 - 3038.90s
 it: 9580 | loss: 0.002 | acc: 1.000 - 3041.06s
 it: 9590 | loss: 0.009 | acc: 1.000 - 3043.14s
 it: 9600 | loss: 0.002 | acc: 1.000 - 3045.23s
 --> test_loss: 1.096 | test_acc: 1.000
 it: 9610 | loss: 0.032 | acc: 0.990 - 3055.15s
 it: 9620 | loss: 0.001 | acc: 1.000 - 3057.27s
 it: 9630 | loss: 0.002 | acc: 1.000 - 3059.39s
 it: 964