In [2]:
# -*- coding: utf-8 -*-

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

import tensorflow as tf
import numpy as np
import time

from dnn_model import Model
from imdb_loader import text_data

def initialize_session():
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    return tf.Session(config=config)

##################################################
max_len = 200           # sequence 단어 수 제한
max_vocab = 20000       # maximum 단어 개수
BATCH_SIZE = 10         # 배치 사이즈
emb_dim = 64            # 단어 embedding dimension
hidden_dim = 128        # RNN hidden dim
learning_rate = 0.005   # Learning rate
use_clip = True         # Gradient clipping 쓸지 여부
##################################################

END_TOKEN = "<eos>"
data = text_data("./dataset/aclImdb/", max_len=max_len, end_token=END_TOKEN)
model = Model(max_len=max_len,
              emb_dim=emb_dim,
              hidden_dim=hidden_dim,
              vocab_size=data.vocab_size,
              class_size=2,
              use_clip=True, learning_rate=learning_rate, end_token=data.w2idx[END_TOKEN])

sess = initialize_session()
sess.run(tf.global_variables_initializer())

def test_model():
    num_it = int(len(data.test_ids) / BATCH_SIZE)
    num_it = 10
    test_loss, test_cnt = 0, 0

    for _ in range(num_it):
        test_ids, length, label = data.get_test(BATCH_SIZE)
        ind, val = [], []
        for i, ids in enumerate(test_ids):
            for j in range(length[i]):
                ind.append([i, ids[j]])
                val.append(1.0)
        loss = sess.run(model.loss, feed_dict={model.ind: ind, model.val: val, model.y: label})

        test_loss += loss
        test_cnt += 1
    print("test loss: {:.3f}".format(test_loss / test_cnt))

# 0: neg, 1: pos
avg_loss, it_cnt, same = 0, 0, .0
it_log, it_test, it_save, it_sample = 10, 100, 1000, 100
start_time = time.time()

for it in range(0, 10000):
    train_ids, length, label = data.get_train(BATCH_SIZE)
    val, ind = [], []
    for i, ids in enumerate(train_ids):
        for j in range(length[i]):
            ind.append([i, ids[j]])
            val.append(1.0)
    loss, _, out = sess.run([model.loss, model.update, model.out_label],
                            feed_dict={model.ind: ind, model.val: val, model.y: label})

    for i, o in enumerate(out):
        if o == i % 2:
            same += 1
    avg_loss += loss
    it_cnt += 1

    if it % it_log == 0:
        print(" it: {:4d} | loss: {:.3f} | acc: {:.3f} - {:.2f}s".format(
            it, avg_loss / it_cnt, same/BATCH_SIZE/it_log, time.time() - start_time))
        avg_loss, it_cnt, same = 0, 0, .0

    if it % it_test == 0 and it > 0:
        test_model()
    if it % it_save == 0 and it > 0:
        model.save(sess)


 it:    0 | loss: 0.730 | acc: 0.050 - 0.30s
 it:   10 | loss: 0.651 | acc: 0.620 - 0.43s
 it:   20 | loss: 0.704 | acc: 0.620 - 0.56s
 it:   30 | loss: 0.636 | acc: 0.630 - 0.78s
 it:   40 | loss: 0.547 | acc: 0.730 - 0.92s
 it:   50 | loss: 0.529 | acc: 0.780 - 1.06s
 it:   60 | loss: 0.548 | acc: 0.750 - 1.19s
 it:   70 | loss: 0.559 | acc: 0.690 - 1.31s
 it:   80 | loss: 0.515 | acc: 0.760 - 1.44s
 it:   90 | loss: 0.484 | acc: 0.770 - 1.56s
 it:  100 | loss: 0.429 | acc: 0.850 - 1.69s
test loss: 0.701
 it:  110 | loss: 0.632 | acc: 0.700 - 1.89s
 it:  120 | loss: 0.451 | acc: 0.780 - 2.01s
 it:  130 | loss: 0.487 | acc: 0.820 - 2.14s
 it:  140 | loss: 0.393 | acc: 0.810 - 2.27s
 it:  150 | loss: 0.542 | acc: 0.720 - 2.39s
 it:  160 | loss: 0.444 | acc: 0.810 - 2.52s
 it:  170 | loss: 0.412 | acc: 0.800 - 2.65s
 it:  180 | loss: 0.433 | acc: 0.820 - 2.77s
 it:  190 | loss: 0.480 | acc: 0.770 - 2.90s
 it:  200 | loss: 0.510 | acc: 0.710 - 3.03s
