In [13]:
import tensorflow as tf
import numpy as np
import argparse
from datetime import datetime

from data_loader import DataGenerator

from trainer import MatchingModelTrainer
from preprocessor import DynamicPreprocessor
from utils.dirs import create_dirs
from utils.logger import SummaryWriter
from utils.config import load_config, save_config
from models.base import get_model
from utils.utils import JamoProcessor

  return f(*args, **kwds)
  return f(*args, **kwds)


In [14]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [15]:
class Config:
    def __init__(self):
        self.mode = "train"
        self.name = "test"
        
        # dir
        self.train_dir = "/media/scatter/scatterdisk/reply_matching_model/debug/"
        self.val_dir = "/media/scatter/scatterdisk/reply_matching_model/debug/sol.small.txt"
        self.checkpoint_dir = "/media/scatter/scatterdisk/reply_matching_model/runs/"
        
        # model
        self.sent_piece_model = "/media/scatter/scatterdisk/tokenizer/sent_piece.50K.model"
        self.model = "DualEncoderLSTM"
        self.normalizer = "DummyNormalizer"
        self.tokenizer = "SentencePieceTokenizer"
        self.negative_sampling = "random"
        self.num_negative_samples = 1
        
        # vocab
        self.pretrained_embed_dir = "/media/scatter/scatterdisk/pretrained_embedding/fasttext.sent_piece_50K.256D"
        self.vocab_size = 50000
        self.vocab_list = "/media/scatter/scatterdisk/pretrained_embedding/vocab_list.sent_piece_50K.txt"
        self.embed_dim = 256
        
        self.learning_rate = 1e-3
        self.min_length = 1
        self.max_length = 50
        self.lstm_dim = 512
        self.batch_size = 256
        self.num_epochs = 86
        self.evaluate_every = 50000
        self.save_every = 50000
        self.max_to_keep = 10
        self.shuffle = False
        self.gpu = "a"

config = Config()

In [18]:
os.path.join("/media/scatter/scatterdisk/reply_matching_model/runs/test/", "model.ckpt")

'/media/scatter/scatterdisk/reply_matching_model/runs/test/model.ckpt'

In [16]:
create_dirs(config)
tf_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
tf_config.gpu_options.per_process_gpu_memory_fraction = 0.2
sess = tf.Session(config=tf_config)
preprocessor = DynamicPreprocessor(config)
data = DataGenerator(preprocessor, config)
summary_writer = SummaryWriter(sess, config)
trainer = MatchingModelTrainer(sess, preprocessor, data, config, summary_writer)

In [17]:
trainer.train()

[32m[18:28:52][INFO] Building train graph... [0m
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
[32m[18:28:56][INFO] Loading checkpoint from /media/scatter/scatterdisk/reply_matching_model/runs/debug/ [0m


INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


[32m[18:28:56][INFO] Building val graph... [0m
100%|██████████| 4/4 [00:01<00:00,  3.25it/s]


INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  5.41it/s]


INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  9.06it/s]


INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  8.44it/s]


KeyboardInterrupt: 

In [19]:
len("====================")

20

In [6]:
latest_checkpoint = tf.train.latest_checkpoint(config.checkpoint_dir)
latest_checkpoint

'/media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt'

---

### 새로 생성

In [5]:
train_model, train_sess = trainer.build_graph(name="train")
val_model, val_sess = trainer.build_graph(name="val")

[32m[15:45:01][INFO] Building train graph... [0m
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
[32m[15:45:05][INFO] Loading checkpoint from /media/scatter/scatterdisk/reply_matching_model/runs/debug/ [0m
[32m[15:45:05][INFO] Loading checkpoint from /media/scatter/scatterdisk/reply_matching_model/runs/debug/ [0m


INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


[32m[15:45:05][INFO] Building val graph... [0m
[32m[15:45:05][INFO] Building val graph... [0m


### 불러온 모델에 대해 validation

In [6]:
trainer.val(train_model, train_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:01<00:00,  3.69it/s]


(10.888929, 0.5065375)

### 다시 학습

In [7]:
for _ in range(30):
    trainer.train_step(train_model, train_sess)

In [8]:
trainer.val(train_model, train_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  8.63it/s]


(0.0043217856, 0.99897176)

In [9]:
val_iterator = data.get_val_iterator(config.batch_size)
batch_queries, batch_replies, \
        batch_queries_lengths, batch_replies_lengths = next(val_iterator)
feed_dict = {train_model.input_queries: batch_queries,
                     train_model.input_replies: batch_replies,
                     train_model.queries_lengths: batch_queries_lengths,
                     train_model.replies_lengths: batch_replies_lengths,}

In [10]:
train_model.val(train_sess, feed_dict=feed_dict)[:2]

(0.00037988753, 1.0)

### 저장

In [11]:
train_model.save(train_sess, config.checkpoint_dir + "model.ckpt")

In [12]:
latest_checkpoint = tf.train.latest_checkpoint(config.checkpoint_dir)
latest_checkpoint

'/media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt'

### 다시 불러오기

In [13]:
val_model.load(val_sess)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


In [14]:
trainer.val(val_model, val_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  5.78it/s]


(10.809477, 0.48448467)

In [15]:
trainer.val(train_model, train_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  8.48it/s]


(0.0043217856, 0.99897176)

### 일단 확실한 건 불러오는 것이 안됨

In [20]:
train_model.load(train_sess)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


In [21]:
trainer.val(train_model, train_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  8.78it/s]


(0.0043217856, 0.99897176)

In [25]:
val_model.load(val_sess)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


In [26]:
trainer.val(val_model, val_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  8.24it/s]


(10.809477, 0.48448467)

In [27]:
trainer.val(train_model, train_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt


100%|██████████| 4/4 [00:00<00:00,  8.81it/s]


(0.0043217856, 0.99897176)

In [28]:
train_model.save(train_sess, config.checkpoint_dir + "debug")

In [30]:
latest_checkpoint = tf.train.latest_checkpoint(config.checkpoint_dir)

In [32]:
val_model.load(val_sess)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/debug


In [34]:
trainer.val(val_model, val_sess, trainer.global_step)

INFO:tensorflow:Restoring parameters from /media/scatter/scatterdisk/reply_matching_model/runs/debug/debug


100%|██████████| 4/4 [00:00<00:00,  8.31it/s]


(10.809477, 0.48448467)

In [35]:
train_model.input_queries

<tf.Tensor 'inputs/Placeholder:0' shape=(?, 50) dtype=int32>

In [37]:
config.checkpoint_dir + "model.ckpt"

'/media/scatter/scatterdisk/reply_matching_model/runs/debug/model.ckpt'

In [10]:
from utils.logger import setup_logger

In [36]:
val_model.input_queries

<tf.Tensor 'inputs/Placeholder:0' shape=(?, 50) dtype=int32>