test loss: 0.493
 it:  210 | loss: 0.4

 it: 1680 | loss: 0.398 | acc: 0.810 - 23.81s
 it: 1690 | loss: 0.327 | acc: 0.840 - 23.93s
 it: 1700 | loss: 0.346 | acc: 0.860 - 24.05s
test loss: 0.402
 it: 1710 | loss: 0.352 | acc: 0.840 - 24.26s
 it: 1720 | loss: 0.344 | acc: 0.860 - 24.38s
 it: 1730 | loss: 0.419 | acc: 0.790 - 24.49s
 it: 1740 | loss: 0.351 | acc: 0.830 - 24.62s
 it: 1750 | loss: 0.285 | acc: 0.900 - 24.74s
 it: 1760 | loss: 0.371 | acc: 0.820 - 24.87s
 it: 1770 | loss: 0.322 | acc: 0.850 - 24.99s
 it: 1780 | loss: 0.297 | acc: 0.860 - 25.10s
 it: 1790 | loss: 0.266 | acc: 0.890 - 25.22s
 it: 1800 | loss: 0.252 | acc: 0.880 - 25.34s
test loss: 0.494
 it: 1810 | loss: 0.410 | acc: 0.840 - 25.49s
 it: 1820 | loss: 0.362 | acc: 0.840 - 25.61s
 it: 1830 | loss: 0.384 | acc: 0.840 - 25.73s
 it: 1840 | loss: 0.390 | acc: 0.830 - 25.84s
 it: 1850 | loss: 0.482 | acc: 0.820 - 25.95s
 it: 1860 | loss: 0.425 | acc: 0.830 - 26.07s
 it: 1870 | loss: 0.398 | acc: 0.770 - 26.18s
 it: 1880 | loss: 0.428 | acc: 0.840 - 26.30s


 it: 3400 | loss: 0.008 | acc: 1.000 - 47.38s
test loss: 0.481
 it: 3410 | loss: 0.016 | acc: 1.000 - 47.60s
 it: 3420 | loss: 0.162 | acc: 0.960 - 47.72s
 it: 3430 | loss: 0.029 | acc: 1.000 - 47.83s
 it: 3440 | loss: 0.135 | acc: 0.950 - 47.95s
 it: 3450 | loss: 0.051 | acc: 0.980 - 48.07s
 it: 3460 | loss: 0.073 | acc: 0.990 - 48.18s
 it: 3470 | loss: 0.085 | acc: 0.960 - 48.30s
 it: 3480 | loss: 0.084 | acc: 0.960 - 48.42s
 it: 3490 | loss: 0.106 | acc: 0.970 - 48.55s
 it: 3500 | loss: 0.023 | acc: 0.990 - 48.66s
test loss: 0.534
 it: 3510 | loss: 0.069 | acc: 0.960 - 48.81s
 it: 3520 | loss: 0.190 | acc: 0.910 - 48.92s
 it: 3530 | loss: 0.118 | acc: 0.960 - 49.04s
 it: 3540 | loss: 0.045 | acc: 0.980 - 49.15s
 it: 3550 | loss: 0.065 | acc: 0.970 - 49.27s
 it: 3560 | loss: 0.059 | acc: 0.990 - 49.38s
 it: 3570 | loss: 0.079 | acc: 0.970 - 49.50s
 it: 3580 | loss: 0.183 | acc: 0.950 - 49.61s
 it: 3590 | loss: 0.032 | acc: 0.980 - 49.73s
 it: 3600 | loss: 0.013 | acc: 1.000 - 49.84s


 it: 5120 | loss: 0.039 | acc: 0.990 - 70.25s
 it: 5130 | loss: 0.117 | acc: 0.970 - 70.37s
 it: 5140 | loss: 0.061 | acc: 0.980 - 70.48s
 it: 5150 | loss: 0.078 | acc: 0.980 - 70.68s
 it: 5160 | loss: 0.043 | acc: 0.990 - 70.81s
 it: 5170 | loss: 0.031 | acc: 0.990 - 70.92s
 it: 5180 | loss: 0.002 | acc: 1.000 - 71.04s
 it: 5190 | loss: 0.066 | acc: 0.990 - 71.17s
 it: 5200 | loss: 0.022 | acc: 0.990 - 71.29s
test loss: 0.783
 it: 5210 | loss: 0.012 | acc: 0.990 - 71.44s
 it: 5220 | loss: 0.035 | acc: 0.990 - 71.57s
 it: 5230 | loss: 0.114 | acc: 0.970 - 71.69s
 it: 5240 | loss: 0.024 | acc: 0.990 - 71.81s
 it: 5250 | loss: 0.037 | acc: 0.980 - 71.92s
 it: 5260 | loss: 0.054 | acc: 0.980 - 72.04s
 it: 5270 | loss: 0.011 | acc: 1.000 - 72.15s
 it: 5280 | loss: 0.009 | acc: 1.000 - 72.27s
 it: 5290 | loss: 0.044 | acc: 0.990 - 72.38s
 it: 5300 | loss: 0.010 | acc: 1.000 - 72.50s
test loss: 0.680
 it: 5310 | loss: 0.002 | acc: 1.000 - 72.64s
 it: 5320 | loss: 0.002 | acc: 1.000 - 72.76s


 it: 6840 | loss: 0.042 | acc: 0.980 - 92.13s
 it: 6850 | loss: 0.122 | acc: 0.970 - 92.25s
 it: 6860 | loss: 0.006 | acc: 1.000 - 92.37s
 it: 6870 | loss: 0.005 | acc: 1.000 - 92.49s
 it: 6880 | loss: 0.019 | acc: 0.990 - 92.61s
 it: 6890 | loss: 0.085 | acc: 0.980 - 92.72s
 it: 6900 | loss: 0.060 | acc: 0.990 - 92.85s
test loss: 0.739
 it: 6910 | loss: 0.011 | acc: 1.000 - 93.00s
 it: 6920 | loss: 0.007 | acc: 1.000 - 93.12s
 it: 6930 | loss: 0.020 | acc: 0.990 - 93.30s
 it: 6940 | loss: 0.020 | acc: 0.990 - 93.42s
 it: 6950 | loss: 0.061 | acc: 0.990 - 93.53s
 it: 6960 | loss: 0.039 | acc: 0.990 - 93.64s
 it: 6970 | loss: 0.021 | acc: 1.000 - 93.77s
 it: 6980 | loss: 0.014 | acc: 0.990 - 93.89s
 it: 6990 | loss: 0.019 | acc: 0.990 - 94.01s
 it: 7000 | loss: 0.036 | acc: 0.980 - 94.13s
test loss: 1.168
 * model saved at 'models/cnn'
 it: 7010 | loss: 0.001 | acc: 1.000 - 95.12s
 it: 7020 | loss: 0.000 | acc: 1.000 - 95.24s
 it: 7030 | loss: 0.010 | acc: 1.000 - 95.36s
 it: 7040 | los

 it: 8540 | loss: 0.001 | acc: 1.000 - 114.95s
 it: 8550 | loss: 0.012 | acc: 0.990 - 115.07s
 it: 8560 | loss: 0.016 | acc: 1.000 - 115.18s
 it: 8570 | loss: 0.001 | acc: 1.000 - 115.30s
 it: 8580 | loss: 0.185 | acc: 0.990 - 115.41s
 it: 8590 | loss: 0.017 | acc: 1.000 - 115.53s
 it: 8600 | loss: 0.069 | acc: 0.980 - 115.65s
test loss: 1.310
 it: 8610 | loss: 0.027 | acc: 0.990 - 115.80s
 it: 8620 | loss: 0.005 | acc: 1.000 - 115.92s
 it: 8630 | loss: 0.007 | acc: 1.000 - 116.04s
 it: 8640 | loss: 0.001 | acc: 1.000 - 116.16s
 it: 8650 | loss: 0.021 | acc: 0.990 - 116.28s
 it: 8660 | loss: 0.029 | acc: 0.990 - 116.40s
 it: 8670 | loss: 0.001 | acc: 1.000 - 116.52s
 it: 8680 | loss: 0.002 | acc: 1.000 - 116.65s
 it: 8690 | loss: 0.001 | acc: 1.000 - 116.77s
 it: 8700 | loss: 0.002 | acc: 1.000 - 116.88s
test loss: 0.655
 it: 8710 | loss: 0.000 | acc: 1.000 - 117.10s
 it: 8720 | loss: 0.002 | acc: 1.000 - 117.21s
 it: 8730 | loss: 0.006 | acc: 1.000 - 117.33s
 it: 8740 | loss: 0.029 | 